from __future__ import annotations
import asyncio
import inspect
import logging
import math
import os
import pathlib
import sys
import typing as t
from functools import lru_cache
from functools import partial
from urllib.parse import urlsplit
import anyio.to_thread
import attrs
from simple_di import Provide
from simple_di import inject
from starlette.applications import Starlette
from typing_extensions import Unpack
from bentoml._internal.bento.bento import Bento
from bentoml._internal.bento.build_config import BentoEnvSchema
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.context import ServiceContext
from bentoml._internal.models import Model as StoredModel
from bentoml._internal.utils import deprecated
from bentoml._internal.utils import dict_filter_none
from bentoml._internal.utils.uri import join_paths
from bentoml.exceptions import BentoMLConfigException
from bentoml.exceptions import BentoMLException
from bentoml.legacy import Runner
from ..images import Image
from ..method import APIMethod
from ..models import BentoModel
from ..models import HuggingFaceModel
from ..models import Model
from .config import ServiceConfig as Config
logger = logging.getLogger("bentoml.serve")
T = t.TypeVar("T")
T_co = t.TypeVar("T_co", covariant=True)
if t.TYPE_CHECKING:
from bentoml._internal import external_typing as ext
from bentoml._internal.service.openapi.specification import OpenAPISpecification
from bentoml._internal.utils.circus import Server
from .dependency import Dependency
P = t.ParamSpec("P")
R = t.TypeVar("R")
class _ServiceDecorator(t.Protocol):
def __call__(self, inner: type[T]) -> Service[T]: ...
class PathMetadata(t.TypedDict):
mounted: bool
class ServiceEnvConfig(t.TypedDict, total=False):
name: str
value: str
stage: t.Literal["all", "build", "runtime"]
def with_config(
func: t.Callable[t.Concatenate["Service[t.Any]", P], R],
) -> t.Callable[t.Concatenate["Service[t.Any]", P], R]:
def wrapper(self: Service[t.Any], *args: P.args, **kwargs: P.kwargs) -> R:
self.inject_config()
return func(self, *args, **kwargs)
return wrapper
def convert_envs(envs: t.List[ServiceEnvConfig]) -> t.List[BentoEnvSchema]:
return [BentoEnvSchema(**env) for env in envs]
class _DummyService:
pass
@attrs.define
class Service(t.Generic[T_co]):
"""A Bentoml service that can be served by BentoML server."""
name: str
config: Config = attrs.field(factory=Config)
inner: type[T_co] = _DummyService
image: t.Optional[Image] = None
description: t.Optional[str] = None
path_prefix: str = ""
envs: t.List[BentoEnvSchema] = attrs.field(factory=list, converter=convert_envs)
labels: t.Dict[str, str] = attrs.field(factory=dict)
models: list[Model[t.Any]] = attrs.field(factory=list)
cmd: t.Optional[t.List[str]] = None
bento: t.Optional[Bento] = attrs.field(init=False, default=None)
apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict)
dependencies: dict[str, Dependency[t.Any]] = attrs.field(factory=dict, init=False)
mount_apps: list[tuple[ext.ASGIApp, str, str]] = attrs.field(
factory=list, init=False
)
middlewares: list[tuple[type[ext.AsgiMiddleware], dict[str, t.Any]]] = attrs.field(
factory=list, init=False
)
# service context
context: ServiceContext = attrs.field(init=False, factory=ServiceContext)
working_dir: str = attrs.field(init=False, factory=os.getcwd)
# import info
_caller_module: str = attrs.field(init=False)
_import_str: str | None = attrs.field(init=False, default=None)
def __attrs_post_init__(self) -> None:
from .dependency import Dependency
has_task = False
for field in dir(self.inner):
value = getattr(self.inner, field)
if isinstance(value, Dependency):
self.dependencies[field] = value
elif isinstance(value, StoredModel):
logger.warning(
"`bentoml.models.get()` as the class attribute is not recommended because it requires the model"
f" to exist at import time. Use `{value._attr} = BentoModel({str(value.tag)!r})` instead."
)
self.models.append(BentoModel(value.tag))
elif isinstance(value, Model):
self.models.append(t.cast(Model[t.Any], value))
elif isinstance(value, APIMethod):
if value.is_task:
has_task = True
self.apis[field] = t.cast("APIMethod[..., t.Any]", value)
if has_task:
traffic = self.config.setdefault("traffic", {})
traffic["external_queue"] = True
traffic.setdefault("concurrency", 1)
pre_mount_apps = getattr(self.inner, "__bentoml_mounted_apps__", [])
if pre_mount_apps:
for app, path, name in pre_mount_apps:
self.mount_asgi_app(app, path=path, name=name)
delattr(self.inner, "__bentoml_mounted_apps__")
if self.config.get("workers") is None and self.has_custom_command():
from bentoml._internal.resource import system_resources
resources = system_resources()
workers = min(16, int(resources["cpu"] / 2) or 1)
self.config["workers"] = workers
if self.path_prefix:
livez_endpoint = join_paths(self.path_prefix, "livez")
readyz_endpoint = join_paths(self.path_prefix, "readyz")
self.config.setdefault("endpoints", {}).setdefault("livez", livez_endpoint)
self.config.setdefault("endpoints", {}).setdefault(
"readyz", readyz_endpoint
)
def __hash__(self):
return hash(self.name)
def has_custom_command(self) -> bool:
return hasattr(self.inner, "__command__") or self.cmd is not None
@_caller_module.default # type: ignore
def _get_caller_module(self) -> str:
if __name__ == "__main__":
return __name__
current_frame = inspect.currentframe()
frame = current_frame
while frame:
this_name = frame.f_globals["__name__"]
if this_name != __name__:
return this_name
frame = frame.f_back
return __name__
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name!r}>"
@lru_cache
def find_dependent_by_path(self, path: str) -> Service[t.Any]:
"""Find a service by path"""
attr_name, _, path = path.partition(".")
if attr_name not in self.dependencies:
if attr_name in self.all_services():
return self.all_services()[attr_name]
else:
raise BentoMLException(f"Service {attr_name} not found")
dependent = self.dependencies[attr_name]
if dependent.on is None:
raise BentoMLException(f"Service {attr_name} not found")
if path:
return dependent.on.find_dependent_by_path(path)
return dependent
def find_dependent_by_name(self, name: str) -> Service[t.Any]:
"""Find a service by name"""
try:
return self.all_services()[name]
except KeyError:
raise BentoMLException(f"Service {name} not found") from None
@property
def url(self) -> str | None:
"""Get the URL of the service, or None if the service is not served"""
dependency_map = BentoMLContainer.remote_runner_mapping.get()
url = dependency_map.get(self.name)
return url.replace("tcp://", "http://") if url else None
async def get_hosts(self) -> list[str]:
"""Return a list of IPs of the service"""
import httpx
from _bentoml_impl.server.allocator import ResourceAllocator
url = BentoMLContainer.remote_runner_mapping.get().get(self.name)
if not url:
raise BentoMLException(f"Service {self.name} not found")
url = url.replace("tcp://", "http://")
workers, _ = ResourceAllocator().get_worker_env(self)
if url.startswith("file://"):
# UDS connections don't have a port number, use a fake one.
return ["127.0.0.1:3000"] * workers
if not url.startswith("http://"):
raise BentoMLException(
f"Unable to get hosts for service {self.name} because it is not served as HTTP"
)
if "BENTOCLOUD_DEPLOYMENT_URL" in os.environ:
# BentoCloud environment, the url is to runner-lb
headers = {"Runner-Name": self.name, "Resolve-Runner": "1"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
if response.is_error:
raise BentoMLException(
f"Failed to get hosts for service {self.name} from cloud"
)
await response.aread()
result = response.json()
return [f"{ip}:{result['port']}" for ip in result.get("ips", [])]
# Serving locally, the hostname should be the IP
return [urlsplit(url).netloc] * workers
def all_services(self, exclude_urls: bool = False) -> dict[str, Service[t.Any]]:
"""Get a map of the service and all recursive dependencies"""
services: dict[str, Service[t.Any]] = {self.name: self}
for dependency in self.dependencies.values():
if dependency.on is None:
continue
if exclude_urls and (dependency.url or dependency.deployment):
continue
dependents = dependency.on.all_services(exclude_urls=exclude_urls)
conflict = next(
(
k
for k in dependents
if k in services and dependents[k] is not services[k]
),
None,
)
if conflict:
raise BentoMLConfigException(
f"Dependency conflict: {conflict} is already defined by {services[conflict].inner}"
)
services.update(dependents)
return services
@property
def doc(self) -> str:
from bentoml._internal.bento.bento import get_default_svc_readme
if self.bento is not None:
return self.bento.doc
return get_default_svc_readme(self)
def schema(self) -> dict[str, t.Any]:
# Add API method routes (these are already full paths from method.route)
all_paths: dict[str, PathMetadata] = dict(
(join_paths(self.path_prefix, method.route), {"mounted": False})
for method in self.apis.values()
)
# Store id(app) to avoid reprocessing
processed_mounted_apps: set[int] = set()
def extract_routes_from_asgi_app(app: t.Any, current_prefix: str) -> None:
if app is None or id(app) in processed_mounted_apps:
return
processed_mounted_apps.add(id(app))
# Introspect ASGI app (FastAPI/Starlette)
if issubclass(app.__class__, Starlette):
for route_item in app.routes:
if hasattr(route_item, "path") and isinstance(route_item.path, str):
item_specific_path = route_item.path
# current_prefix is the path *to* this app.
# item_specific_path is relative to this app.
full_item_path = join_paths(current_prefix, item_specific_path)
all_paths[full_item_path] = {"mounted": True}
if hasattr(route_item, "app") and route_item.app is not None:
extract_routes_from_asgi_app(route_item.app, full_item_path)
for mounted_app, mount_path_prefix, _ in self.mount_apps:
normalized_base_prefix = join_paths(mount_path_prefix)
extract_routes_from_asgi_app(mounted_app, normalized_base_prefix)
routes = [method.schema() for method in self.apis.values()]
if self.path_prefix:
for route in routes:
route["route"] = join_paths(self.path_prefix, route["route"])
return dict_filter_none(
{
"name": self.name,
"type": "service",
"routes": routes,
"description": getattr(self.inner, "__doc__", None),
"paths": all_paths,
}
)
@property
def import_string(self) -> str:
if self._import_str is None:
import_module = self._caller_module
if import_module == "__main__":
if hasattr(sys.modules["__main__"], "__file__"):
import_module = sys.modules["__main__"].__file__
assert isinstance(import_module, str)
try:
import_module_path = pathlib.Path(import_module).relative_to(
self.working_dir
)
except ValueError:
raise BentoMLException(
"Failed to get service import origin, service object defined in __main__ module is not supported"
)
import_module = str(import_module_path.with_suffix("")).replace(
os.path.sep, "."
)
else:
raise BentoMLException(
"Failed to get service import origin, service object defined interactively in console or notebook is not supported"
)
if self._caller_module not in sys.modules:
raise BentoMLException(
"Failed to get service import origin, service object must be defined in a module"
)
for name, value in vars(sys.modules[self._caller_module]).items():
if value is self:
self._import_str = f"{import_module}:{name}"
break
else:
raise BentoMLException(
"Failed to get service import origin, service object must be assigned to a variable at module level"
)
return self._import_str
def to_asgi(self, is_main: bool = True) -> ext.ASGIApp:
from _bentoml_impl.server.app import ServiceAppFactory
self.inject_config()
factory = ServiceAppFactory(self, is_main=is_main)
return factory()
def mount_asgi_app(
self, app: ext.ASGIApp, path: str = "/", name: str | None = None
) -> None:
self.mount_apps.append((app, join_paths(self.path_prefix, path), name)) # type: ignore
def mount_wsgi_app(
self, app: ext.WSGIApp, path: str = "/", name: str | None = None
) -> None:
from a2wsgi import WSGIMiddleware
self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore
def add_asgi_middleware(
self, middleware_cls: type[ext.AsgiMiddleware], **options: t.Any
) -> None:
self.middlewares.append((middleware_cls, options))
def gradio_app_startup_hook(self, max_concurrency: int):
gradio_apps = getattr(self.inner, "__bentoml_gradio_apps__", [])
if gradio_apps:
for gradio_app, path, _ in gradio_apps:
logger.info(f"Initializing gradio app at: {path or '/'}")
blocks = gradio_app.get_blocks()
blocks.queue(default_concurrency_limit=max_concurrency)
if hasattr(blocks, "startup_events"):
# gradio < 5.0
blocks.startup_events()
else:
# gradio >= 5.0
blocks.run_startup_events()
delattr(self.inner, "__bentoml_gradio_apps__")
def __call__(self) -> T_co:
try:
instance = self.inner()
instance.to_async = _AsyncWrapper(instance, self.apis.keys())
instance.to_sync = _SyncWrapper(instance, self.apis.keys())
instance.bento_service = self
return instance
except Exception:
logger.exception("Initializing service error")
raise
@property
def openapi_spec(self) -> OpenAPISpecification:
from .openapi import generate_spec
return generate_spec(self)
def inject_config(self) -> None:
from bentoml._internal.configuration import load_config
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.utils import deep_merge
# XXX: ensure at least one item to make `flatten_dict` work
override_defaults = {
"services": {
name: (svc.config or {"workers": 1})
for name, svc in self.all_services().items()
}
}
load_config(override_defaults=override_defaults, use_version=2)
main_config = BentoMLContainer.config.services[self.name].get()
api_server_keys = (
"traffic",
"metrics",
"logging",
"ssl",
"http",
"grpc",
"backlog",
"runner_probe",
"max_runner_connections",
)
api_server_config = {
k: main_config[k] for k in api_server_keys if main_config.get(k) is not None
}
rest_config = {
k: main_config[k] for k in main_config if k not in api_server_keys
}
existing = t.cast(t.Dict[str, t.Any], BentoMLContainer.config.get())
deep_merge(existing, {"api_server": api_server_config, **rest_config})
BentoMLContainer.config.set(existing) # type: ignore
@with_config
@inject
def serve_http(
self,
*,
working_dir: str | None = None,
port: int = Provide[BentoMLContainer.http.port],
host: str = Provide[BentoMLContainer.http.host],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
timeout: int | None = None,
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[
BentoMLContainer.ssl.keyfile_password
],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
development_mode: bool = False,
reload: bool = False,
threaded: bool = False,
) -> Server:
from _bentoml_impl.server import serve_http
from bentoml._internal.log import configure_logging
configure_logging()
return serve_http(
self,
working_dir=working_dir,
host=host,
port=port,
backlog=backlog,
timeout=timeout,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
bentoml_home=bentoml_home,
development_mode=development_mode,
reload=reload,
threaded=threaded,
)
def on_load_bento(self, bento: Bento) -> None:
for svc in self.all_services().values():
service_info = next(
info for info in bento.info.services if info.name == svc.name
)
for model, info in zip(svc.models, service_info.models):
# Replace the model version with the one in the Bento
if not isinstance(model, HuggingFaceModel):
continue
model_id = info.metadata.get("model_id") # use the case in bento info
if not model_id:
model_id = info.tag.name.replace("--", "/")
revision = info.metadata.get("revision", info.tag.version)
model.model_id = model_id
model.revision = revision
svc.bento = bento
def needs_task_db(self) -> bool:
if "BENTOCLOUD_DEPLOYMENT_URL" in os.environ:
return False
return any(method.is_task for method in self.apis.values())
@t.overload
def service(inner: type[T], /) -> Service[T]: ...
@t.overload
def service(
*,
name: str | None = None,
image: Image | None = None,
description: str | None = None,
path_prefix: str | None = None,
envs: list[ServiceEnvConfig] | None = None,
labels: dict[str, str] | None = None,
cmd: list[str] | None = None,
service_class: type[Service[T]] = Service,
**kwargs: Unpack[Config],
) -> _ServiceDecorator: ...
[docs]
def service(
inner: type[T] | None = None,
/,
*,
name: str | None = None,
image: Image | None = None,
description: str | None = None,
path_prefix: str | None = None,
envs: list[ServiceEnvConfig] | None = None,
labels: dict[str, str] | None = None,
cmd: list[str] | None = None,
service_class: type[Service[T]] = Service,
**kwargs: Unpack[Config],
) -> t.Any:
"""Mark a class as a BentoML service.
Args:
name: The name of the service. Defaults to the class name.
image: The image to use for the service.
description: A description of the service.
path_prefix: A URL path prefix to apply to all API endpoints of this service.
For example, setting ``path_prefix="/v1"`` will make an endpoint
``/predict`` available at ``/v1/predict``. This also applies to mounted
ASGI applications and health check endpoints.
envs: Environment variables to set for the service.
labels: Labels to attach to the service.
cmd: A custom command to start the service.
**kwargs: Additional service configurations such as ``traffic``, ``resources``,
``workers``, etc.
Example:
@service(traffic={"timeout": 60})
class InferenceService:
@api
def predict(self, input: str) -> str:
return input
"""
config = kwargs
def decorator(inner: type[T]) -> Service[T]:
if isinstance(inner, Service):
raise TypeError("service() decorator can only be applied once")
return service_class(
name=name or inner.__name__,
config=config,
inner=inner,
image=image,
description=description,
path_prefix=path_prefix,
envs=envs or [],
labels=labels or {},
cmd=cmd,
)
return decorator(inner) if inner is not None else decorator
[docs]
@deprecated()
def runner_service(runner: Runner, **kwargs: Unpack[Config]) -> Service[t.Any]:
"""Make a service from a legacy Runner"""
if not isinstance(runner, Runner): # type: ignore
raise ValueError(f"Expect an instance of Runner, but got {type(runner)}")
class RunnerHandle(runner.runnable_class):
def __init__(self) -> None:
super().__init__(**runner.runnable_init_params)
apis: dict[str, APIMethod[..., t.Any]] = {}
assert runner.runnable_class.bentoml_runnable_methods__ is not None
for method in runner.runner_methods:
runnable_method = runner.runnable_class.bentoml_runnable_methods__[method.name]
api = APIMethod( # type: ignore
func=runnable_method.func,
batchable=runnable_method.config.batchable,
batch_dim=runnable_method.config.batch_dim,
max_batch_size=method.max_batch_size,
max_latency_ms=method.max_latency_ms,
)
apis[method.name] = api
config: Config = {}
resource_config = runner.resource_config or {}
if (
"nvidia.com/gpu" in runner.runnable_class.SUPPORTED_RESOURCES
and "nvidia.com/gpu" in resource_config
):
gpus: list[int] | str | int = resource_config["nvidia.com/gpu"]
if isinstance(gpus, str):
gpus = int(gpus)
elif isinstance(gpus, list):
gpus = len(gpus)
config["workers"] = int(gpus * runner.workers_per_resource)
elif "cpus" in resource_config:
config["workers"] = (
math.ceil(resource_config["cpus"]) * runner.workers_per_resource
)
config.update(kwargs)
return Service(
name=runner.name,
config=config,
inner=RunnerHandle,
models=[BentoModel(m.tag) for m in runner.models],
apis=apis,
)
class _Wrapper:
def __init__(self, wrapped: t.Any, apis: t.Iterable[str]) -> None:
self.__call = None
for name in apis:
if name == "__call__":
self.__call = self._make_method(wrapped, name)
else:
setattr(self, name, self._make_method(wrapped, name))
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self.__call is None:
raise TypeError("This service is not callable.")
return self.__call(*args, **kwargs)
def _make_method(self, instance: t.Any, name: str) -> t.Any:
raise NotImplementedError
class _AsyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if is_async_func:
return func
if inspect.isgeneratorfunction(original_func):
async def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.AsyncGenerator[t.Any, None]:
gen = func(*args, **kwargs)
next_fun = gen.__next__
while True:
try:
yield await anyio.to_thread.run_sync(next_fun)
except StopIteration:
break
except RuntimeError as e:
if "raised StopIteration" in str(e):
break
raise
return wrapped_gen
else:
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return await anyio.to_thread.run_sync(partial(func, **kwargs), *args)
return wrapped
class _SyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if not is_async_func:
return func
if inspect.isasyncgenfunction(original_func):
def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.Generator[t.Any, None, None]:
agen = func(*args, **kwargs)
loop = asyncio.get_event_loop()
while True:
try:
yield loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break
return wrapped_gen
else:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(func(*args, **kwargs))
return wrapped