Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def decorator(func):
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]

@classmethod
def set_active_backend(cls, backend: str):
cls._active_backend = backend

@classmethod
def list_backends(cls):
return list(cls._backends.keys())
Expand Down Expand Up @@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)

old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
_AttentionBackendRegistry.set_active_backend(backend)

try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
_AttentionBackendRegistry.set_active_backend(old_backend)


def dispatch_attention_fn(
Expand Down Expand Up @@ -348,6 +352,7 @@ def dispatch_attention_fn(
check(**kwargs)

kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}

return backend_fn(**kwargs)


Expand Down
25 changes: 23 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def set_attention_backend(self, backend: str) -> None:
from .attention import AttentionModuleMixin
from .attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
Expand All @@ -607,17 +608,34 @@ def set_attention_backend(self, backend: str) -> None:
from .attention_processor import Attention, MochiAttention

logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)

parallel_config_set = False
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if getattr(processor, "_parallel_config", None) is not None:
parallel_config_set = True
break

backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

backend = AttentionBackendName(backend)
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but current attention backend '{backend.value}' "
f"does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
)

_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
Expand All @@ -626,6 +644,9 @@ def set_attention_backend(self, backend: str) -> None:
continue
processor._attention_backend = backend

# Important to set the active backend so that it propagates gracefully throughout.
_AttentionBackendRegistry.set_active_backend(backend)

def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
Expand Down Expand Up @@ -1538,7 +1559,7 @@ def enable_parallelism(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
f"calling `model.enable_parallelism()`."
)

# All modules use the same attention processor and backend. We don't need to
Expand Down
Loading