-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity][DistIR] Legalize redistribute #16098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Instead of adding a new pass, could we add an implementation of |
|
I wouldn't call the behavior of the pass as a normal legalize, because this pass doesn't translate relax op to call_tir. Depending on the resharding pattern, R.dist.redistribute may be translated to a relax op or a list of relax ops. And for the user side, normally they don't want to legalize redistribute together with other ops. This LegalizeRedistribute is expected to be called every time a redistribute is created in some pass and serve as an epilogue. So I disagree changing to FLegalize |
As of #15842, this is no longer a requirement of |
|
I agree that the handling of DistIR is different enough that we would like to decoule -- at least for now |
|
If we're planning on an integration at some point, that makes sense to me. While orthogonal and composable features are easier to use and maintain in the long-term, having fully de-coupled features is useful for path-clearing and exploration, |
|
@Lunderberg Since this PR doesn't have complex logic, can you do a quick review? |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. I have a couple of questions regarding whether redistribute_replica_to_shard is overly constrained, and whether we want to just use R.strided_slice instead.
| return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore | ||
|
|
||
|
|
||
| def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) -> Expr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change the type of num_workers from int to Expr. That allows the number of workers to be a symbolic variable. It doesn't require any runtime support, as the symbolic variable would be specialized later on, but this is very useful when writing generic implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do need runtime support. Currently Disco runtime treats num_workers as a constant, and legalize_ops of ccl op will throw away num_workers when converting to call_dps_packed(ccl, ...). This behavior needs to be changed if we want to make num_worker a Expr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend we do a refactor on the whole stack to support dynamic num_workers if there is really a need in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct, that we would require runtime support, but only if the num_workers is still dynamic after being lowered to either the disco runtime or the ccl op legalization. It is easier to write a single dynamic implementation then specialize to a variety of static cases, than it is to write several distinct static implementations. However, the initial writing of the dynamic implementation requires that it be expressible, even though it will be specialized out later in lowering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the case you are talking about. Why do we want to use a symbolic var for a value that is constant after lowering? And why will we have a variety of static case to specialize, given disco runtime regards num_workers as a global constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One example is if you are comparing how performance scales with the number of GPUs, then you need a way to specify the number of GPUs. Each data point collected would be for a specialized value of num_workers.
Effectively, having num_workers: Expr de-couples the communication mechanics from the choice of how many workers to use. It still can be explicitly specified at any stage of lowering, and it still must be statically-known after lowering, but we gain flexibility before then. This has a number of benefits, for example:
-
Predicted memory usage. If the number of workers is stored symbolically, adding up the size of all live values at any point gives the memory footprint as a function of the number of workers. Requiring the number of workers to be static at all points of lowering prevents this analysis.
-
Consistent optimization. If an optimization is applicable regardless of the number of workers, the optimization should be applied at a point when the number of workers is unknown. This prevents a developer from accidentally making a less general optimization. (e.g. By using a sharded tensor shape in the pattern-matching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. This makes sense. Since redistribute_R_to_S shares the attribute with scatter_from_worker0, I'd like to open up a followup PR for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg I just come up with an additional question on the attrs->args change: Do we need to recompile each time for different num_workers after this change with specialization? If yes, then what's the difference between define-symbolic->specialize flow and assume-constant flow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main difference occurs when there are additional optimization steps that occur after a module definition but before the module is handed off to relax.build for lowering/compilation.
If sharding (and propagation of sharding) is done early in optimization, then propagation of sharding produces large portions of the compute graph where no communication steps are required. These portions can be optimized as if they were single-GPU modules, with no specific handling of multi-GPU setups required.
If we can write modules with a dynamic number of gpus, we can write the lowering steps as a single optimization pipeline.
# With specialization occurring late in the pipeline.
mod = Sequential([
pre_sharding_optimizations,
shard_across_multiple_gpus,
propagate_sharding,
convert_to_local_view,
single_gpu_optimizations,
])(mod)
built_modules = [relax.build(specialize(mod, num_gpus)) for num_gpus in num_gpu_list]If we can only write modules with a static number of gpus, we cannot write an optimization pipeline, as the optimization pipeline
# With specialization occurring at the start of the pipeline.
mod = pre_sharding_optimizations(mod)
mods = [shard_across_multiple_gpus(mod, num_gpus) for num_gpus in num_gpu_list]
pipeline = Sequential([
propagate_sharding,
convert_to_local_view,
single_gpu_optimizations,
])
mods = [pipeline(mod) for mod in mods]
built_modules = [relax.build(mod) for mod in mods]It's not that it's impossible by any means, but that the restricted expressability in an early step means that a user must leave the world of a single IRModule much earlier.
|
@Lunderberg I have addressed your comments |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes, one follow-up, and one additional change requested.
tests/python/relax/distributed/test_distributed_transform_legalize_redistribute.py
Show resolved
Hide resolved
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making the changes, and for the discussion, and I think the changes are good to go in. I've run into significant issues in the past from overly-restrictive intermediates (e.g. being unable to use tril or triu when a parameter depended on seq_len), where the only limitation was the use of call->attrs instead of call->args, and so I push for the more flexible representation.
For the follow-up PR you mentioned, it may benefit from use of FNormalize, in order to provide a migration path for any external usage that comes up in the meantime.
This PR adds a pass LegalizeRedistribute that transforms R.dist.redistribute to ccl ops (Sharding->Replica will be translated as allgather) or slicing op (Replica->Sharding)