from __future__ import annotations
import inspect
import logging
import math
import os
import pathlib
import sys
import typing as t
from functools import lru_cache
from functools import partial
import attrs
from simple_di import Provide
from simple_di import inject
from typing_extensions import Unpack
from bentoml import Runner
from bentoml._internal.bento.bento import Bento
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.context import ServiceContext
from bentoml._internal.models import Model
from bentoml._internal.utils import dict_filter_none
from bentoml.exceptions import BentoMLException
from ..method import APIMethod
from .config import ServiceConfig as Config
logger = logging.getLogger("bentoml.io")
T = t.TypeVar("T", bound=object)
if t.TYPE_CHECKING:
from _bentoml_impl.server.serving import Server
from bentoml._internal import external_typing as ext
from bentoml._internal.service.openapi.specification import OpenAPISpecification
from .dependency import Dependency
P = t.ParamSpec("P")
R = t.TypeVar("R")
class _ServiceDecorator(t.Protocol):
def __call__(self, inner: type[T]) -> Service[T]: ...
def with_config(
func: t.Callable[t.Concatenate["Service[t.Any]", P], R],
) -> t.Callable[t.Concatenate["Service[t.Any]", P], R]:
def wrapper(self: Service[t.Any], *args: P.args, **kwargs: P.kwargs) -> R:
self.inject_config()
return func(self, *args, **kwargs)
return wrapper
@attrs.define
class Service(t.Generic[T]):
"""A Bentoml service that can be served by BentoML server."""
config: Config
inner: type[T]
bento: t.Optional[Bento] = attrs.field(init=False, default=None)
models: list[Model] = attrs.field(factory=list)
apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict)
dependencies: dict[str, Dependency[t.Any]] = attrs.field(factory=dict, init=False)
mount_apps: list[tuple[ext.ASGIApp, str, str]] = attrs.field(
factory=list, init=False
)
middlewares: list[tuple[type[ext.AsgiMiddleware], dict[str, t.Any]]] = attrs.field(
factory=list, init=False
)
# service context
context: ServiceContext = attrs.field(init=False, factory=ServiceContext)
working_dir: str = attrs.field(init=False, factory=os.getcwd)
# import info
_caller_module: str = attrs.field(init=False)
_import_str: str | None = attrs.field(init=False, default=None)
def __attrs_post_init__(self) -> None:
from .dependency import Dependency
for field in dir(self.inner):
value = getattr(self.inner, field)
if isinstance(value, Dependency):
self.dependencies[field] = t.cast(Dependency[t.Any], value)
elif isinstance(value, Model):
self.models.append(value)
elif isinstance(value, APIMethod):
self.apis[field] = t.cast("APIMethod[..., t.Any]", value)
pre_mount_apps = getattr(self.inner, "__bentoml_mounted_apps__", [])
if pre_mount_apps:
self.mount_apps.extend(pre_mount_apps)
delattr(self.inner, "__bentoml_mounted_apps__")
def __hash__(self):
return hash(self.name)
@_caller_module.default # type: ignore
def _get_caller_module(self) -> str:
if __name__ == "__main__":
return __name__
current_frame = inspect.currentframe()
frame = current_frame
while frame:
this_name = frame.f_globals["__name__"]
if this_name != __name__:
return this_name
frame = frame.f_back
return __name__
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name!r}>"
@lru_cache
def find_dependent(self, name_or_path: str) -> Service[t.Any]:
"""Find a service by name or path"""
attr_name, _, path = name_or_path.partition(".")
if attr_name not in self.dependencies:
if attr_name in self.all_services():
return self.all_services()[attr_name]
else:
raise ValueError(f"Service {attr_name} not found")
if path:
return self.dependencies[attr_name].on.find_dependent(path)
return self.dependencies[attr_name].on
@lru_cache(maxsize=1)
def all_services(self) -> dict[str, Service[t.Any]]:
"""Get a map of the service and all recursive dependencies"""
services: dict[str, Service[t.Any]] = {self.name: self}
for dependency in self.dependencies.values():
services.update(dependency.on.all_services())
return services
@property
def doc(self) -> str:
from bentoml._internal.bento.bento import get_default_svc_readme
if self.bento is not None:
return self.bento.doc
return get_default_svc_readme(self)
def schema(self) -> dict[str, t.Any]:
return dict_filter_none(
{
"name": self.name,
"type": "service",
"routes": [method.schema() for method in self.apis.values()],
"description": getattr(self.inner, "__doc__", None),
}
)
@property
def name(self) -> str:
name = self.config.get("name") or self.inner.__name__
return name
@property
def import_string(self) -> str:
if self._import_str is None:
import_module = self._caller_module
if import_module == "__main__":
if hasattr(sys.modules["__main__"], "__file__"):
import_module = sys.modules["__main__"].__file__
assert isinstance(import_module, str)
try:
import_module_path = pathlib.Path(import_module).relative_to(
self.working_dir
)
except ValueError:
raise BentoMLException(
"Failed to get service import origin, service object defined in __main__ module is not supported"
)
import_module = str(import_module_path.with_suffix("")).replace(
os.path.sep, "."
)
else:
raise BentoMLException(
"Failed to get service import origin, service object defined interactively in console or notebook is not supported"
)
if self._caller_module not in sys.modules:
raise BentoMLException(
"Failed to get service import origin, service object must be defined in a module"
)
for name, value in vars(sys.modules[self._caller_module]).items():
if value is self:
self._import_str = f"{import_module}:{name}"
break
else:
raise BentoMLException(
"Failed to get service import origin, service object must be assigned to a variable at module level"
)
return self._import_str
def to_asgi(self, is_main: bool = True, init: bool = False) -> ext.ASGIApp:
from _bentoml_impl.server.app import ServiceAppFactory
self.inject_config()
factory = ServiceAppFactory(self, is_main=is_main)
if init:
factory.create_instance()
return factory()
def mount_asgi_app(
self, app: ext.ASGIApp, path: str = "/", name: str | None = None
) -> None:
self.mount_apps.append((app, path, name)) # type: ignore
def mount_wsgi_app(
self, app: ext.WSGIApp, path: str = "/", name: str | None = None
) -> None:
# TODO: Migrate to a2wsgi
from starlette.middleware.wsgi import WSGIMiddleware
self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore
def add_asgi_middleware(
self, middleware_cls: type[ext.AsgiMiddleware], **options: t.Any
) -> None:
self.middlewares.append((middleware_cls, options))
def __call__(self) -> T:
try:
instance = self.inner()
instance.to_async = _AsyncWrapper(instance, self.apis.keys())
return instance
except Exception:
logger.exception("Initializing service error")
raise
@property
def openapi_spec(self) -> OpenAPISpecification:
from .openapi import generate_spec
return generate_spec(self)
def inject_config(self) -> None:
from bentoml._internal.configuration import load_config
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.configuration.containers import config_merger
# XXX: ensure at least one item to make `flatten_dict` work
override_defaults = {
"services": {
name: (svc.config or {"workers": 1})
for name, svc in self.all_services().items()
}
}
load_config(override_defaults=override_defaults, use_version=2)
main_config = BentoMLContainer.config.services[self.name].get()
api_server_keys = (
"traffic",
"metrics",
"logging",
"ssl",
"http",
"grpc",
"backlog",
"runner_probe",
"max_runner_connections",
)
api_server_config = {
k: main_config[k] for k in api_server_keys if main_config.get(k) is not None
}
rest_config = {
k: main_config[k] for k in main_config if k not in api_server_keys
}
existing = t.cast(t.Dict[str, t.Any], BentoMLContainer.config.get())
config_merger.merge(existing, {"api_server": api_server_config, **rest_config})
BentoMLContainer.config.set(existing) # type: ignore
@with_config
@inject
def serve_http(
self,
*,
working_dir: str | None = None,
port: int = Provide[BentoMLContainer.http.port],
host: str = Provide[BentoMLContainer.http.host],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
timeout: int | None = None,
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[
BentoMLContainer.ssl.keyfile_password
],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
development_mode: bool = False,
reload: bool = False,
threaded: bool = False,
) -> Server:
from _bentoml_impl.server import serve_http
from bentoml._internal.log import configure_logging
configure_logging()
return serve_http(
self,
working_dir=working_dir,
host=host,
port=port,
backlog=backlog,
timeout=timeout,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
bentoml_home=bentoml_home,
development_mode=development_mode,
reload=reload,
threaded=threaded,
)
@t.overload
def service(inner: type[T], /) -> Service[T]: ...
@t.overload
def service(inner: None = ..., /, **kwargs: Unpack[Config]) -> _ServiceDecorator: ...
[docs]def service(inner: type[T] | None = None, /, **kwargs: Unpack[Config]) -> t.Any:
"""Mark a class as a BentoML service.
Example:
@service(traffic={"timeout": 60})
class InferenceService:
@api
def predict(self, input: str) -> str:
return input
"""
config = kwargs
def decorator(inner: type[T]) -> Service[T]:
if isinstance(inner, Service):
raise TypeError("service() decorator can only be applied once")
return Service(config=config, inner=inner)
return decorator(inner) if inner is not None else decorator
[docs]def runner_service(runner: Runner, **kwargs: Unpack[Config]) -> Service[t.Any]:
"""Make a service from a legacy Runner"""
if not isinstance(runner, Runner): # type: ignore
raise ValueError(f"Expect an instance of Runner, but got {type(runner)}")
class RunnerHandle(runner.runnable_class):
def __init__(self) -> None:
super().__init__(**runner.runnable_init_params)
RunnerHandle.__name__ = runner.name
apis: dict[str, APIMethod[..., t.Any]] = {}
assert runner.runnable_class.bentoml_runnable_methods__ is not None
for method in runner.runner_methods:
runnable_method = runner.runnable_class.bentoml_runnable_methods__[method.name]
api = APIMethod( # type: ignore
func=runnable_method.func,
name=method.name,
batchable=runnable_method.config.batchable,
batch_dim=runnable_method.config.batch_dim,
max_batch_size=method.max_batch_size,
max_latency_ms=method.max_latency_ms,
)
apis[method.name] = api
config: Config = {}
resource_config = runner.resource_config or {}
if (
"nvidia.com/gpu" in runner.runnable_class.SUPPORTED_RESOURCES
and "nvidia.com/gpu" in resource_config
):
gpus: list[int] | str | int = resource_config["nvidia.com/gpu"]
if isinstance(gpus, str):
gpus = int(gpus)
if runner.workers_per_resource > 1:
config["workers"] = {}
workers_per_resource = int(runner.workers_per_resource)
if isinstance(gpus, int):
gpus = list(range(gpus))
for i in gpus:
config["workers"].extend([{"gpus": i}] * workers_per_resource)
else:
resources_per_worker = int(1 / runner.workers_per_resource)
if isinstance(gpus, int):
config["workers"] = [
{"gpus": resources_per_worker}
for _ in range(gpus // resources_per_worker)
]
else:
config["workers"] = [
{"gpus": gpus[i : i + resources_per_worker]}
for i in range(0, len(gpus), resources_per_worker)
]
elif "cpus" in resource_config:
config["workers"] = (
math.ceil(resource_config["cpus"]) * runner.workers_per_resource
)
config.update(kwargs)
return Service(config=config, inner=RunnerHandle, models=runner.models, apis=apis)
class _AsyncWrapper:
def __init__(self, wrapped: t.Any, apis: t.Iterable[str]) -> None:
self.__call = None
for name in apis:
if name == "__call__":
self.__call = self.__make_method(wrapped, name)
else:
setattr(self, name, self.__make_method(wrapped, name))
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self.__call is None:
raise TypeError("This service is not callable.")
return self.__call(*args, **kwargs)
def __make_method(self, inner: t.Any, name: str) -> t.Any:
import asyncio
import anyio.to_thread
original_func = func = getattr(inner, name)
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if is_async_func:
return func
if inspect.isgeneratorfunction(original_func):
async def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.AsyncGenerator[t.Any, None]:
gen = func(*args, **kwargs)
next_fun = gen.__next__
while True:
try:
yield await anyio.to_thread.run_sync(next_fun)
except StopIteration:
break
except RuntimeError as e:
if "raised StopIteration" in str(e):
break
raise
return wrapped_gen
else:
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return await anyio.to_thread.run_sync(partial(func, **kwargs), *args)
return wrapped