Skip to content

Conversation

@jinhongyii
Copy link
Contributor

This PR adds a pass LegalizeRedistribute that transforms R.dist.redistribute to ccl ops (Sharding->Replica will be translated as allgather) or slicing op (Replica->Sharding)

@jinhongyii
Copy link
Contributor Author

cc: @Lunderberg @csullivan @tqchen

@Lunderberg
Copy link
Contributor

Instead of adding a new pass, could we add an implementation of FLegalize, then allow relax.transform.LegalizeOps to accept a list of operators which should be legalized? That would avoid needing to make these single-purpose transforms in the future.

@jinhongyii
Copy link
Contributor Author

jinhongyii commented Nov 9, 2023

I wouldn't call the behavior of the pass as a normal legalize, because this pass doesn't translate relax op to call_tir. Depending on the resharding pattern, R.dist.redistribute may be translated to a relax op or a list of relax ops. And for the user side, normally they don't want to legalize redistribute together with other ops. This LegalizeRedistribute is expected to be called every time a redistribute is created in some pass and serve as an epilogue. So I disagree changing to FLegalize

@Lunderberg
Copy link
Contributor

Lunderberg commented Nov 9, 2023

because this pass doesn't translate relax op to call_tir.

As of #15842, this is no longer a requirement of FLegalize. My goal with that PR was specifically to enable this type of usage, where Relax operator lowering is best expressed in terms of other Relax operators.

@tqchen
Copy link
Member

tqchen commented Nov 9, 2023

I agree that the handling of DistIR is different enough that we would like to decoule -- at least for now

@Lunderberg
Copy link
Contributor

Lunderberg commented Nov 10, 2023

If we're planning on an integration at some point, that makes sense to me. While orthogonal and composable features are easier to use and maintain in the long-term, having fully de-coupled features is useful for path-clearing and exploration,

@jinhongyii
Copy link
Contributor Author

@Lunderberg Since this PR doesn't have complex logic, can you do a quick review?

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. I have a couple of questions regarding whether redistribute_replica_to_shard is overly constrained, and whether we want to just use R.strided_slice instead.

return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore


def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) -> Expr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change the type of num_workers from int to Expr. That allows the number of workers to be a symbolic variable. It doesn't require any runtime support, as the symbolic variable would be specialized later on, but this is very useful when writing generic implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do need runtime support. Currently Disco runtime treats num_workers as a constant, and legalize_ops of ccl op will throw away num_workers when converting to call_dps_packed(ccl, ...). This behavior needs to be changed if we want to make num_worker a Expr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend we do a refactor on the whole stack to support dynamic num_workers if there is really a need in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct, that we would require runtime support, but only if the num_workers is still dynamic after being lowered to either the disco runtime or the ccl op legalization. It is easier to write a single dynamic implementation then specialize to a variety of static cases, than it is to write several distinct static implementations. However, the initial writing of the dynamic implementation requires that it be expressible, even though it will be specialized out later in lowering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the case you are talking about. Why do we want to use a symbolic var for a value that is constant after lowering? And why will we have a variety of static case to specialize, given disco runtime regards num_workers as a global constant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example is if you are comparing how performance scales with the number of GPUs, then you need a way to specify the number of GPUs. Each data point collected would be for a specialized value of num_workers.

Effectively, having num_workers: Expr de-couples the communication mechanics from the choice of how many workers to use. It still can be explicitly specified at any stage of lowering, and it still must be statically-known after lowering, but we gain flexibility before then. This has a number of benefits, for example:

  • Predicted memory usage. If the number of workers is stored symbolically, adding up the size of all live values at any point gives the memory footprint as a function of the number of workers. Requiring the number of workers to be static at all points of lowering prevents this analysis.

  • Consistent optimization. If an optimization is applicable regardless of the number of workers, the optimization should be applied at a point when the number of workers is unknown. This prevents a developer from accidentally making a less general optimization. (e.g. By using a sharded tensor shape in the pattern-matching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. This makes sense. Since redistribute_R_to_S shares the attribute with scatter_from_worker0, I'd like to open up a followup PR for this.

Copy link
Contributor Author

@jinhongyii jinhongyii Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg I just come up with an additional question on the attrs->args change: Do we need to recompile each time for different num_workers after this change with specialization? If yes, then what's the difference between define-symbolic->specialize flow and assume-constant flow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference occurs when there are additional optimization steps that occur after a module definition but before the module is handed off to relax.build for lowering/compilation.

If sharding (and propagation of sharding) is done early in optimization, then propagation of sharding produces large portions of the compute graph where no communication steps are required. These portions can be optimized as if they were single-GPU modules, with no specific handling of multi-GPU setups required.

If we can write modules with a dynamic number of gpus, we can write the lowering steps as a single optimization pipeline.

# With specialization occurring late in the pipeline.
mod = Sequential([
    pre_sharding_optimizations,
    shard_across_multiple_gpus,
    propagate_sharding,
    convert_to_local_view,
    single_gpu_optimizations,
])(mod)

built_modules = [relax.build(specialize(mod, num_gpus)) for num_gpus in num_gpu_list]

If we can only write modules with a static number of gpus, we cannot write an optimization pipeline, as the optimization pipeline

# With specialization occurring at the start of the pipeline.
mod = pre_sharding_optimizations(mod)

mods = [shard_across_multiple_gpus(mod, num_gpus) for num_gpus in num_gpu_list]
pipeline = Sequential([
    propagate_sharding,
    convert_to_local_view,
    single_gpu_optimizations,
])
mods = [pipeline(mod) for mod in mods]
built_modules = [relax.build(mod) for mod in mods]

It's not that it's impossible by any means, but that the restricted expressability in an early step means that a user must leave the world of a single IRModule much earlier.

@jinhongyii
Copy link
Contributor Author

@Lunderberg I have addressed your comments

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes, one follow-up, and one additional change requested.

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes, and for the discussion, and I think the changes are good to go in. I've run into significant issues in the past from overly-restrictive intermediates (e.g. being unable to use tril or triu when a parameter depended on seq_len), where the only limitation was the use of call->attrs instead of call->args, and so I push for the more flexible representation.

For the follow-up PR you mentioned, it may benefit from use of FNormalize, in order to provide a migration path for any external usage that comes up in the meantime.

@jinhongyii jinhongyii merged commit 6f650db into apache:unity Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants