Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/diffusers/hooks/taylorseer_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ class TaylorSeerCacheConfig:
- Patterns are matched using `re.fullmatch` on the module name.
- If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked.
- If neither is provided, all attention-like modules are hooked by default.
- Example of inactive and active usage:
```
def forward(x):
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
return x
```

Example of inactive and active usage:

```py
def forward(x):
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
return x
```
"""

cache_interval: int = 5
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class TaylorSeerCacheConfig(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])

Expand All @@ -273,6 +288,10 @@ def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])


def apply_taylorseer_cache(*args, **kwargs):
requires_backends(apply_taylorseer_cache, ["torch"])


class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down