Source code for _bentoml_impl.client.http

from __future__ import annotations

import inspect
import io
import json
import logging
import mimetypes
import os
import pathlib
import tempfile
import time
import typing as t
from abc import abstractmethod
from functools import cached_property
from http import HTTPStatus
from urllib.parse import urljoin
from urllib.parse import urlparse

import attr
import httpx
from pydantic import RootModel

from _bentoml_sdk import IODescriptor
from _bentoml_sdk.typing_utils import is_image_type
from bentoml import __version__
from bentoml._internal.utils.uri import uri_to_path
from bentoml.exceptions import BentoMLException

from ..serde import Payload
from .base import AbstractClient
from .base import ClientEndpoint

if t.TYPE_CHECKING:
    from httpx._types import RequestFiles

    from _bentoml_sdk import Service

    from ..serde import Serde

    T = t.TypeVar("T", bound="HTTPClient[t.Any]")
    A = t.TypeVar("A")

C = t.TypeVar("C", httpx.Client, httpx.AsyncClient)
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
logger = logging.getLogger("bentoml.io")
MAX_RETRIES = 3


def is_http_url(url: str) -> bool:
    return urlparse(url).scheme in {"http", "https"}


def to_async_iterable(iterable: t.Iterable[A]) -> t.AsyncIterable[A]:
    async def _gen() -> t.AsyncIterator[A]:
        for item in iterable:
            yield item

    return _gen()


@attr.define
class HTTPClient(AbstractClient, t.Generic[C]):
    client_cls: t.ClassVar[type[httpx.Client] | type[httpx.AsyncClient]]

    url: str
    endpoints: dict[str, ClientEndpoint] = attr.field(factory=dict)
    media_type: str = "application/json"
    timeout: float = 30
    default_headers: dict[str, str] = attr.field(factory=dict)

    _opened_files: list[io.BufferedReader] = attr.field(init=False, factory=list)
    _temp_dir: tempfile.TemporaryDirectory[str] = attr.field(init=False)

    @staticmethod
    def _make_client(
        client_cls: type[AnyClient],
        url: str,
        headers: t.Mapping[str, str],
        timeout: float,
    ) -> AnyClient:
        parsed = urlparse(url)
        transport = None
        if parsed.scheme == "file":
            uds = uri_to_path(url)
            if client_cls is httpx.Client:
                transport = httpx.HTTPTransport(uds=uds)
            else:
                transport = httpx.AsyncHTTPTransport(uds=uds)
            url = "http://127.0.0.1:3000"
        elif parsed.scheme == "tcp":
            url = f"http://{parsed.netloc}"
        return client_cls(
            base_url=url,
            transport=transport,  # type: ignore
            headers=headers,
            timeout=timeout,
            follow_redirects=True,
        )

    @_temp_dir.default  # type: ignore
    def default_temp_dir(self) -> tempfile.TemporaryDirectory[str]:
        return tempfile.TemporaryDirectory(prefix="bentoml-client-")

    def __init__(
        self,
        url: str,
        *,
        media_type: str = "application/json",
        service: Service[t.Any] | None = None,
        server_ready_timeout: float | None = None,
        token: str | None = None,
        timeout: float = 30,
    ) -> None:
        """Create a client instance from a URL.

        Args:
            url: The URL of the BentoML service.
            media_type: The media type to use for serialization. Defaults to
                "application/json".

        .. note::

            The client created with this method can only return primitive types without a model.
        """
        routes: dict[str, ClientEndpoint] = {}
        default_headers = {"User-Agent": f"BentoML HTTP Client/{__version__}"}
        if token is None:
            token = os.getenv("BENTO_CLOUD_API_KEY")
        if token:
            default_headers["Authorization"] = f"Bearer {token}"

        if service is not None:
            for name, method in service.apis.items():
                routes[name] = ClientEndpoint(
                    name=name,
                    route=method.route,
                    input=method.input_spec.model_json_schema(),
                    output=method.output_spec.model_json_schema(),
                    doc=method.doc,
                    input_spec=method.input_spec,
                    output_spec=method.output_spec,
                    stream_output=method.is_stream,
                )

            from bentoml._internal.context import server_context

            default_headers.update(
                {
                    "Bento-Name": server_context.bento_name,
                    "Bento-Version": server_context.bento_version,
                    "Runner-Name": service.name,
                    "Yatai-Bento-Deployment-Name": server_context.yatai_bento_deployment_name,
                    "Yatai-Bento-Deployment-Namespace": server_context.yatai_bento_deployment_namespace,
                }
            )
        self.__attrs_init__(  # type: ignore
            url=url,
            endpoints=routes,
            media_type=media_type,
            default_headers=default_headers,
            timeout=timeout,
        )
        if server_ready_timeout is None or server_ready_timeout > 0:
            self.wait_until_server_ready(server_ready_timeout)
        if service is None:
            schema_url = urljoin(url, "/schema.json")

            with self._make_client(
                httpx.Client, url, default_headers, timeout
            ) as client:
                resp = client.get("/schema.json")

                if resp.is_error:
                    raise BentoMLException(f"Failed to fetch schema from {schema_url}")
                for route in resp.json()["routes"]:
                    self.endpoints[route["name"]] = ClientEndpoint(
                        name=route["name"],
                        route=route["route"],
                        input=route["input"],
                        output=route["output"],
                        doc=route.get("doc"),
                        stream_output=route["output"].get("is_stream", False),
                    )
        super().__init__()

    @cached_property
    def client(self) -> C:
        return self._make_client(
            self.client_cls, self.url, self.default_headers, self.timeout
        )

    @cached_property
    def serde(self) -> Serde:
        from ..serde import ALL_SERDE

        return ALL_SERDE[self.media_type]()

    def _build_request(
        self,
        endpoint: ClientEndpoint,
        args: t.Sequence[t.Any],
        kwargs: dict[str, t.Any],
        headers: t.Mapping[str, str],
    ) -> httpx.Request:
        headers = httpx.Headers({"Content-Type": self.media_type, **headers})
        if endpoint.input_spec is not None:
            model = endpoint.input_spec.from_inputs(*args, **kwargs)
            if model.multipart_fields and self.media_type == "application/json":
                return self._build_multipart(endpoint, model, headers)
            else:
                payload = self.serde.serialize_model(model)
                headers.update(payload.headers)
                return self.client.build_request(
                    "POST",
                    endpoint.route,
                    headers=headers,
                    content=to_async_iterable(payload.data)
                    if self.client_cls is httpx.AsyncClient
                    else payload.data,
                )

        for name, value in zip(endpoint.input["properties"], args):
            if name in kwargs:
                raise TypeError(f"Duplicate argument {name}")
            kwargs[name] = value

        params = set(endpoint.input["properties"].keys())
        non_exist_args = set(kwargs.keys()) - set(params)
        if non_exist_args:
            raise TypeError(
                f"Arguments not found in endpoint {endpoint.name}: {non_exist_args}"
            )
        required = set(endpoint.input.get("required", []))
        missing_args = set(required) - set(kwargs.keys())
        if missing_args:
            raise TypeError(
                f"Missing required arguments in endpoint {endpoint.name}: {missing_args}"
            )
        has_file = any(
            schema.get("type") == "file"
            or schema.get("type") == "array"
            and schema["items"].get("type") == "file"
            for schema in endpoint.input["properties"].values()
        )
        if has_file and self.media_type == "application/json":
            return self._build_multipart(endpoint, kwargs, headers)
        payload = self.serde.serialize(kwargs, endpoint.input)
        headers.update(payload.headers)
        return self.client.build_request(
            "POST",
            endpoint.route,
            content=to_async_iterable(payload.data)
            if self.client_cls is httpx.AsyncClient
            else payload.data,
            headers=headers,
        )

    def wait_until_server_ready(self, timeout: int | None = None) -> None:
        if timeout is None:
            timeout = self.timeout
        with self._make_client(
            httpx.Client, self.url, self.default_headers, timeout
        ) as client:
            start = time.monotonic()
            while time.monotonic() - start < timeout:
                try:
                    resp = client.get("/readyz")
                    if resp.status_code == 200:
                        return
                except (httpx.TimeoutException, httpx.ConnectError):
                    pass
        raise BentoMLException(f"Server is not ready after {timeout} seconds")

    def _build_multipart(
        self,
        endpoint: ClientEndpoint,
        model: IODescriptor | dict[str, t.Any],
        headers: httpx.Headers,
    ) -> httpx.Request:
        def is_file_field(k: str) -> bool:
            if isinstance(model, IODescriptor):
                return k in model.multipart_fields
            if (f := endpoint.input["properties"].get(k, {})).get("type") == "file":
                return True
            if f.get("type") == "array" and f["items"].get("type") == "file":
                return True
            return False

        if isinstance(model, dict):
            fields = model
        else:
            fields = {k: getattr(model, k) for k in model.model_fields}
        data: dict[str, t.Any] = {}
        files: RequestFiles = []

        for name, value in fields.items():
            if not is_file_field(name):
                data[name] = json.dumps(value)
                continue
            if not isinstance(value, (list, tuple)):
                value = [value]

            for v in value:
                if isinstance(v, str) and not is_http_url(v):
                    v = pathlib.Path(v)
                if is_image_type(type(v)):
                    files.append(
                        (
                            name,
                            (
                                None,
                                getattr(v, "_fp", v.fp),
                                f"image/{v.format.lower()}",
                            ),
                        )
                    )
                elif isinstance(v, pathlib.PurePath):
                    file = open(v, "rb")
                    files.append((name, (v.name, file, mimetypes.guess_type(v)[0])))
                    self._opened_files.append(file)
                elif isinstance(v, str):
                    data.setdefault(name, []).append(v)
                else:
                    assert isinstance(v, t.BinaryIO)
                    filename = (
                        pathlib.Path(fn).name
                        if (fn := getattr(v, "name", None))
                        else None
                    )
                    content_type = (
                        mimetypes.guess_type(filename)[0] if filename else None
                    )
                    files.append((name, (filename, v, content_type)))
        headers.pop("content-type", None)
        return self.client.build_request(
            "POST", endpoint.route, data=data, files=files, headers=headers
        )

    def _deserialize_output(self, payload: Payload, endpoint: ClientEndpoint) -> t.Any:
        data = iter(payload.data)
        if endpoint.output_spec is not None:
            model = self.serde.deserialize_model(payload, endpoint.output_spec)
            if isinstance(model, RootModel):
                return model.root  # type: ignore
            return model
        elif (ot := endpoint.output.get("type")) == "string":
            return bytes(next(data)).decode("utf-8")
        elif ot == "bytes":
            return bytes(next(data))
        else:
            return self.serde.deserialize(payload, endpoint.output)

    def call(self, __name: str, /, *args: t.Any, **kwargs: t.Any) -> t.Any:
        try:
            endpoint = self.endpoints[__name]
        except KeyError:
            raise BentoMLException(f"Endpoint {__name} not found") from None
        if endpoint.stream_output:
            return self._get_stream(endpoint, args, kwargs)
        else:
            return self._call(endpoint, args, kwargs)

    @abstractmethod
    def _call(
        self,
        endpoint: ClientEndpoint,
        args: t.Sequence[t.Any],
        kwargs: dict[str, t.Any],
        *,
        headers: t.Mapping[str, str] | None = None,
    ) -> t.Any: ...

    @abstractmethod
    def _get_stream(
        self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any
    ) -> t.Any: ...


class SyncHTTPClient(HTTPClient[httpx.Client]):
    """A synchronous client for BentoML service.

    .. note:: Inner usage ONLY
    """

    client_cls = httpx.Client

    def __enter__(self: T) -> T:
        return self

    def __exit__(self, exc_type: t.Any, exc: t.Any, tb: t.Any) -> None:
        return self.close()

[docs] def is_ready(self, timeout: int | None = None) -> bool: try: resp = self.client.get( "/readyz", timeout=timeout or httpx.USE_CLIENT_DEFAULT ) return resp.status_code == 200 except httpx.TimeoutException: logger.warn("Timed out waiting for runner to be ready") return False
def close(self) -> None: if "client" in vars(self): self.client.close() def _get_stream( self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any ) -> t.Generator[t.Any, None, None]: resp = self._call(endpoint, args, kwargs) for data in resp: yield data def request(self, method: str, url: str, **kwargs: t.Any) -> httpx.Response: return self.client.request(method, url, **kwargs) def _call( self, endpoint: ClientEndpoint, args: t.Sequence[t.Any], kwargs: dict[str, t.Any], *, headers: t.Mapping[str, str] | None = None, ) -> t.Any: try: req = self._build_request(endpoint, args, kwargs, headers or {}) resp = self.client.send(req, stream=endpoint.stream_output) if resp.is_error: resp.read() raise BentoMLException( f"Error making request: {resp.status_code}: {resp.text}", error_code=HTTPStatus(resp.status_code), ) if endpoint.stream_output: return self._parse_stream_response(endpoint, resp) elif ( endpoint.output.get("type") == "file" and self.media_type == "application/json" ): return self._parse_file_response(endpoint, resp) else: return self._parse_response(endpoint, resp) finally: for f in self._opened_files: f.close() self._opened_files.clear() def _parse_response(self, endpoint: ClientEndpoint, resp: httpx.Response) -> t.Any: payload = Payload((resp.read(),), resp.headers) return self._deserialize_output(payload, endpoint) def _parse_stream_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.Generator[t.Any, None, None]: try: for data in resp.iter_bytes(): yield self._deserialize_output(Payload((data,), resp.headers), endpoint) finally: resp.close() def _parse_file_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> pathlib.Path: from multipart.multipart import parse_options_header content_disposition = resp.headers.get("content-disposition") filename: str | None = None if content_disposition: _, options = parse_options_header(content_disposition) if b"filename" in options: filename = str( options[b"filename"], resp.charset_encoding or "utf-8", errors="ignore", ) with tempfile.NamedTemporaryFile( "wb", suffix=filename, dir=self._temp_dir.name, delete=False ) as f: f.write(resp.read()) return pathlib.Path(f.name) class AsyncHTTPClient(HTTPClient[httpx.AsyncClient]): """An asynchronous client for BentoML service. .. note:: Inner usage ONLY """ client_cls = httpx.AsyncClient
[docs] async def is_ready(self, timeout: int | None = None) -> bool: try: resp = await self.client.get( "/readyz", timeout=timeout or httpx.USE_CLIENT_DEFAULT ) return resp.status_code == 200 except httpx.TimeoutException: logger.warn("Timed out waiting for runner to be ready") return False
async def _get_stream( self, endpoint: ClientEndpoint, args: t.Any, kwargs: t.Any ) -> t.AsyncGenerator[t.Any, None]: resp = await self._call(endpoint, args, kwargs) assert inspect.isasyncgen(resp) async for data in resp: yield data async def __aenter__(self: T) -> T: return self async def __aexit__(self, *args: t.Any) -> None: return await self.close() async def request(self, method: str, url: str, **kwargs: t.Any) -> httpx.Response: return await self.client.request(method, url, **kwargs) async def _call( self, endpoint: ClientEndpoint, args: t.Sequence[t.Any], kwargs: dict[str, t.Any], *, headers: t.Mapping[str, str] | None = None, ) -> t.Any: try: req = self._build_request(endpoint, args, kwargs, headers or {}) resp = await self.client.send(req, stream=endpoint.stream_output) if resp.is_error: await resp.aread() raise BentoMLException( f"Error making request: {resp.status_code}: {resp.text}", error_code=HTTPStatus(resp.status_code), ) if endpoint.stream_output: return self._parse_stream_response(endpoint, resp) elif ( endpoint.output.get("type") == "file" and self.media_type == "application/json" ): return await self._parse_file_response(endpoint, resp) else: return await self._parse_response(endpoint, resp) finally: for f in self._opened_files: f.close() self._opened_files.clear() async def _parse_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.Any: data = await resp.aread() return self._deserialize_output(Payload((data,), resp.headers), endpoint) async def _parse_stream_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> t.AsyncGenerator[t.Any, None]: try: async for data in resp.aiter_bytes(): yield self._deserialize_output(Payload((data,), resp.headers), endpoint) finally: await resp.aclose() async def _parse_file_response( self, endpoint: ClientEndpoint, resp: httpx.Response ) -> pathlib.Path: from multipart.multipart import parse_options_header content_disposition = resp.headers.get("content-disposition") filename: str | None = None if content_disposition: _, options = parse_options_header(content_disposition) if b"filename" in options: filename = str( options[b"filename"], resp.charset_encoding or "utf-8", errors="ignore", ) with tempfile.NamedTemporaryFile( "wb", suffix=filename, dir=self._temp_dir.name, delete=False ) as f: f.write(await resp.aread()) return pathlib.Path(f.name) async def close(self) -> None: if "client" in vars(self): await self.client.aclose()