Source code for bentoml._internal.frameworks.flax

from __future__ import annotations

import functools
import logging
import typing as t
from pickle import UnpicklingError
from types import ModuleType
from typing import TYPE_CHECKING

import msgpack.exceptions

import bentoml

from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...exceptions import NotFound
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions as ModelOptions
from ..runner.utils import Params
from ..types import LazyType
from ..utils import LazyLoader
from ..utils.pkg import get_pkg_version
from .common.jax import JaxArrayContainer
from .common.jax import jax
from .common.jax import jnp

if TYPE_CHECKING:
    from flax import struct
    from flax.core import FrozenDict
    from jax._src.lib.xla_bridge import XlaBackend
    from jax.lib import xla_bridge

    from ...types import ModelSignature
    from .. import external_typing as ext
    from ..models.model import ModelSignaturesType
    from ..tag import Tag
else:
    xla_bridge = LazyLoader("xla_bridge", globals(), "jax.lib.xla_bridge")

try:
    from flax import linen as nn
    from flax import serialization
except ImportError:  # pragma: no cover
    raise MissingDependencyException(
        "flax is required in order to use with 'bentoml.flax'. See https://flax.readthedocs.io/en/latest/index.html#installation for instructions."
    )

# NOTE: tensorflow is required since jax depends on XLA, which is a part of Tensorflow.
try:
    import tensorflow as tf
except ImportError:  # pragma: no cover
    raise MissingDependencyException(
        "'tensorflow' is required in order to use module 'bentoml.flax', install tensorflow with 'pip install tensorflow'. For more information, refer to https://www.tensorflow.org/install"
    )


MODULE_NAME = "bentoml.flax"
MODEL_FILENAME = "saved_model.msgpack"
API_VERSION = "v1"

logger = logging.getLogger(__name__)


__all__ = ["load_model", "save_model", "get_runnable", "get", "JaxArrayContainer"]


[docs]def get(tag_like: str | Tag) -> bentoml.Model: """ Get the BentoML model with the given tag. Args: tag_like: The tag of the model to retrieve from the model store. Returns: :obj:`~bentoml.Model`: A BentoML :obj:`~bentoml.Model` with the matching tag. Example: .. code-block:: python import bentoml model = bentoml.flax.get("mnist:latest") """ model = bentoml.models.get(tag_like) if model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {model.tag} was saved with module {model.info.module}, failed to load with {MODULE_NAME}." ) return model
[docs]def load_model( bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = "cpu", ) -> tuple[nn.Module, dict[str, t.Any]]: """ Load the ``flax.linen.Module`` model instance with the given tag from the local BentoML model store. Args: bento_model: Either the tag of the model to get from the store, or a BentoML `~bentoml.Model` instance to load the model from. init: Whether to initialize the state dict of given ``flax.linen.Module``. By default, the weights and values will be put to ``jnp.ndarray``. If ``init`` is set to ``False``, The state_dict will only be put to given accelerator device instead. device: The device to put the state dict to. By default, it will be put to ``cpu``. This is only used when ``init`` is set to ``False``. Returns: A tuple of ``flax.linen.Module`` as well as its ``state_dict`` from the model store. Example: .. code-block:: python import bentoml import jax net, state_dict = bentoml.flax.load_model("mnist:latest") predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x)) results = predict_fn(jnp.ones((1, 28, 28, 1))) """ # NOTE: we need to hide all GPU from TensorFlow, otherwise it will try to allocate # memory on the GPU and make it unavailable for JAX. tf.config.experimental.set_visible_devices([], "GPU") if not isinstance(bento_model, bentoml.Model): bento_model = get(bento_model) if bento_model.info.module not in (MODULE_NAME, __name__): raise NotFound( f"Model {bento_model.tag} was saved with module {bento_model.info.module}, failed loading with {MODULE_NAME}." ) if "_module" not in bento_model.custom_objects: raise BentoMLException( f"Model {bento_model.tag} was either corrupt or not saved with 'bentoml.flax.save_model()'." ) module: nn.Module = bento_model.custom_objects["_module"] serialized = bento_model.path_of(MODEL_FILENAME) try: with open(serialized, "rb") as f: state_dict: dict[str, t.Any] = serialization.from_bytes(module, f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData, UnicodeDecodeError) as err: raise BentoMLException( f"Unable to covert model {bento_model.tag}'s state_dict: {err}" ) from None # ensure that all arrays are restored as jnp.ndarray # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 if init: state_dict = jax.tree_util.tree_map(jnp.array, state_dict) else: # keep the params on given device if we don't want to initialize state_dict = jax.tree_util.tree_map( lambda s: jax.device_put(s, jax.devices(device)[0]), state_dict ) return module, state_dict
[docs]def save_model( name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None, ) -> bentoml.Model: """ Save a ``flax.linen.Module`` model instance to the BentoML model store. Args: name: The name to give to the model in the BentoML store. This must be a valid :obj:`~bentoml.Tag` name. module: ``flax.linen.Module`` to be saved. signatures: Signatures of predict methods to be used. If not provided, the signatures default to ``predict``. See :obj:`~bentoml.types.ModelSignature` for more details. labels: A default set of management labels to be associated with the model. An example is ``{"training-set": "data-1"}``. custom_objects: Custom objects to be saved with the model. An example is ``{"my-normalizer": normalizer}``. Custom objects are currently serialized with cloudpickle, but this implementation is subject to change. external_modules: user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module. metadata: Metadata to be associated with the model. An example is ``{"bias": 4}``. Metadata is intended for display in a model management UI and therefore must be a default Python type, such as :obj:`str` or :obj:`int`. Returns: :obj:`~bentoml.Tag`: A tag that can be used to access the saved model from the BentoML model store. Example: .. code-block:: python import jax rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, config) for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, config.batch_size, input_rng ) _, test_loss, test_accuracy = apply_model( state, test_ds["image"], test_ds["label"] ) logger.info( "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f", epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100 ) # `Save` the model with BentoML tag = bentoml.flax.save_model("mnist", CNN(), state) """ if not isinstance(module, nn.Module): raise BentoMLException( f"'bentoml.flax.save_model()' only support saving 'flax.linen.Module' object. Got {module.__class__.__name__} instead." ) context = ModelContext( framework_name="flax", framework_versions={ "flax": get_pkg_version("flax"), "jax": get_pkg_version("jax"), "jaxlib": get_pkg_version("jaxlib"), }, ) if signatures is None: signatures = {"__call__": {"batchable": False}} logger.info( 'Using the default model signature for Flax (%s) for model "%s".', signatures, name, ) custom_objects = {} if custom_objects is None else custom_objects custom_objects["_module"] = module with bentoml.models._create( # type: ignore name, module=MODULE_NAME, api_version=API_VERSION, signatures=signatures, labels=labels, options=ModelOptions(), custom_objects=custom_objects, external_modules=external_modules, metadata=metadata, context=context, ) as bento_model: with open(bento_model.path_of(MODEL_FILENAME), "wb") as f: f.write(serialization.to_bytes(state)) return bento_model
def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]: """Private API: use :obj:`~bentoml.Model.to_runnable` instead.""" partial_kwargs: dict[str, t.Any] = bento_model.info.options.partial_kwargs class FlaxRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ("tpu", "nvidia.com/gpu", "cpu") SUPPORTS_CPU_MULTI_THREADING = True def __init__(self): super().__init__() self.device = xla_bridge.get_backend().platform self.model, self.state_dict = load_model(bento_model, device=self.device) self.params = self.state_dict["params"] self.methods_cache: t.Dict[str, t.Callable[..., t.Any]] = {} def gen_run_method(self: FlaxRunnable, method_name: str): method = getattr(self.model, method_name) method_partial_kwargs = partial_kwargs.get(method_name) if method_partial_kwargs is not None: method = functools.partial(method, **method_partial_kwargs) def mapping(item: jnp.ndarray | ext.NpNDArray | ext.PdDataFrame) -> jnp.ndarray: if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(item): return jnp.asarray(item) if LazyType["ext.PdDataFrame"]("pandas.DataFrame").isinstance(item): # NOTE: only to_numpy() are doing copying in memory here. return jnp.asarray(item.to_numpy()) return item def run_method( self: FlaxRunnable, *args: jnp.ndarray | ext.NpNDArray | ext.PdDataFrame ): params = Params[jnp.ndarray](*args).map(mapping) arg = params.args[0] if len(params.args) == 1 else params.args # NOTE: can we jit this? # No?, as we should not interfere with JAX tracing in multiple threads # https://jax.readthedocs.io/en/latest/concurrency.html?highlight=concurrency return self.model.apply({"params": self.params}, arg, method=method) return run_method def add_runnable_method(method_name: str, options: ModelSignature): def run_method(self: FlaxRunnable, *args: jnp.ndarray): fn = self.methods_cache.get(method_name) if not fn: fn = gen_run_method(self, method_name) self.methods_cache[method_name] = fn return fn(self, *args) FlaxRunnable.add_method( run_method, name=method_name, batchable=options.batchable, batch_dim=options.batch_dim, input_spec=options.input_spec, output_spec=options.output_spec, ) for method_name, options in bento_model.info.signatures.items(): add_runnable_method(method_name, options) return FlaxRunnable