Source code for bentoml._internal.frameworks.pytorch

from __future__ import annotations

import logging
import typing as t
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING

import cloudpickle

import bentoml
from bentoml import Tag

from ...exceptions import NotFound
from ..models import Model
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions as ModelOptions
from ..types import LazyType
from ..utils.pkg import get_pkg_version
from .common.pytorch import PyTorchTensorContainer
from .common.pytorch import torch

__all__ = ["load_model", "save_model", "get_runnable", "get", "PyTorchTensorContainer"]


MODULE_NAME = "bentoml.pytorch"
MODEL_FILENAME = "saved_model.pt"
API_VERSION = "v1"

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
    from ..models.model import ModelSignaturesType


[docs]def get(tag_like: str | Tag) -> Model: model = bentoml.models.get(tag_like) if model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}." ) return model
[docs]def load_model( bentoml_model: str | Tag | Model, device_id: t.Optional[str] = "cpu", ) -> torch.nn.Module: """ Load a model from a BentoML Model with given name. Args: tag (:code:`Union[str, Tag]`): Tag of a saved model in BentoML local modelstore. device_id (:code:`str`, `optional`, default to :code:`cpu`): Optional devices to put the given model on. Refer to `device attributes <https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device>`_. Returns: :obj:`torch.nn.Module`: an instance of :code:`torch.nn.Module` from BentoML modelstore. Examples: .. code-block:: python import bentoml model = bentoml.pytorch.load_model('lit_classifier:latest', device_id="cuda:0") """ if isinstance(bentoml_model, (str, Tag)): bentoml_model = get(bentoml_model) if bentoml_model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {bentoml_model.tag} was saved with module {bentoml_model.info.module}, not loading with {MODULE_NAME}." ) weight_file = bentoml_model.path_of(MODEL_FILENAME) with Path(weight_file).open("rb") as file: model: "torch.nn.Module" = torch.load(file, map_location=device_id) return model
[docs]def save_model( name: Tag | str, model: "torch.nn.Module", *, signatures: ModelSignaturesType | None = None, labels: t.Dict[str, str] | None = None, custom_objects: t.Dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: t.Dict[str, t.Any] | None = None, ) -> bentoml.Model: """ Save a model instance to BentoML modelstore. Args: name (:code:`str`): Name for given model instance. This should pass Python identifier check. model (:code:`torch.nn.Module`): Instance of model to be saved signatures (:code:`ModelSignaturesType`, `optional`, default to :code:`None`): A dictionary of method names and their corresponding signatures. labels (:code:`Dict[str, str]`, `optional`, default to :code:`None`): user-defined labels for managing models, e.g. team=nlp, stage=dev custom_objects (:code:`Dict[str, Any]]`, `optional`, default to :code:`None`): user-defined additional python objects to be saved alongside the model, e.g. a tokenizer instance, preprocessor function, model configuration json external_modules (:code:`List[ModuleType]`, `optional`, default to :code:`None`): user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module metadata (:code:`Dict[str, Any]`, `optional`, default to :code:`None`): Custom metadata for given model. Returns: :obj:`~bentoml.Tag`: A :obj:`tag` with a format `name:version` where `name` is the user-defined model's name, and a generated `version` by BentoML. Examples: .. code-block:: python import torch import bentoml class NGramLanguageModeler(nn.Module): def __init__(self, vocab_size, embedding_dim, context_size): super(NGramLanguageModeler, self).__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.linear1 = nn.Linear(context_size * embedding_dim, 128) self.linear2 = nn.Linear(128, vocab_size) def forward(self, inputs): embeds = self.embeddings(inputs).view((1, -1)) out = F.relu(self.linear1(embeds)) out = self.linear2(out) log_probs = F.log_softmax(out, dim=1) return log_probs tag = bentoml.pytorch.save("ngrams", NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)) # example tag: ngrams:20201012_DE43A2 Integration with Torch Hub and BentoML: .. code-block:: python import torch import bentoml resnet50 = torch.hub.load("pytorch/vision", "resnet50", pretrained=True) ... # trained a custom resnet50 tag = bentoml.pytorch.save("resnet50", resnet50) """ if not LazyType("torch.nn.Module").isinstance(model): raise TypeError(f"Given model ({model}) is not a torch.nn.Module.") context: ModelContext = ModelContext( framework_name="torch", framework_versions={"torch": get_pkg_version("torch")}, ) if signatures is None: signatures = {"__call__": {"batchable": False}} logger.info( 'Using the default model signature for PyTorch (%s) for model "%s".', signatures, name, ) with bentoml.models._create( # type: ignore name, module=MODULE_NAME, api_version=API_VERSION, labels=labels, signatures=signatures, custom_objects=custom_objects, external_modules=external_modules, options=ModelOptions(), context=context, metadata=metadata, ) as bento_model: weight_file = bento_model.path_of(MODEL_FILENAME) with open(weight_file, "wb") as file: torch.save(model, file, pickle_module=cloudpickle) # type: ignore return bento_model
def get_runnable(bento_model: Model): """ Private API: use :obj:`~bentoml.Model.to_runnable` instead. """ from .common.pytorch import PytorchModelRunnable from .common.pytorch import make_pytorch_runnable_method from .common.pytorch import partial_class partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore runnable_class: type[PytorchModelRunnable] = partial_class( PytorchModelRunnable, bento_model=bento_model, loader=load_model, ) for method_name, options in bento_model.info.signatures.items(): method_partial_kwargs = partial_kwargs.get(method_name) runnable_class.add_method( make_pytorch_runnable_method(method_name, method_partial_kwargs), name=method_name, batchable=options.batchable, batch_dim=options.batch_dim, input_spec=options.input_spec, output_spec=options.output_spec, ) return runnable_class