Source code for _bentoml_sdk.validators

from __future__ import annotations

import contextlib
import fnmatch
import functools
import io
import operator
import os
import tempfile
import typing as t
from pathlib import Path
from pathlib import PurePath

import attrs
from annotated_types import BaseMetadata
from pydantic_core import core_schema
from starlette.datastructures import UploadFile

from bentoml._internal.utils import dict_filter_none

from .typing_utils import is_file_like
from .typing_utils import is_image_type

if t.TYPE_CHECKING:
    import numpy as np
    import pandas as pd
    import tensorflow as tf
    import torch
    from pydantic import GetCoreSchemaHandler
    from pydantic import GetJsonSchemaHandler
    from typing_extensions import Literal

    TensorType = t.Union[np.ndarray[t.Any, t.Any], tf.Tensor, torch.Tensor]
    TensorFormat = Literal["numpy-array", "tf-tensor", "torch-tensor"]
    from PIL import Image as PILImage
else:
    from bentoml._internal.utils.lazy_loader import LazyLoader

    np = LazyLoader("np", globals(), "numpy")
    tf = LazyLoader("tf", globals(), "tensorflow")
    torch = LazyLoader("torch", globals(), "torch")
    pa = LazyLoader("pa", globals(), "pyarrow")
    pd = LazyLoader("pd", globals(), "pandas")
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

T = t.TypeVar("T")

# This is an internal global state that is True when the model is being serialized for arrow
__in_arrow_serialization__ = False


@contextlib.contextmanager
def arrow_serialization():
    global __in_arrow_serialization__
    __in_arrow_serialization__ = True
    try:
        yield
    finally:
        __in_arrow_serialization__ = False


[docs]class PILImageEncoder:
[docs] def decode(self, obj: bytes | t.BinaryIO | UploadFile | PILImage.Image) -> t.Any: if is_image_type(type(obj)): return t.cast("PILImage.Image", obj) if isinstance(obj, UploadFile): formats = None if obj.headers.get("Content-Type", "").startswith("image/"): formats = [obj.headers["Content-Type"][6:].upper()] return PILImage.open(obj.file, formats=formats) if is_file_like(obj): return PILImage.open(obj) if isinstance(obj, bytes): return PILImage.open(io.BytesIO(obj)) return obj
[docs] def encode(self, obj: PILImage.Image) -> bytes: buffer = io.BytesIO() obj.save(buffer, format=obj.format or "PNG") return buffer.getvalue()
def __get_pydantic_core_schema__( self, source: type[t.Any], handler: t.Callable[[t.Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( function=self.decode, schema=core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema(self.encode), ) def __get_pydantic_json_schema__( self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> dict[str, t.Any]: value = handler(schema) if handler.mode == "validation": value.update({"type": "file", "format": "image"}) else: value.update({"type": "string", "format": "binary"}) return value
[docs]@attrs.define class FileSchema: format: str = "binary" content_type: str | None = None def __attrs_post_init__(self) -> None: if self.content_type is not None: self.format = self.content_type.split("/")[0] def __get_pydantic_json_schema__( self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> dict[str, t.Any]: value = handler(schema) if handler.mode == "validation": value.update({"type": "file", "format": self.format}) if self.content_type is not None: value.update({"content_type": self.content_type}) else: value.update({"type": "string", "format": "binary"}) return value
[docs] def encode(self, obj: Path) -> bytes: return obj.read_bytes()
[docs] def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any: from bentoml._internal.context import request_directory media_type: str | None = None if isinstance(obj, str): return obj if isinstance(obj, PurePath): return Path(obj) if isinstance(obj, UploadFile): body = obj.file.read() filename = obj.filename media_type = obj.content_type elif is_file_like(obj): body = obj.read() filename = ( os.path.basename(fn) if (fn := getattr(obj, "name", None)) is not None else None ) else: body = t.cast(bytes, obj) filename = None if media_type is not None and self.content_type is not None: if not fnmatch.fnmatch(media_type, self.content_type): raise ValueError( f"Invalid content type {media_type}, expected {self.content_type}" ) with tempfile.NamedTemporaryFile( suffix=filename, dir=request_directory.get(), delete=False ) as f: f.write(body) return Path(f.name)
def __get_pydantic_core_schema__( self, source: type[t.Any], handler: t.Callable[[t.Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( function=self.decode, schema=core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema(self.encode), )
[docs]@attrs.frozen(unsafe_hash=True) class TensorSchema: format: TensorFormat dtype: t.Optional[str] = None shape: t.Optional[t.Tuple[int, ...]] = None @property def dim(self) -> int | None: if self.shape is None: return None return functools.reduce(operator.mul, self.shape, 1) def __get_pydantic_json_schema__( self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> dict[str, t.Any]: value = handler(schema) if handler.mode == "validation": value.update( dict_filter_none( { "type": "tensor", "format": self.format, "dtype": self.dtype, "shape": self.shape, "dim": self.dim, } ) ) else: dimension = 1 if self.shape is None else len(self.shape) child = {"type": "number"} for _ in range(dimension): child = {"type": "array", "items": child} value.update(child) return value def __get_pydantic_core_schema__( self, source_type: t.Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( self.validate, core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema( self.encode, info_arg=True ), )
[docs] def encode(self, arr: TensorType, info: core_schema.SerializationInfo) -> t.Any: if self.format == "numpy-array": assert isinstance(arr, np.ndarray) numpy_array = arr elif self.format == "tf-tensor": if not info.mode_is_json(): # tf.Tensor supports picklev5 serialization return arr numpy_array = arr.numpy() else: assert isinstance(arr, torch.Tensor) if arr.device.type != "cpu": numpy_array = arr.cpu().numpy() else: numpy_array = arr.numpy() if __in_arrow_serialization__: numpy_array = numpy_array.flatten() if info.mode_is_json(): return numpy_array.tolist() return numpy_array
@property def framework_dtype(self) -> t.Any: dtype = self.dtype if dtype is None: return None if self.format == "numpy-array": return getattr(np, dtype) elif self.format == "tf-tensor": return getattr(tf, dtype) else: return getattr(torch, dtype)
[docs] def validate(self, obj: t.Any) -> t.Any: arr: t.Any if self.format == "numpy-array": if isinstance(obj, np.ndarray): return obj arr = np.array(obj, dtype=self.framework_dtype) if self.shape is not None: arr = arr.reshape(self.shape) return arr elif self.format == "tf-tensor": if isinstance(obj, tf.Tensor): return obj else: return tf.constant(obj, dtype=self.framework_dtype, shape=self.shape) # type: ignore else: if isinstance(obj, torch.Tensor): return obj if isinstance(obj, np.ndarray): return torch.from_numpy(obj) arr = torch.tensor(obj, dtype=self.framework_dtype) if self.shape is not None: arr = arr.reshape(self.shape) return arr
[docs]@attrs.frozen(unsafe_hash=True) class DataframeSchema: orient: str = "records" columns: list[str] | None = None def __get_pydantic_json_schema__( self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> dict[str, t.Any]: value = handler(schema) if handler.mode == "validation": value.update( dict_filter_none( { "type": "dataframe", "orient": self.orient, "columns": self.columns, } ) ) else: if self.orient == "records": value.update( { "type": "array", "items": {"type": "object"}, } ) elif self.orient == "columns": value.update( { "type": "object", "additionalProperties": {"type": "array"}, } ) else: raise ValueError( "Only 'records' and 'columns' are supported for orient" ) return value def __get_pydantic_core_schema__( self, source_type: t.Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( self.validate, core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema( self.encode, info_arg=True ), )
[docs] def encode(self, df: pd.DataFrame, info: core_schema.SerializationInfo) -> t.Any: if not info.mode_is_json(): return df if self.orient == "records": return df.to_dict(orient="records") elif self.orient == "columns": return df.to_dict(orient="list") else: raise ValueError("Only 'records' and 'columns' are supported for orient")
[docs] def validate(self, obj: t.Any) -> pd.DataFrame: if isinstance(obj, pd.DataFrame): return obj return pd.DataFrame(obj, columns=self.columns)
[docs]@attrs.frozen class ContentType(BaseMetadata): content_type: str
[docs]@attrs.frozen class Shape(BaseMetadata): dimensions: tuple[int, ...]
[docs]@attrs.frozen class DType(BaseMetadata): dtype: str