Source code for bentoml._internal.frameworks.diffusers

from __future__ import annotations

import logging
import os
import shutil
import typing as t
from typing import TYPE_CHECKING

import attr

import bentoml
from bentoml import Tag
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import MissingDependencyException
from bentoml.exceptions import NotFound
from bentoml.models import ModelContext

from ..models.model import PartialKwargsModelOptions
from .utils.transformers import extract_commit_hash

    from types import ModuleType

    from bentoml.types import ModelSignature
    from bentoml.types import ModelSignatureDict

    import diffusers
    import torch
    from diffusers.loaders import LoraLoaderMixin
    from diffusers.loaders import TextualInversionLoaderMixin
    from diffusers.utils.import_utils import is_accelerate_available
    from diffusers.utils.import_utils import is_torch_version
    from diffusers.utils.import_utils import is_xformers_available
except ImportError:  # pragma: no cover
    raise MissingDependencyException(
        "'diffusers' and 'transformers' is required in order to use module 'bentoml.diffusers', install diffusers and its dependencies with 'pip install --upgrade diffusers transformers accelerate'. For more information, refer to",

MODULE_NAME = "bentoml.diffusers"
DIFFUSION_MODEL_FOLDER = "diffusion_model"
DIFFUSION_MODEL_CONFIG_FILE = "model_index.json"

logger = logging.getLogger(__name__)

LoraOptionType = t.Union[str, t.Dict[str, str]]
TextualInversionOptionType = t.Union[str, t.Dict[str, str]]

class DiffusersOptions(PartialKwargsModelOptions):
    """Options for the diffusers model."""

    pipeline_class: str | type[diffusers.DiffusionPipeline] | None = None
    scheduler_class: str | type[diffusers.SchedulerMixin] | None = None
    torch_dtype: str | torch.dtype | None = None
    device_map: str | dict[str, int | str | torch.device] | None = None
    custom_pipeline: str | None = None
    enable_xformers: bool | None = None
    enable_attention_slicing: int | str | None = None
    enable_model_cpu_offload: bool | None = None
    enable_sequential_cpu_offload: bool | None = None
    enable_torch_compile: bool | None = None
    low_cpu_mem_usage: bool | None = None
    variant: str | None = None
    load_pretrained_extra_kwargs: dict[str, t.Any] | None = None
    lora_dir: str | None = None
    lora_weights: LoraOptionType | list[LoraOptionType] | None = None
    textual_inversions: TextualInversionOptionType | list[
    ] | None = None

def _prepare_lora_args(
    raw_arg: LoraOptionType, lora_dir: str | None = None
) -> tuple[str, dict[str, str]]:
    if lora_dir is None:
        lora_dir = os.getcwd()

    lora_dir = os.path.expanduser(lora_dir)

    # if user only provide a string, we try to use the string as a
    # path (either absolute or relative) to the weight file. If no
    # file detected, treat the string as huggingface repository
    # identifier
    if isinstance(raw_arg, str):
        raw_arg = raw_arg.strip()
        arg_path = os.path.expanduser(raw_arg)
        weight_path = None

        # absolute path case
        if os.path.isabs(arg_path) and os.path.exists(arg_path):
            weight_path = arg_path

        # relative path case
        tmp_path = os.path.join(lora_dir, arg_path)
        if os.path.exists(tmp_path):
            weight_path = tmp_path

        if weight_path:
            model_name = os.path.dirname(weight_path)
            weight_name = os.path.basename(weight_path)
            kwargs = {"weight_name": weight_name}
            return (model_name, kwargs)

        # repo id case
        lst = raw_arg.split("/")
        if not len(lst) > 2:
            raise ValueError(f"{raw_arg} is not a valid huggingface LoRA path")
        model_name = "/".join(lst[:2])
        weight_name = "/".join(lst[2:])
        kwargs = {"weight_name": weight_name}
        return (model_name, kwargs)

    model_name = raw_arg.pop("model_name")
    return (model_name, raw_arg)

def _load_lora_weights_to_pipeline(
    pipeline: diffusers.DiffusionPipeline,
    lora_weights: LoraOptionType | list[LoraOptionType],
    lora_dir: str | None = None,
    if not isinstance(lora_weights, list):
        lora_weights = [lora_weights]

    if len(lora_weights) > 1:
            "Currently diffusers only support single lora weight loading. The first lora weight will be loaded and the rest will be discarded"

    lora_weight = lora_weights[0]
    model_name, kwargs = _prepare_lora_args(lora_weight, lora_dir=lora_dir)
    pipeline.load_lora_weights(model_name, **kwargs)

def _prepare_textual_inversion_args(
    raw_arg: TextualInversionOptionType,
) -> tuple[str, dict[str, str]]:
    if isinstance(raw_arg, str):
        # if user only provide a string, we consider that a path to
        # the weight file
        model_name = "."
        kwargs = {"weight_name": raw_arg}
        return (model_name, kwargs)

    model_name = raw_arg.pop("model_name")
    return (model_name, raw_arg)

def _str2cls(
    full_cls_str: str,
) -> type[diffusers.DiffusionPipeline | diffusers.SchedulerMixin]:
    import importlib

    module_name, _, class_name = full_cls_str.rpartition(".")

    # if user only provide something like "StableDiffusionpipeline"
    # with out the module name, we will try the sane default
    if not module_name:
        module_name = "diffusers"

    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls

[docs]def get(tag_like: str | Tag) -> bentoml.Model: """ Get the BentoML model with the given tag. Args: tag_like: The tag of the model to retrieve from the model store. Returns: :obj:`~bentoml.Model`: A BentoML :obj:`~bentoml.Model` with the matching tag. Example: .. code-block:: python import bentoml # target model must be from the BentoML model store model = bentoml.diffusers.get("my_stable_diffusion_model") """ model = bentoml.models.get(tag_like) if not in (MODULE_NAME, __name__): raise NotFound( f"Model {model.tag} was saved with module {}, not loading with {MODULE_NAME}." ) return model
[docs]def load_model( bento_model: str | Tag | bentoml.Model, device_id: str | torch.device | None = None, pipeline_class: str | type[diffusers.pipelines.DiffusionPipeline] = diffusers.DiffusionPipeline, device_map: str | dict[str, int | str | torch.device] | None = None, custom_pipeline: str | None = None, scheduler_class: type[diffusers.SchedulerMixin] | None = None, torch_dtype: str | torch.dtype | None = None, low_cpu_mem_usage: bool | None = None, enable_xformers: bool = False, enable_attention_slicing: int | str | None = None, enable_model_cpu_offload: bool | None = None, enable_sequential_cpu_offload: bool | None = None, enable_torch_compile: bool | None = None, variant: str | None = None, lora_weights: LoraOptionType | list[LoraOptionType] | None = None, textual_inversions: TextualInversionOptionType | list[TextualInversionOptionType] | None = None, load_pretrained_extra_kwargs: dict[str, t.Any] | None = None, ) -> diffusers.DiffusionPipeline: """ Load a Diffusion model and convert it to diffusers `Pipeline <>`_ with the given tag from the local BentoML model store. Args: bento_model: Either the tag of the model to get from the store, or a BentoML ``~bentoml.Model`` instance to load the model from. device_id (:code:`str`, `optional`, default to :code:`None`): Optional devices to put the given model on. Refer to `device attributes <>`_. pipeline_class (:code:`type[diffusers.DiffusionPipeline]`, `optional`): DiffusionPipeline Class use to load the saved diffusion model, default to ``diffusers.DiffusionPipeline``. For more pipeline types, refer to `Pipeline Overview <>`_ device_map (:code:`None | str | Dict[str, Union[int, str, torch.device]]`, `optional`): A map that specifies where each submodule should go. For more information, refer to `device_map <>`_ custom_pipeline (:code:`None | str`, `optional`): An identifier of custom pipeline hosted on github. For a list of community maintained custom piplines, refer to scheduler_class (:code:`type[diffusers.SchedulerMixin]`, `optional`): Scheduler Class to be used by DiffusionPipeline torch_dtype (:code:`str | torch.dtype`, `optional`): Override the default `torch.dtype` and load the model under this dtype. low_cpu_mem_usage (:code:`bool`, `optional`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. defaults to `True` if torch version >= 1.9.0 else `False` enable_xformers (:code:`bool`, `optional`): Use xformers optimization if it's available. For more info, refer to variant (:code:`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. lora_weights (:code:`LoraOptionType | list[LoraOptionType]` *optional*): lora weights to be loaded. :code:`LoraOptionType` can be either a string or a dictionary. When it's a string, it represents a path to the weight file. When it's a dictionary, it contains a key :code`"model_name"` pointing to a huggingface repository or a local directory, a key :code:`weight_name` pointing the weight file and other keys that will be passed to pipeline's :code:`load_lora_weights` method. textual_inversions (:code:`TextualInversionOptionType | list[TextualInversionOptionType]` *optional*): Textual inversions to be loaded. :code:`TextualInversionOptionType` can be either a string or a dictionary. When it's a string, it represents a path to the weight file. When it's a dictionary, it contains a key :code`"model_name"` pointing to a huggingface repository or a local directory, a key :code:`weight_name` pointing the weight file and other keys that will be passed to pipeline's :code:`load_lora_weights` method. load_pretrained_extra_kwargs: (:code:`dict[str, t.Any]`, *optional*): Extra kwargs passed to Pipeline class's :code:`from_pretrained` method Returns: The Diffusion model loaded as diffusers pipeline from the BentoML model store. Example: .. code-block:: python import bentoml pipeline = bentoml.diffusers.load_model('my_diffusers_model:latest') pipeline(prompt) """ # noqa if not isinstance(bento_model, bentoml.Model): bento_model = get(bento_model) if not in (MODULE_NAME, __name__): raise NotFound( f"Model {bento_model.tag} was saved with module {}, not loading with {MODULE_NAME}." ) if isinstance(pipeline_class, str): pipeline_class = t.cast( type[diffusers.DiffusionPipeline], _str2cls(pipeline_class) ) if lora_weights: if not issubclass(pipeline_class, LoraLoaderMixin): raise NotImplementedError( f"Class {pipeline_class} is not a subclass of LoraLoaderMixin, cannot load textual inversions" ) if textual_inversions: if not issubclass(pipeline_class, TextualInversionLoaderMixin): raise NotImplementedError( f"Class {pipeline_class} is not a subclass of TextualInversionLoaderMixin, cannot load lora weights" ) diffusion_model_dir = bento_model.path_of(DIFFUSION_MODEL_FOLDER) if low_cpu_mem_usage is None: if is_torch_version(">=", "1.9.0") and is_accelerate_available(): low_cpu_mem_usage = True else: low_cpu_mem_usage = False load_pretrained_extra_kwargs = load_pretrained_extra_kwargs or {} pipeline: diffusers.DiffusionPipeline = pipeline_class.from_pretrained( diffusion_model_dir, torch_dtype=torch_dtype, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device_map, custom_pipeline=custom_pipeline, variant=variant, **load_pretrained_extra_kwargs, ) if scheduler_class: scheduler: diffusers.SchedulerMixin = scheduler_class.from_config( pipeline.scheduler.config ) pipeline.scheduler = scheduler if device_id is not None: move_model_to_device = True if str(device_id).lower().startswith("cuda"): # when device_map is not None, we should not move the # pipeline to gpu again see # if device_map is not None: move_model_to_device = False if enable_sequential_cpu_offload: move_model_to_device = False if enable_model_cpu_offload: move_model_to_device = False if move_model_to_device: pipeline = if enable_xformers: pipeline.enable_xformers_memory_efficient_attention() if enable_sequential_cpu_offload: pipeline.enable_sequential_cpu_offload() if enable_model_cpu_offload: pipeline.enable_model_cpu_offload() if enable_attention_slicing is not None: pipeline.enable_attention_slicing(enable_attention_slicing) if enable_torch_compile:"Run torch compile on unet") pipeline.unet = torch.compile( pipeline.unet, mode="reduce-overhead", fullgraph=True ) if lora_weights: _load_lora_weights_to_pipeline(pipeline, lora_weights) if textual_inversions: if not isinstance(textual_inversions, list): textual_inversions = [textual_inversions] for textual_inversion in textual_inversions: model_name, kwargs = _prepare_textual_inversion_args(textual_inversion) pipeline.load_textual_inversion(model_name, **kwargs) return pipeline
[docs]def import_model( name: Tag | str, model_name_or_path: str | os.PathLike[str], *, proxies: dict[str, str] | None = None, revision: str = "main", variant: str | None = None, pipeline_class: str | type[diffusers.DiffusionPipeline] | None = None, sync_with_hub_version: bool = False, signatures: dict[str, ModelSignatureDict | ModelSignature] | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None, # ... ) -> bentoml.Model: """ Import Diffusion model from a artifact URI to the BentoML model store. Args: name: The name to give to the model in the BentoML store. This must be a valid :obj:`~bentoml.Tag` name. model_name_or_path: Can be either: - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on Valid repo ids have to be located under a user or organization name, like `CompVis/ldm-text2im-large-256`. - A path to a *directory* containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': '', 'http://hostname': ''}`. The proxies are used on each request. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on, so `revision` can be any identifier allowed by git. variant (`str`, *optional*): Variant of the model to import. For example there's "fp16" and "fp32" variant for "DeepFloyd/IF-I-XL-v1.0". This may save download bandwidth and local disk space. sync_with_hub_version (`bool`, default to False): If sync_with_hub_version is true, then the model imported by signatures: Signatures of predict methods to be used. If not provided, the signatures default to {"__call__": {"batchable": False}}. See :obj:`~bentoml.types.ModelSignature` for more details. labels: A default set of management labels to be associated with the model. For example: ``{"training-set": "data-v1"}``. custom_objects: Custom objects to be saved with the model. An example is ``{"my-normalizer": normalizer}``. Custom objects are serialized with cloudpickle. metadata: Metadata to be associated with the model. An example is ``{"param_a": .2}``. Metadata is intended for display in a model management UI and therefore all values in metadata dictionary must be a primitive Python type, such as ``str`` or ``int``. Returns: A :obj:`~bentoml.Model` instance referencing a saved model in the local BentoML model store. Example: .. code-block:: python import bentoml bentoml.diffusers.import_model( 'my_sd15_model', "runwayml/stable-diffusion-v1-5", signatures={ "__call__": {"batchable": False}, } ) """ tag = Tag.from_taglike(name) try: model = bentoml.models.get(tag) return model except bentoml.exceptions.NotFound: pass if sync_with_hub_version: if tag.version is not None: logger.warn( f"sync_with_hub_version is True, user provided version {tag.version} may be overridden by huggingface hub's commit hash" ) context = ModelContext( framework_name="diffusers", framework_versions={"diffusers": diffusers.__version__}, ) if signatures is None: signatures = { "__call__": {"batchable": False}, } 'Using the default model signature for diffusers (%s) for model "%s".', signatures, name, ) if pipeline_class and isinstance(pipeline_class, str): pipeline_class = t.cast( type[diffusers.DiffusionPipeline], _str2cls(pipeline_class) ) options_dict: dict[str, str] = {} if pipeline_class: cls_str = f"{pipeline_class.__module__}.{pipeline_class.__name__}" options_dict["pipeline_class"] = cls_str if variant: options_dict["variant"] = variant options = DiffusersOptions(**options_dict) if options_dict else None if os.path.isdir(model_name_or_path): src_dir = model_name_or_path if sync_with_hub_version: raise BentoMLException( "Cannot sync version with huggingface hub when importing a local model" ) elif pipeline_class: src_dir = model_name_or_path, proxies=proxies, revision=revision, variant=variant ) if sync_with_hub_version: from huggingface_hub.file_download import REGEX_COMMIT_HASH version = extract_commit_hash(src_dir, REGEX_COMMIT_HASH) if version is not None: if variant is not None: version = version + "-" + variant tag.version = version else: from huggingface_hub import snapshot_download src_dir = snapshot_download( model_name_or_path, proxies=proxies, revision=revision, ) if sync_with_hub_version: from huggingface_hub.file_download import REGEX_COMMIT_HASH version = extract_commit_hash(src_dir, REGEX_COMMIT_HASH) if version is not None: tag.version = version with bentoml.models.create( tag, module=MODULE_NAME, api_version=API_VERSION, signatures=signatures, labels=labels, options=options, custom_objects=custom_objects, external_modules=external_modules, metadata=metadata, context=context, ) as bento_model: diffusion_model_dir = bento_model.path_of(DIFFUSION_MODEL_FOLDER) ignore = shutil.ignore_patterns(".git") model_config_file = os.path.join(src_dir, DIFFUSION_MODEL_CONFIG_FILE) if not os.path.exists(model_config_file): raise BentoMLException(f'artifact "{src_dir}" is not a Diffusion model') shutil.copytree(src_dir, diffusion_model_dir, symlinks=False, ignore=ignore) return bento_model
[docs]def save_model( name: Tag | str, pipeline: diffusers.DiffusionPipeline, *, signatures: dict[str, ModelSignatureDict | ModelSignature] | None = None, labels: dict[str, str] | None = None, custom_objects: dict[str, t.Any] | None = None, external_modules: t.List[ModuleType] | None = None, metadata: dict[str, t.Any] | None = None, ) -> bentoml.Model: """ Save a DiffusionPipeline to the BentoML model store. Args: name: The name to give to the model in the BentoML store. This must be a valid :obj:`~bentoml.Tag` name. pipeline: Instance of the Diffusers pipeline to be saved signatures: Signatures of predict methods to be used. If not provided, the signatures default to {"__call__": {"batchable": False}}. See :obj:`~bentoml.types.ModelSignature` for more details. labels: A default set of management labels to be associated with the model. For example: ``{"training-set": "data-v1"}``. custom_objects: Custom objects to be saved with the model. An example is ``{"my-normalizer": normalizer}``. Custom objects are serialized with cloudpickle. metadata: Metadata to be associated with the model. An example is ``{"param_a": .2}``. Metadata is intended for display in a model management UI and therefore all values in metadata dictionary must be a primitive Python type, such as ``str`` or ``int``. Returns: A :obj:`~bentoml.Model` instance referencing a saved model in the local BentoML model store. """ if not isinstance(pipeline, diffusers.DiffusionPipeline): raise BentoMLException( "'pipeline' must be an instance of 'diffusers.DiffusionPipeline'. " ) context = ModelContext( framework_name="diffusers", framework_versions={"diffusers": diffusers.__version__}, ) if signatures is None: signatures = { "__call__": {"batchable": False}, } 'Using the default model signature for diffusers (%s) for model "%s".', signatures, name, ) with bentoml.models.create( name, module=MODULE_NAME, api_version=API_VERSION, signatures=signatures, labels=labels, options=None, custom_objects=custom_objects, external_modules=external_modules, metadata=metadata, context=context, ) as bento_model: diffusion_model_dir = bento_model.path_of(DIFFUSION_MODEL_FOLDER) pipeline.save_pretrained(diffusion_model_dir) return bento_model
def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]: """ Private API: use :obj:`~bentoml.Model.to_runnable` instead. """ bento_options = t.cast(DiffusersOptions, partial_kwargs: dict[str, t.Any] = bento_options.partial_kwargs # type: ignore pipeline_class: str | type[diffusers.DiffusionPipeline] = ( bento_options.pipeline_class or diffusers.DiffusionPipeline ) if isinstance(pipeline_class, str): pipeline_class = t.cast( type[diffusers.DiffusionPipeline], _str2cls(pipeline_class) ) scheduler_class: str | type[ diffusers.SchedulerMixin ] | None = bento_options.scheduler_class if scheduler_class and isinstance(scheduler_class, str): scheduler_class = t.cast( type[diffusers.SchedulerMixin], _str2cls(scheduler_class) ) custom_pipeline: str | None = bento_options.custom_pipeline _enable_xformers: bool | None = bento_options.enable_xformers enable_attention_slicing: int | str | None = bento_options.enable_attention_slicing enable_sequential_cpu_offload: bool | None = ( bento_options.enable_sequential_cpu_offload ) enable_model_cpu_offload: bool | None = bento_options.enable_model_cpu_offload enable_torch_compile: bool | None = bento_options.enable_torch_compile low_cpu_mem_usage: bool | None = bento_options.low_cpu_mem_usage variant: str | None = bento_options.variant _torch_dtype: str | torch.dtype | None = bento_options.torch_dtype device_map: str | dict[ str, int | str | torch.device ] | None = bento_options.device_map load_pretrained_extra_kwargs = bento_options.load_pretrained_extra_kwargs support_lora = True if issubclass(pipeline_class, LoraLoaderMixin) else False support_textual_inversion = ( True if issubclass(pipeline_class, TextualInversionLoaderMixin) else False ) lora_dir = bento_options.lora_dir lora_weights = bento_options.lora_weights textual_inversions = bento_options.textual_inversions if not support_lora and lora_weights: raise NotImplementedError( f"Class {pipeline_class} is not a subclass of LoraLoaderMixin, cannot load lora weights. " "Try using `bento_model.with_options(pipeline_class=diffusers.StableDiffusionPipeline) to specify the pipeline's class" ) if not support_textual_inversion and textual_inversions: raise NotImplementedError( f"Class {pipeline_class} is not a subclass of TextualInversionLoaderMixin, cannot load textual inversions" "Try using `bento_model.with_options(pipeline_class=diffusers.StableDiffusionPipeline) to specify the pipeline's class" ) class DiffusersRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ("", "cpu") SUPPORTS_CPU_MULTI_THREADING = True def __init__(self): super().__init__() if torch.cuda.is_available() and _torch_dtype is None: torch_dtype = torch.float16 else: torch_dtype = _torch_dtype enable_xformers: bool = False if torch.cuda.is_available() and _enable_xformers is None: if is_xformers_available(): enable_xformers: bool = True device_id: str | None = None if torch.cuda.is_available(): device_id = "cuda" self.lora_dir = lora_dir self.pipeline: diffusers.DiffusionPipeline = load_model( bento_model, device_id=device_id, device_map=device_map, pipeline_class=pipeline_class, scheduler_class=scheduler_class, torch_dtype=torch_dtype, custom_pipeline=custom_pipeline, enable_xformers=enable_xformers, enable_attention_slicing=enable_attention_slicing, enable_sequential_cpu_offload=enable_sequential_cpu_offload, enable_model_cpu_offload=enable_model_cpu_offload, enable_torch_compile=enable_torch_compile, low_cpu_mem_usage=low_cpu_mem_usage, variant=variant, lora_weights=lora_weights, textual_inversions=textual_inversions, load_pretrained_extra_kwargs=load_pretrained_extra_kwargs, ) @bentoml.Runnable.method(batchable=False) def _replace_scheduler(self, scheduler_txt: str): try: scheduler_cls = _str2cls(scheduler_txt) if isinstance(self.pipeline.scheduler, scheduler_cls): return dict(success=True) if scheduler_cls in self.pipeline.scheduler.compatibles: self.pipeline.scheduler = scheduler_cls.from_config( self.pipeline.scheduler.config, ) return dict(success=True) else: return dict( success=False, error_message="scheduler class is incompatible to this pipeline", ) except (ModuleNotFoundError, ValueError, AttributeError):"Cannot import {scheduler_txt}") return dict( success=False, error_message="cannot import scheduler class", ) if support_lora: def _load_lora_weights( self: DiffusersRunnable, lora_weights: LoraOptionType | list[LoraOptionType], ): _load_lora_weights_to_pipeline(self.pipeline, lora_weights, self.lora_dir) def _unload_lora_weights( self: DiffusersRunnable, ): self.pipeline.unload_lora_weights() # clear cached lora weights from GPU memory torch.cuda.empty_cache() else: def _load_lora_weights( self: DiffusersRunnable, lora_args: LoraOptionType | list[LoraOptionType], ): raise NotImplementedError( f"Class {pipeline_class} is not a subclass of LoraLoaderMixin, cannot load lora weights" ) def _unload_lora_weights( self: DiffusersRunnable, ): raise NotImplementedError( f"Class {pipeline_class} is not a subclass of LoraLoaderMixin, cannot unload lora weights" ) setattr(DiffusersRunnable, "_load_lora_weights", _load_lora_weights) setattr(DiffusersRunnable, "_unload_lora_weights", _unload_lora_weights) def make_run_method( method_name: str, partial_kwargs: dict[str, t.Any] | None ) -> t.Callable[..., t.Any]: if support_lora: def _run_method( runnable_self: DiffusersRunnable, *args: t.Any, **kwargs: t.Any, ) -> t.Any: if method_partial_kwargs is not None: kwargs = dict(method_partial_kwargs, **kwargs) lora_weights: str | None = kwargs.pop("lora_weights", None) try: if lora_weights is not None: runnable_self._load_lora_weights(lora_weights) raw_method = getattr(runnable_self.pipeline, method_name) res = raw_method(*args, **kwargs) finally: torch.cuda.empty_cache() if lora_weights is not None: runnable_self._unload_lora_weights() # handle BaseOutput cannot be serialized yet if isinstance(res, diffusers.utils.BaseOutput): res = res.to_tuple() return res else: def _run_method( runnable_self: DiffusersRunnable, *args: t.Any, **kwargs: t.Any, ) -> t.Any: if method_partial_kwargs is not None: kwargs = dict(method_partial_kwargs, **kwargs) raw_method = getattr(runnable_self.pipeline, method_name) res = raw_method(*args, **kwargs) # handle BaseOutput cannot be serialized yet if isinstance(res, diffusers.utils.BaseOutput): res = res.to_tuple() return res return _run_method for method_name, options in method_partial_kwargs = partial_kwargs.get(method_name) DiffusersRunnable.add_method( make_run_method(method_name, method_partial_kwargs), name=method_name, batchable=options.batchable, batch_dim=options.batch_dim, input_spec=options.input_spec, output_spec=options.output_spec, ) return DiffusersRunnable