Flax#

About this page

This is an API reference for FLax in BentoML. Please refer to /frameworks/flax for more information about how to use Flax in BentoML.

Note

You can find more examples for Flax in our `bentoml/examples https://github.com/bentoml/BentoML/tree/main/examples`_ directory.

bentoml.flax.save_model(name: Tag | str, module: nn.Module, state: dict[str, t.Any] | FrozenDict[str, t.Any] | struct.PyTreeNode, *, signatures: ModelSignaturesType | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None) bentoml.Model[source]#

Save a flax.linen.Module model instance to the BentoML model store.

Parameters:
  • name – The name to give to the model in the BentoML store. This must be a valid Tag name.

  • moduleflax.linen.Module to be saved.

  • signatures – Signatures of predict methods to be used. If not provided, the signatures default to predict. See ModelSignature for more details.

  • labels – A default set of management labels to be associated with the model. An example is {"training-set": "data-1"}.

  • custom_objects – Custom objects to be saved with the model. An example is {"my-normalizer": normalizer}. Custom objects are currently serialized with cloudpickle, but this implementation is subject to change.

  • external_modules – user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module.

  • metadata – Metadata to be associated with the model. An example is {"bias": 4}. Metadata is intended for display in a model management UI and therefore must be a default Python type, such as str or int.

Returns:

A tag that can be used to access the saved model from the BentoML model store.

Return type:

Tag

Example:

import jax

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)

for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(
        state, train_ds, config.batch_size, input_rng
    )
    _, test_loss, test_accuracy = apply_model(
        state, test_ds["image"], test_ds["label"]
    )

    logger.info(
        "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f",
        epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100
    )

# `Save` the model with BentoML
tag = bentoml.flax.save_model("mnist", CNN(), state)
bentoml.flax.load_model(bento_model: str | Tag | bentoml.Model, init: bool = True, device: str | XlaBackend = 'cpu') tuple[nn.Module, dict[str, t.Any]][source]#

Load the flax.linen.Module model instance with the given tag from the local BentoML model store.

Parameters:
  • bento_model – Either the tag of the model to get from the store, or a BentoML ~bentoml.Model instance to load the model from.

  • init – Whether to initialize the state dict of given flax.linen.Module. By default, the weights and values will be put to jnp.ndarray. If init is set to False, The state_dict will only be put to given accelerator device instead.

  • device – The device to put the state dict to. By default, it will be put to cpu. This is only used when init is set to False.

Returns:

A tuple of flax.linen.Module as well as its state_dict from the model store.

Example:

import bentoml
import jax

net, state_dict = bentoml.flax.load_model("mnist:latest")
predict_fn = jax.jit(lambda s: net.apply({"params": state_dict["params"]}, x))
results = predict_fn(jnp.ones((1, 28, 28, 1)))
bentoml.flax.get(tag_like: str | Tag) bentoml.Model[source]#

Get the BentoML model with the given tag.

Parameters:

tag_like – The tag of the model to retrieve from the model store.

Returns:

A BentoML Model with the matching tag.

Return type:

Model

Example:

import bentoml

model = bentoml.flax.get("mnist:latest")