Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Left some nits, otherwise looks very nice!
| raise NotImplementedError( | ||
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
|
|
||
| if self.ulysses_degree * self.ring_degree > world_size: |
There was a problem hiding this comment.
Should we hit line as both cannot be set, right?
There was a problem hiding this comment.
Both can be set techinically, but currently both can't be > 1. Also this is for cases where you have 3 GPUs available and you set something like ulysses_degree=1 and ring_degree==4 (more GPUs being requested is greater than world_size)
There was a problem hiding this comment.
Feels slightly confusing to me but since we're erroring out early for unsupported ulysses_degree and ring_degree value combos, I think it's okay.
| self._flattened_mesh = self._mesh._flatten() | ||
| self._ring_mesh = self._mesh["ring"] | ||
| self._ulysses_mesh = self._mesh["ulysses"] | ||
| self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() |
There was a problem hiding this comment.
Can't they be None? Why unguard?
There was a problem hiding this comment.
They are internal attributes that are derived from mesh which is set through the setup method. The device mesh object is also only dynamically created when enabled_parallelism is called.
The guards are redundant, they would always be None unless set explicitly for some custom debugging.
| mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), | ||
| mesh_dim_names=("ring", "ulysses"), | ||
| ) | ||
| from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry |
There was a problem hiding this comment.
(nit): Would prefer to keep the import at the beginning of the method implementation, if possible (after from .attention_processor import Attention, MochiAttention).
sayakpaul
left a comment
There was a problem hiding this comment.
Just some nits. But not merge-blocker.
| @property | ||
| def mesh_shape(self) -> Tuple[int, int]: | ||
| """Shape of the device mesh (ring_degree, ulysses_degree).""" | ||
| return (self.ring_degree, self.ulysses_degree) |
There was a problem hiding this comment.
Would it be possible to add a small explainer about what it would mean for different values, for example - "(3, 1), (1, 3)", etc.? When both are being set, both cannot be > 1.
| supports_context_parallel = backend in cls._supports_context_parallel | ||
| is_degree_greater_than_1 = parallel_config is not None and ( | ||
| parallel_config.context_parallel_config.ring_degree > 1 | ||
| or parallel_config.context_parallel_config.ulysses_degree > 1 | ||
| ) | ||
| return supports_context_parallel and is_degree_greater_than_1 | ||
| supports_context_parallel = backend.value in cls._supports_context_parallel | ||
| return supports_context_parallel |
| cp_mesh = None | ||
| attention_classes = (Attention, MochiAttention, AttentionModuleMixin) | ||
|
|
||
| # Step 1: Validate attention backend supports context parallelism if enabled |
There was a problem hiding this comment.
(nit): Seems like we are not documenting other steps. So, maybe we can remove this comment or add other steps as comments.
|
@sayakpaul It worked for me. When I set the attn backend to native, it threw an error correctly. but i think the INFO 10-31 07:05:23 [__init__.py:64] Found attention_backend from config, set attention backend to: native
[rank1]: transformer.enable_parallelism(
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/models/modeling_utils.py", line 1529, in enable_parallelism
[rank1]: raise ValueError(
[rank1]: ValueError: Context parallelism is enabled but the attention processor 'WanAttnProcessor' is using backend 'native' which does not support context parallelism. Please set a compatible attention backend: ['_native_cudnn', 'flash', 'sage'] using `model.set_attention_backend()` before calling `enable_parallelism()`. |
What does this PR do?
Currently CP inference will run with split hooks even if the attention backend doesn't support it. This can lead to weird results #12443
This PR
enable_parallelismFixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.