PyTorch Lightning#

Hereโ€™s a simple example of using PyTorch Lightning with BentoML:

import bentoml
import torch
import pytorch_lightning as pl

class AdditionModel(pl.LightningModule):
    def forward(self, inputs):
        return inputs.add(1)

# `save` a given classifier and retrieve coresponding tag:
tag = bentoml.pytorch_lightning.save_model("addition_model", AdditionModel())

# retrieve metadata with `bentoml.models.get`:
metadata = bentoml.models.get(tag)

# `load` the model back in memory:
model = bentoml.pytorch_lightning.load_model("addition_model:latest")

# Run a given model under `Runner` abstraction with `to_runner`
runner = bentoml.pytorch_lightning.get(tag).to_runner()
runner.init_local()
runner.run(torch.from_numpy(np.array([[1,2,3,4]])))

Note

You can find more examples for PyTorch Lightning in our `bentoml/examples https://github.com/bentoml/BentoML/tree/main/examples`_ directory.

bentoml.pytorch_lightning.save_model(name: str, model: pl.LightningModule, *, signatures: ModelSignaturesType | None = None, labels: t.Dict[str, str] | None = None, custom_objects: t.Dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: t.Dict[str, t.Any] | None = None) bentoml.Model[source]#

Save a model instance to BentoML modelstore.

Parameters
  • name (str) โ€“ Name for given model instance. This should pass Python identifier check.

  • model (pl.LightningModule) โ€“ Instance of model to be saved

  • labels (Dict[str, str], optional, default to None) โ€“ user-defined labels for managing models, e.g. team=nlp, stage=dev

  • custom_objects (Dict[str, Any]], optional, default to None) โ€“ user-defined additional python objects to be saved alongside the model, e.g. a tokenizer instance, preprocessor function, model configuration json

  • external_modules (List[ModuleType], optional, default to None) โ€“ user-defined additional python modules to be saved alongside the model or custom objects, e.g. a tokenizer module, preprocessor module, model configuration module

  • metadata (Dict[str, Any], optional, default to None) โ€“ Custom metadata for given model.

  • model_store (ModelStore, default to BentoMLContainer.model_store) โ€“ BentoML modelstore, provided by DI Container.

Returns

A tag with a format name:version where name is the user-defined modelโ€™s name, and a generated version by BentoML.

Return type

Tag

Examples:

import bentoml
import torch
import pytorch_lightning as pl

class LitClassifier(pl.LightningModule):

    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

tag = bentoml.pytorch_lightning.save("lit_classifier", LitClassifier())
bentoml.pytorch_lightning.load_model(bentoml_model: str | bentoml._internal.tag.Tag | bentoml._internal.models.model.Model, device_id: Optional[str] = 'cpu') torch.ScriptModule[source]#

Load a model from BentoML local modelstore with given name.

Parameters
Returns

an instance of torch.ScriptModule from BentoML modelstore.

Return type

torch.ScriptModule

Examples:

import bentoml
lit = bentoml.torchscript.load_model('lit_classifier:latest', device_id="cuda:0")
bentoml.pytorch_lightning.get(tag_like: str | bentoml._internal.tag.Tag) Model[source]#