Source code for _bentoml_sdk.service.factory

from __future__ import annotations

import asyncio
import inspect
import logging
import math
import os
import pathlib
import sys
import typing as t
from functools import lru_cache
from functools import partial
from urllib.parse import urlsplit

import anyio.to_thread
import attrs
from simple_di import Provide
from simple_di import inject
from starlette.applications import Starlette
from typing_extensions import Unpack

from bentoml._internal.bento.bento import Bento
from bentoml._internal.bento.build_config import BentoEnvSchema
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.context import ServiceContext
from bentoml._internal.models import Model as StoredModel
from bentoml._internal.utils import deprecated
from bentoml._internal.utils import dict_filter_none
from bentoml._internal.utils.uri import join_paths
from bentoml.exceptions import BentoMLConfigException
from bentoml.exceptions import BentoMLException
from bentoml.legacy import Runner

from ..images import Image
from ..method import APIMethod
from ..models import BentoModel
from ..models import HuggingFaceModel
from ..models import Model
from .config import ServiceConfig as Config

logger = logging.getLogger("bentoml.serve")

T = t.TypeVar("T")
T_co = t.TypeVar("T_co", covariant=True)

if t.TYPE_CHECKING:
    from bentoml._internal import external_typing as ext
    from bentoml._internal.service.openapi.specification import OpenAPISpecification
    from bentoml._internal.utils.circus import Server

    from .dependency import Dependency

    P = t.ParamSpec("P")
    R = t.TypeVar("R")

    class _ServiceDecorator(t.Protocol):
        def __call__(self, inner: type[T]) -> Service[T]: ...


class PathMetadata(t.TypedDict):
    mounted: bool


class ServiceEnvConfig(t.TypedDict, total=False):
    name: str
    value: str
    stage: t.Literal["all", "build", "runtime"]


def with_config(
    func: t.Callable[t.Concatenate["Service[t.Any]", P], R],
) -> t.Callable[t.Concatenate["Service[t.Any]", P], R]:
    def wrapper(self: Service[t.Any], *args: P.args, **kwargs: P.kwargs) -> R:
        self.inject_config()
        return func(self, *args, **kwargs)

    return wrapper


def convert_envs(envs: t.List[ServiceEnvConfig]) -> t.List[BentoEnvSchema]:
    return [BentoEnvSchema(**env) for env in envs]


class _DummyService:
    pass


@attrs.define
class Service(t.Generic[T_co]):
    """A Bentoml service that can be served by BentoML server."""

    name: str
    config: Config = attrs.field(factory=Config)
    inner: type[T_co] = _DummyService
    image: t.Optional[Image] = None
    description: t.Optional[str] = None
    path_prefix: str = ""
    envs: t.List[BentoEnvSchema] = attrs.field(factory=list, converter=convert_envs)
    labels: t.Dict[str, str] = attrs.field(factory=dict)
    models: list[Model[t.Any]] = attrs.field(factory=list)
    cmd: t.Optional[t.List[str]] = None
    bento: t.Optional[Bento] = attrs.field(init=False, default=None)
    apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict)
    dependencies: dict[str, Dependency[t.Any]] = attrs.field(factory=dict, init=False)
    mount_apps: list[tuple[ext.ASGIApp, str, str]] = attrs.field(
        factory=list, init=False
    )
    middlewares: list[tuple[type[ext.AsgiMiddleware], dict[str, t.Any]]] = attrs.field(
        factory=list, init=False
    )
    # service context
    context: ServiceContext = attrs.field(init=False, factory=ServiceContext)
    working_dir: str = attrs.field(init=False, factory=os.getcwd)
    # import info
    _caller_module: str = attrs.field(init=False)
    _import_str: str | None = attrs.field(init=False, default=None)

    def __attrs_post_init__(self) -> None:
        from .dependency import Dependency

        has_task = False
        for field in dir(self.inner):
            value = getattr(self.inner, field)
            if isinstance(value, Dependency):
                self.dependencies[field] = value
            elif isinstance(value, StoredModel):
                logger.warning(
                    "`bentoml.models.get()` as the class attribute is not recommended because it requires the model"
                    f" to exist at import time. Use `{value._attr} = BentoModel({str(value.tag)!r})` instead."
                )
                self.models.append(BentoModel(value.tag))
            elif isinstance(value, Model):
                self.models.append(t.cast(Model[t.Any], value))
            elif isinstance(value, APIMethod):
                if value.is_task:
                    has_task = True
                self.apis[field] = t.cast("APIMethod[..., t.Any]", value)

        if has_task:
            traffic = self.config.setdefault("traffic", {})
            traffic["external_queue"] = True
            traffic.setdefault("concurrency", 1)

        pre_mount_apps = getattr(self.inner, "__bentoml_mounted_apps__", [])
        if pre_mount_apps:
            for app, path, name in pre_mount_apps:
                self.mount_asgi_app(app, path=path, name=name)
            delattr(self.inner, "__bentoml_mounted_apps__")

        if self.config.get("workers") is None and self.has_custom_command():
            from bentoml._internal.resource import system_resources

            resources = system_resources()
            workers = min(16, int(resources["cpu"] / 2) or 1)
            self.config["workers"] = workers

        if self.path_prefix:
            livez_endpoint = join_paths(self.path_prefix, "livez")
            readyz_endpoint = join_paths(self.path_prefix, "readyz")
            self.config.setdefault("endpoints", {}).setdefault("livez", livez_endpoint)
            self.config.setdefault("endpoints", {}).setdefault(
                "readyz", readyz_endpoint
            )

    def __hash__(self):
        return hash(self.name)

    def has_custom_command(self) -> bool:
        return hasattr(self.inner, "__command__") or self.cmd is not None

    @_caller_module.default  # type: ignore
    def _get_caller_module(self) -> str:
        if __name__ == "__main__":
            return __name__
        current_frame = inspect.currentframe()
        frame = current_frame
        while frame:
            this_name = frame.f_globals["__name__"]
            if this_name != __name__:
                return this_name
            frame = frame.f_back
        return __name__

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__} name={self.name!r}>"

    @lru_cache
    def find_dependent_by_path(self, path: str) -> Service[t.Any]:
        """Find a service by path"""
        attr_name, _, path = path.partition(".")
        if attr_name not in self.dependencies:
            if attr_name in self.all_services():
                return self.all_services()[attr_name]
            else:
                raise BentoMLException(f"Service {attr_name} not found")
        dependent = self.dependencies[attr_name]
        if dependent.on is None:
            raise BentoMLException(f"Service {attr_name} not found")
        if path:
            return dependent.on.find_dependent_by_path(path)
        return dependent

    def find_dependent_by_name(self, name: str) -> Service[t.Any]:
        """Find a service by name"""
        try:
            return self.all_services()[name]
        except KeyError:
            raise BentoMLException(f"Service {name} not found") from None

    @property
    def url(self) -> str | None:
        """Get the URL of the service, or None if the service is not served"""
        dependency_map = BentoMLContainer.remote_runner_mapping.get()
        url = dependency_map.get(self.name)
        return url.replace("tcp://", "http://") if url else None

    async def get_hosts(self) -> list[str]:
        """Return a list of IPs of the service"""
        import httpx

        from _bentoml_impl.server.allocator import ResourceAllocator

        url = BentoMLContainer.remote_runner_mapping.get().get(self.name)
        if not url:
            raise BentoMLException(f"Service {self.name} not found")
        url = url.replace("tcp://", "http://")
        workers, _ = ResourceAllocator().get_worker_env(self)
        if url.startswith("file://"):
            # UDS connections don't have a port number, use a fake one.
            return ["127.0.0.1:3000"] * workers
        if not url.startswith("http://"):
            raise BentoMLException(
                f"Unable to get hosts for service {self.name} because it is not served as HTTP"
            )
        if "BENTOCLOUD_DEPLOYMENT_URL" in os.environ:
            # BentoCloud environment, the url is to runner-lb
            headers = {"Runner-Name": self.name, "Resolve-Runner": "1"}
            async with httpx.AsyncClient() as client:
                response = await client.get(url, headers=headers)
                if response.is_error:
                    raise BentoMLException(
                        f"Failed to get hosts for service {self.name} from cloud"
                    )
                await response.aread()
                result = response.json()
                return [f"{ip}:{result['port']}" for ip in result.get("ips", [])]

        # Serving locally, the hostname should be the IP
        return [urlsplit(url).netloc] * workers

    def all_services(self, exclude_urls: bool = False) -> dict[str, Service[t.Any]]:
        """Get a map of the service and all recursive dependencies"""
        services: dict[str, Service[t.Any]] = {self.name: self}
        for dependency in self.dependencies.values():
            if dependency.on is None:
                continue
            if exclude_urls and (dependency.url or dependency.deployment):
                continue
            dependents = dependency.on.all_services(exclude_urls=exclude_urls)
            conflict = next(
                (
                    k
                    for k in dependents
                    if k in services and dependents[k] is not services[k]
                ),
                None,
            )
            if conflict:
                raise BentoMLConfigException(
                    f"Dependency conflict: {conflict} is already defined by {services[conflict].inner}"
                )
            services.update(dependents)
        return services

    @property
    def doc(self) -> str:
        from bentoml._internal.bento.bento import get_default_svc_readme

        if self.bento is not None:
            return self.bento.doc

        return get_default_svc_readme(self)

    def schema(self) -> dict[str, t.Any]:
        # Add API method routes (these are already full paths from method.route)
        all_paths: dict[str, PathMetadata] = dict(
            (join_paths(self.path_prefix, method.route), {"mounted": False})
            for method in self.apis.values()
        )

        # Store id(app) to avoid reprocessing
        processed_mounted_apps: set[int] = set()

        def extract_routes_from_asgi_app(app: t.Any, current_prefix: str) -> None:
            if app is None or id(app) in processed_mounted_apps:
                return
            processed_mounted_apps.add(id(app))

            # Introspect ASGI app (FastAPI/Starlette)
            if issubclass(app.__class__, Starlette):
                for route_item in app.routes:
                    if hasattr(route_item, "path") and isinstance(route_item.path, str):
                        item_specific_path = route_item.path
                        # current_prefix is the path *to* this app.
                        # item_specific_path is relative to this app.
                        full_item_path = join_paths(current_prefix, item_specific_path)
                        all_paths[full_item_path] = {"mounted": True}

                        if hasattr(route_item, "app") and route_item.app is not None:
                            extract_routes_from_asgi_app(route_item.app, full_item_path)

        for mounted_app, mount_path_prefix, _ in self.mount_apps:
            normalized_base_prefix = join_paths(mount_path_prefix)
            extract_routes_from_asgi_app(mounted_app, normalized_base_prefix)
        routes = [method.schema() for method in self.apis.values()]
        if self.path_prefix:
            for route in routes:
                route["route"] = join_paths(self.path_prefix, route["route"])
        return dict_filter_none(
            {
                "name": self.name,
                "type": "service",
                "routes": routes,
                "description": getattr(self.inner, "__doc__", None),
                "paths": all_paths,
            }
        )

    @property
    def import_string(self) -> str:
        if self._import_str is None:
            import_module = self._caller_module
            if import_module == "__main__":
                if hasattr(sys.modules["__main__"], "__file__"):
                    import_module = sys.modules["__main__"].__file__
                    assert isinstance(import_module, str)
                    try:
                        import_module_path = pathlib.Path(import_module).relative_to(
                            self.working_dir
                        )
                    except ValueError:
                        raise BentoMLException(
                            "Failed to get service import origin, service object defined in __main__ module is not supported"
                        )
                    import_module = str(import_module_path.with_suffix("")).replace(
                        os.path.sep, "."
                    )
                else:
                    raise BentoMLException(
                        "Failed to get service import origin, service object defined interactively in console or notebook is not supported"
                    )

            if self._caller_module not in sys.modules:
                raise BentoMLException(
                    "Failed to get service import origin, service object must be defined in a module"
                )

            for name, value in vars(sys.modules[self._caller_module]).items():
                if value is self:
                    self._import_str = f"{import_module}:{name}"
                    break
            else:
                raise BentoMLException(
                    "Failed to get service import origin, service object must be assigned to a variable at module level"
                )
        return self._import_str

    def to_asgi(self, is_main: bool = True) -> ext.ASGIApp:
        from _bentoml_impl.server.app import ServiceAppFactory

        self.inject_config()
        factory = ServiceAppFactory(self, is_main=is_main)
        return factory()

    def mount_asgi_app(
        self, app: ext.ASGIApp, path: str = "/", name: str | None = None
    ) -> None:
        self.mount_apps.append((app, join_paths(self.path_prefix, path), name))  # type: ignore

    def mount_wsgi_app(
        self, app: ext.WSGIApp, path: str = "/", name: str | None = None
    ) -> None:
        from a2wsgi import WSGIMiddleware

        self.mount_apps.append((WSGIMiddleware(app), path, name))  # type: ignore

    def add_asgi_middleware(
        self, middleware_cls: type[ext.AsgiMiddleware], **options: t.Any
    ) -> None:
        self.middlewares.append((middleware_cls, options))

    def gradio_app_startup_hook(self, max_concurrency: int):
        gradio_apps = getattr(self.inner, "__bentoml_gradio_apps__", [])
        if gradio_apps:
            for gradio_app, path, _ in gradio_apps:
                logger.info(f"Initializing gradio app at: {path or '/'}")
                blocks = gradio_app.get_blocks()
                blocks.queue(default_concurrency_limit=max_concurrency)
                if hasattr(blocks, "startup_events"):
                    # gradio < 5.0
                    blocks.startup_events()
                else:
                    # gradio >= 5.0
                    blocks.run_startup_events()
            delattr(self.inner, "__bentoml_gradio_apps__")

    def __call__(self) -> T_co:
        try:
            instance = self.inner()
            instance.to_async = _AsyncWrapper(instance, self.apis.keys())
            instance.to_sync = _SyncWrapper(instance, self.apis.keys())
            instance.bento_service = self
            return instance
        except Exception:
            logger.exception("Initializing service error")
            raise

    @property
    def openapi_spec(self) -> OpenAPISpecification:
        from .openapi import generate_spec

        return generate_spec(self)

    def inject_config(self) -> None:
        from bentoml._internal.configuration import load_config
        from bentoml._internal.configuration.containers import BentoMLContainer
        from bentoml._internal.utils import deep_merge

        # XXX: ensure at least one item to make `flatten_dict` work
        override_defaults = {
            "services": {
                name: (svc.config or {"workers": 1})
                for name, svc in self.all_services().items()
            }
        }

        load_config(override_defaults=override_defaults, use_version=2)
        main_config = BentoMLContainer.config.services[self.name].get()
        api_server_keys = (
            "traffic",
            "metrics",
            "logging",
            "ssl",
            "http",
            "grpc",
            "backlog",
            "runner_probe",
            "max_runner_connections",
        )
        api_server_config = {
            k: main_config[k] for k in api_server_keys if main_config.get(k) is not None
        }
        rest_config = {
            k: main_config[k] for k in main_config if k not in api_server_keys
        }
        existing = t.cast(t.Dict[str, t.Any], BentoMLContainer.config.get())
        deep_merge(existing, {"api_server": api_server_config, **rest_config})
        BentoMLContainer.config.set(existing)  # type: ignore

    @with_config
    @inject
    def serve_http(
        self,
        *,
        working_dir: str | None = None,
        port: int = Provide[BentoMLContainer.http.port],
        host: str = Provide[BentoMLContainer.http.host],
        backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
        timeout: int | None = None,
        ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
        ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
        ssl_keyfile_password: str | None = Provide[
            BentoMLContainer.ssl.keyfile_password
        ],
        ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
        ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
        ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
        ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
        bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
        development_mode: bool = False,
        reload: bool = False,
        threaded: bool = False,
    ) -> Server:
        from _bentoml_impl.server import serve_http
        from bentoml._internal.log import configure_logging

        configure_logging()

        return serve_http(
            self,
            working_dir=working_dir,
            host=host,
            port=port,
            backlog=backlog,
            timeout=timeout,
            ssl_certfile=ssl_certfile,
            ssl_keyfile=ssl_keyfile,
            ssl_keyfile_password=ssl_keyfile_password,
            ssl_version=ssl_version,
            ssl_cert_reqs=ssl_cert_reqs,
            ssl_ca_certs=ssl_ca_certs,
            ssl_ciphers=ssl_ciphers,
            bentoml_home=bentoml_home,
            development_mode=development_mode,
            reload=reload,
            threaded=threaded,
        )

    def on_load_bento(self, bento: Bento) -> None:
        for svc in self.all_services().values():
            service_info = next(
                info for info in bento.info.services if info.name == svc.name
            )
            for model, info in zip(svc.models, service_info.models):
                # Replace the model version with the one in the Bento
                if not isinstance(model, HuggingFaceModel):
                    continue
                model_id = info.metadata.get("model_id")  # use the case in bento info
                if not model_id:
                    model_id = info.tag.name.replace("--", "/")
                revision = info.metadata.get("revision", info.tag.version)
                model.model_id = model_id
                model.revision = revision
            svc.bento = bento

    def needs_task_db(self) -> bool:
        if "BENTOCLOUD_DEPLOYMENT_URL" in os.environ:
            return False
        return any(method.is_task for method in self.apis.values())


@t.overload
def service(inner: type[T], /) -> Service[T]: ...


@t.overload
def service(
    *,
    name: str | None = None,
    image: Image | None = None,
    description: str | None = None,
    path_prefix: str | None = None,
    envs: list[ServiceEnvConfig] | None = None,
    labels: dict[str, str] | None = None,
    cmd: list[str] | None = None,
    service_class: type[Service[T]] = Service,
    **kwargs: Unpack[Config],
) -> _ServiceDecorator: ...


[docs] def service( inner: type[T] | None = None, /, *, name: str | None = None, image: Image | None = None, description: str | None = None, path_prefix: str | None = None, envs: list[ServiceEnvConfig] | None = None, labels: dict[str, str] | None = None, cmd: list[str] | None = None, service_class: type[Service[T]] = Service, **kwargs: Unpack[Config], ) -> t.Any: """Mark a class as a BentoML service. Args: name: The name of the service. Defaults to the class name. image: The image to use for the service. description: A description of the service. path_prefix: A URL path prefix to apply to all API endpoints of this service. For example, setting ``path_prefix="/v1"`` will make an endpoint ``/predict`` available at ``/v1/predict``. This also applies to mounted ASGI applications and health check endpoints. envs: Environment variables to set for the service. labels: Labels to attach to the service. cmd: A custom command to start the service. **kwargs: Additional service configurations such as ``traffic``, ``resources``, ``workers``, etc. Example: @service(traffic={"timeout": 60}) class InferenceService: @api def predict(self, input: str) -> str: return input """ config = kwargs def decorator(inner: type[T]) -> Service[T]: if isinstance(inner, Service): raise TypeError("service() decorator can only be applied once") return service_class( name=name or inner.__name__, config=config, inner=inner, image=image, description=description, path_prefix=path_prefix, envs=envs or [], labels=labels or {}, cmd=cmd, ) return decorator(inner) if inner is not None else decorator
[docs] @deprecated() def runner_service(runner: Runner, **kwargs: Unpack[Config]) -> Service[t.Any]: """Make a service from a legacy Runner""" if not isinstance(runner, Runner): # type: ignore raise ValueError(f"Expect an instance of Runner, but got {type(runner)}") class RunnerHandle(runner.runnable_class): def __init__(self) -> None: super().__init__(**runner.runnable_init_params) apis: dict[str, APIMethod[..., t.Any]] = {} assert runner.runnable_class.bentoml_runnable_methods__ is not None for method in runner.runner_methods: runnable_method = runner.runnable_class.bentoml_runnable_methods__[method.name] api = APIMethod( # type: ignore func=runnable_method.func, batchable=runnable_method.config.batchable, batch_dim=runnable_method.config.batch_dim, max_batch_size=method.max_batch_size, max_latency_ms=method.max_latency_ms, ) apis[method.name] = api config: Config = {} resource_config = runner.resource_config or {} if ( "nvidia.com/gpu" in runner.runnable_class.SUPPORTED_RESOURCES and "nvidia.com/gpu" in resource_config ): gpus: list[int] | str | int = resource_config["nvidia.com/gpu"] if isinstance(gpus, str): gpus = int(gpus) elif isinstance(gpus, list): gpus = len(gpus) config["workers"] = int(gpus * runner.workers_per_resource) elif "cpus" in resource_config: config["workers"] = ( math.ceil(resource_config["cpus"]) * runner.workers_per_resource ) config.update(kwargs) return Service( name=runner.name, config=config, inner=RunnerHandle, models=[BentoModel(m.tag) for m in runner.models], apis=apis, )
class _Wrapper: def __init__(self, wrapped: t.Any, apis: t.Iterable[str]) -> None: self.__call = None for name in apis: if name == "__call__": self.__call = self._make_method(wrapped, name) else: setattr(self, name, self._make_method(wrapped, name)) def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: if self.__call is None: raise TypeError("This service is not callable.") return self.__call(*args, **kwargs) def _make_method(self, instance: t.Any, name: str) -> t.Any: raise NotImplementedError class _AsyncWrapper(_Wrapper): def _make_method(self, instance: t.Any, name: str) -> t.Any: original_func = func = getattr(instance, name).local while hasattr(original_func, "func"): original_func = original_func.func is_async_func = ( asyncio.iscoroutinefunction(original_func) or ( callable(original_func) and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore ) or inspect.isasyncgenfunction(original_func) ) if is_async_func: return func if inspect.isgeneratorfunction(original_func): async def wrapped_gen( *args: t.Any, **kwargs: t.Any ) -> t.AsyncGenerator[t.Any, None]: gen = func(*args, **kwargs) next_fun = gen.__next__ while True: try: yield await anyio.to_thread.run_sync(next_fun) except StopIteration: break except RuntimeError as e: if "raised StopIteration" in str(e): break raise return wrapped_gen else: async def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any: return await anyio.to_thread.run_sync(partial(func, **kwargs), *args) return wrapped class _SyncWrapper(_Wrapper): def _make_method(self, instance: t.Any, name: str) -> t.Any: original_func = func = getattr(instance, name).local while hasattr(original_func, "func"): original_func = original_func.func is_async_func = ( asyncio.iscoroutinefunction(original_func) or ( callable(original_func) and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore ) or inspect.isasyncgenfunction(original_func) ) if not is_async_func: return func if inspect.isasyncgenfunction(original_func): def wrapped_gen( *args: t.Any, **kwargs: t.Any ) -> t.Generator[t.Any, None, None]: agen = func(*args, **kwargs) loop = asyncio.get_event_loop() while True: try: yield loop.run_until_complete(agen.__anext__()) except StopAsyncIteration: break return wrapped_gen else: def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any: loop = asyncio.get_event_loop() return loop.run_until_complete(func(*args, **kwargs)) return wrapped