Skip to content

Remove is_root_in_mesh from CommParams and add mesh. #2250

Merged
wujingyue merged 7 commits intomainfrom
wjy/comm
Jun 3, 2024
Merged

Remove is_root_in_mesh from CommParams and add mesh. #2250
wujingyue merged 7 commits intomainfrom
wjy/comm

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented May 15, 2024

As a follow up to #2172. mesh will eventually go away after the input and output TensorViews of a communication are embedded in the base Expr. Anyhow, is_root_in_mesh can be computed from other parameters so is removed.

@wujingyue wujingyue marked this pull request as draft May 15, 2024 17:50
@wujingyue wujingyue force-pushed the wjy/comm branch 2 times, most recently from 8b782a3 to 937d050 Compare May 15, 2024 19:38
@wujingyue wujingyue changed the title Remove some fields from CommsParams. Remove some fields from CommParams. May 15, 2024
@wujingyue wujingyue changed the title Remove some fields from CommParams. Remove is_root_in_mesh and team from CommParams. May 15, 2024
@wujingyue wujingyue marked this pull request as ready for review May 15, 2024 19:49
@wujingyue wujingyue requested review from cowanmeg and samnordmann May 15, 2024 19:49
@wujingyue
Copy link
Collaborator Author

!build --dist

Comment on lines -115 to -119
std::unordered_set<DeviceIdxType> team_without_duplicates(
params.team.begin(), params.team.end());
NVF_ERROR(
team_without_duplicates.size() == params.team.size(),
"the communication must not involve the same device more than once");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move this check to Communication::team() but it looks unnecessary because

  1. DeviceMesh already checks against duplicates, and
  2. Communication::team() adds the root only when it's not in the mesh.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds ok to me. I think only the CommunicationTests directly set the team, but the normal use case creates a team from the DeviceMesh, which checks against duplicates.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changed the CommunicationTests as well to set mesh instead of team. The bottom line is that CommParams contains all parameters (e.g. mesh) needed for constructing a Communication and other fields (e.g. team) are all computed.

NVF_ERROR(
team_without_duplicates.size() == params.team.size(),
"the communication must not involve the same device more than once");
NVF_ERROR(!params.team.empty(), "the team size must be greater than 0");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to Communication's constructor.

NVF_ERROR(!params.team.empty(), "the team size must be greater than 0");
if (hasRoot(params.type)) {
auto it = std::find(params.team.begin(), params.team.end(), params.root);
NVF_ERROR(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could be moved to Communication::team(), but it would look a bit silly:

if (root isn't in team) {
  team.push_back(root);
}
assert root is in team  // obviously true


namespace {

inline void assertBufferCount(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless proven to be important, I tend to leave the inline decision to the compiler. Otherwise, the function could organic grow to a giant without people realizing the overhead of inlining.

Comment on lines -115 to -119
std::unordered_set<DeviceIdxType> team_without_duplicates(
params.team.begin(), params.team.end());
NVF_ERROR(
team_without_duplicates.size() == params.team.size(),
"the communication must not involve the same device more than once");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds ok to me. I think only the CommunicationTests directly set the team, but the normal use case creates a team from the DeviceMesh, which checks against duplicates.

@wujingyue wujingyue requested a review from cowanmeg May 16, 2024 05:52
@samnordmann
Copy link
Collaborator

samnordmann commented May 16, 2024

Thanks @wujingyue! I am glad to see this kind of clean up, because it significantly cleans up the code! However, I am not sure about the design decision and the choice of parameters to encode the communication.

I am talking more precisely about keeping the mesh as a parameter and deducing the Team from it. I think we should do the opposite: provide the Team as a parameter and deduce the mesh. The most natural way to describe a collective is to specify the Team; The mesh is only relevant to nvFuser (or in general to Tensor sharding, but not MPI communications).

Actually, the mesh shouldn't be needed at all to describe the collective. I understand that for now it used for implementing the trick of adding an empty tensor when the root is not in the sender/receiver mesh, but that is

  1. a not so important corner case and
  2. a not a so convincing implementation to support this corner case that I did a long time ago and that I'd be happy to revisit. In particular, it makes no sense to me that posting a collective implies allocation of a new tensor -- this should be handled by the caller.

Anyway, my point is that this special case should not influence our design decisions.

Given the current stage of development, storing the mesh and deducing the Team could still make sense -- however, when considering 2D and ND shardings (that we are adding right now), it becomes imo even more obvious that storing the Team is more natural.
Let us now consider a simple case of 1d resharding over a 2D mesh:

mesh = {{0,1,2,3}, {4,5,6,7}}
tv0 [DiDx{i1}, DiDy{i2}]
tv1 = copy(tv0) [DiDx{i1}, i2]

To produce tv1, we need two allgather: Allgather1 among Team (0,1,2,3), and Allgather 2 among Team (4,5,6,7). We see that, with no extra info, there is no way to deduce the team from the mesh.

Another thing that I don't like in the current code are lines like https://github.com/NVIDIA/Fuser/blob/e7fdb714bec23ef51d8f72b81c42046fb204273b/csrc/multidevice/lower_communication.cpp#L184C11-L184C22
where we create a new DeviceMesh whereas, in general, we should only be creating new Team, from flattening/slicing the meshes and maybe adding the root. DeviceMesh are supposed to describe Tensor Sharding, so a mesh is multi-D and should always be attached to some tensors -- but currently we also use it to specify the list of ranks participating in a collective (i.e. the Team) which is not great.

The design I have been thinking about that I would suggest is to have:

  • in CommParam: Team, root, RedOpType. This is enough to fully describe the collective, and we don't need the mesh.
  • (optionally) As Communication's input/output: the input and output TensorView*s, from which we can access both the sender and receiver meshes, but also any other information we can imagine be potentially useful (the axis index, allocation domain, layouts, aliases, etc.). Also, it encapsulate the DeviceMesh in the TV pointers, which is better than just creating and passing meshes around.

In my opinion, this design is cleaner and more sustainable for the future.

@cowanmeg @wujingyue @naoyam please let me know what you think, I'd be happy to have your opinion and to discuss before deciding what to do

@wujingyue
Copy link
Collaborator Author

The design I have been thinking about that I would suggest is to have:

  • in CommParam: Team, root, RedOpType. This is enough to fully describe the collective, and we don't need the mesh.
  • (optionally) As Communication's input/output: the input and output TensorView*s, from which we can access both the sender and receiver meshes, but also any other information we can imagine be potentially useful (the axis index, allocation domain, layouts, aliases, etc.). Also, it encapsulate the DeviceMesh in the TV pointers, which is better than just creating and passing meshes around.

I'm almost on the same page as you.

I think the second item will be needed to conveniently traverse through Exprs in host IR, e.g., by reusing the existing graph traversal utilities in the code base. What's why I put a TODO here https://github.com/NVIDIA/Fuser/pull/2256/files#diff-8a31ccb86eb9f8d9b346f35addcbd61ba465a418dd23124e61634f1fd5ecc9e7R51-R52.

So, in the end state, Communication will contain neither mesh nor team as data attributes. mesh will be collected from in and out, and team will be computed from type (an attribute), meshes of in and out and their IterDomains.

If we agree on that, I'll send out a PR by Monday to add input and output to the op and collect mesh from there instead of passing it in as a parameter. AFAICS, that will be the easiest for me to reach the end state, given the chain of three other PRs on top of this PR. Wdyt?

@samnordmann
Copy link
Collaborator

The design I have been thinking about that I would suggest is to have:

  • in CommParam: Team, root, RedOpType. This is enough to fully describe the collective, and we don't need the mesh.
  • (optionally) As Communication's input/output: the input and output TensorView*s, from which we can access both the sender and receiver meshes, but also any other information we can imagine be potentially useful (the axis index, allocation domain, layouts, aliases, etc.). Also, it encapsulate the DeviceMesh in the TV pointers, which is better than just creating and passing meshes around.

I'm almost on the same page as you.

I think the second item will be needed to conveniently traverse through Exprs in host IR, e.g., by reusing the existing graph traversal utilities in the code base. What's why I put a TODO here https://github.com/NVIDIA/Fuser/pull/2256/files#diff-8a31ccb86eb9f8d9b346f35addcbd61ba465a418dd23124e61634f1fd5ecc9e7R51-R52.

So, in the end state, Communication will contain neither mesh nor team as data attributes. mesh will be collected from in and out, and team will be computed from type (an attribute), meshes of in and out and their IterDomains.

If we agree on that, I'll send out a PR by Monday to add input and output to the op and collect mesh from there instead of passing it in as a parameter. AFAICS, that will be the easiest for me to reach the end state, given the chain of three other PRs on top of this PR. Wdyt?

As discussed orally I think we still need to pass the Team -- not the mesh

@wujingyue
Copy link
Collaborator Author

To produce tv1, we need two allgather: Allgather1 among Team (0,1,2,3), and Allgather 2 among Team (4,5,6,7)

I may have misunderstood what you mean the first time I read this, so let me try again.

For this example,

mesh = {{0,1,2,3}, {4,5,6,7}}
tv0 [DiDx{i1}, DiDy{i2}]
tv1 = copy(tv0) [DiDx{i1}, i2]

Are you proposing to lower/compile the fusion into two Communications, one of which however will be skipped at run time? The first one is an Allgather among {0,1,2,3}, and the second is another AllGather but among {4,5,6,7}. At run time, when my_device_index is known, a device will skip executing one of two Communications that doesn't involve my_device_index.

That would make some sense; otherwise, I failed to see how team (just as my_device_index), a piece of runtime information, can/should be passed into the IR at lowering time.

@cowanmeg
Copy link
Collaborator

To produce tv1, we need two allgather: Allgather1 among Team (0,1,2,3), and Allgather 2 among Team (4,5,6,7)

I may have misunderstood what you mean the first time I read this, so let me try again.

For this example,

mesh = {{0,1,2,3}, {4,5,6,7}}
tv0 [DiDx{i1}, DiDy{i2}]
tv1 = copy(tv0) [DiDx{i1}, i2]

Are you proposing to lower/compile the fusion into two Communications, one of which however will be skipped at run time? The first one is an Allgather among {0,1,2,3}, and the second is another AllGather but among {4,5,6,7}. At run time, when my_device_index is known, a device will skip executing one of two Communications that doesn't involve my_device_index.

That would make some sense; otherwise, I failed to see how team (just as my_device_index), a piece of runtime information, can/should be passed into the IR at lowering time.

There will only be one allgather in each program, but each device will look up its team (in this case the devices along its column (axis y) of the DeviceMesh). Globally there will be multiple allgathers running in parallel over disjoint teams.

If a device is in a the mesh of the consumer/producer tv it's guaranteed to take part in the communication, but it will only involve devices in its team (which will be an axis of the device mesh).

@wujingyue
Copy link
Collaborator Author

There will only be one allgather in each program, but each device will look up its team

That was my first understanding as well. But, if we choose to implement it that way, team would better be computed (and maybe cached for speed) at run time and therefore not passed into Communication's constructor1. This enables nvFuser to potentially compile the fusion to host IR once and ship the IR to multiple processes to execute. That being said, I admit that the PR as is doesn't compute team from device ID either -- that needs to be fixed.

Footnotes

  1. This is also aligned with Sam's comment in another PR.

@samnordmann
Copy link
Collaborator

samnordmann commented May 17, 2024

Sorry for the long comment... I have many things to say :) Let's discuss orally if things are not clear.

About the PR status only: As far as im concerned, the other open PRs authored by @wujingyue pose no problem and could be merged independently from this one.

The only problematic point imo in current PR is that I want to be able to pass Team instead of the DeviceMesh, or both for now if that makes things easier (but I agree that it overlaps too much..!). I will need that for implementing 2D resharding. Maybe I can find a way to work around that, but passing Team will just be more natural.


There will only be one allgather in each program, but each device will look up its team

Imo, we should represent all the communications in the DAG, even the ones not run by the current device 1. In other words, each device should be able to see the complete DAG of operation that is run on the network. If a device is not involved in a communication, we can either (symbolically) predicate out this communication (or, better, a whole segment) or at runtime simply skip posting any collective at the executor level.

Are you proposing to lower/compile the fusion into two Communications, one of which however will be skipped at run time? The first one is an Allgather among {0,1,2,3}, and the second is another AllGather but among {4,5,6,7}. At run time, when my_device_index is known, a device will skip executing one of two Communications that doesn't involve my_device_index.

that's exactly what I mean! Thanks for rephrasing.

That would make some sense; otherwise, I failed to see how team (just as my_device_index), a piece of runtime information, can/should be passed into the IR at lowering time.

I think I see what you mean, and I am on the same page. In particular, imo my_device_index shouldn't be used to create the Communication.

However, it is probably a not so important detail, but imo Team is not really a runtime information, at least not more than DeviceMesh since it is deduced from DeviceMesh. For now, Team and DeviceMesh have concrete data (i.e., are non-symbolic), so this distinction could be a bit abstract, but eventually both could be made symbolic (see this doc)

my_device_index should also eventually be made a Val* -- but, for sure, the concrete data that this Val is binded to is a runtime info. Having it has a Val is useful, e.g., to express symbolic predicate like if my_device_index in DeviceMesh {...}


Let me add one more example to illustrate that there is not always a canonical way to deduce the Team from sender/receiver meshes.
Imagine we want to lower the following resharding op

DeviceMesh mesh0 = {0,1};
DeviceMesh mesh1= {2,3};

tv0 = makeConcreteTensor({128});
tv1 = set(tv0);

tv0->setDeviceMesh(mesh0);
tv1->setDeviceMesh(mesh1);

This situation is a type of pipeline parallelism where tv0 is duplicated on {0,1}, and we want to send the whole data to {2,3}. There are many way to lower this resharding op to MPI-communications:

  1. Broadcast(root=0, team={0,2,3})
  2. Broadcast(root=1, team={1,2,3})
  3. Broadcast(root=0, team={0,2}) + Broadcast(root=1, team={1,3})
  4. Scatter(root=0, team={0,2,3}) + Allgather(team={2,3})
    etc.

This dummy example illustrates that there are lowering decision when mapping a resharding tv op to MPI collective. We want our class Communication to be flexible enough to express all possible decomposition. This is why I think that Team should be passed in the Communication's constructor.

Let me now go back to what I wrote before:

  • (optionally) As Communication's input/output: the input and output TensorView*s, from which we can access both the sender and receiver meshes, but also any other information we can imagine be potentially useful (the axis index, allocation domain, layouts, aliases, etc.). Also, it encapsulate the DeviceMesh in the TV pointers, which is better than just creating and passing meshes around.

Now that I think about it we shouldn't pass the I/O tvs as the communication's I/O at construction. Instead, the communication should optionnally hold a pointer to the Expr* from which it was lowered from.

  1. Using the info contained in I/O TVs should be purely optional, because it could be tricky to always map back a Communication to a tensor resharding. Since there is no one-to-one mapping between resharding ops and MPI-comms, we might need to create new TVs and Expr each time we create a communication, and those newly created IRs will form small DAGs
    disconnected from the main DAG.
  2. it is handful (e.g. in the test) to trigger a collective without necessarily associate a symbolic tensors resharding op to it
  3. We could simply store the Expr* instead of the inputs and outputs
  4. because of the way Fusion is implemented, when a Val is given as a Expr's output, it changes its definition, and so we will loose the data depency in the compute DAG.
  5. @wujingyue about what you wrote:

I think the second item will be needed to conveniently traverse through Exprs in host IR, e.g., by reusing the existing graph traversal utilities in the code base.

This is not a problem, because a Communication always appears in the HostIrContainer embedded inside a PostOnStream -- the later provides the I/O dependency info at the Host level.

Footnotes

  1. (thanks @wujingyue for this markdown trick :)) this the same logic as representing TensorViews in the DAG even if the tensor is not sharded on the current device.

@wujingyue
Copy link
Collaborator Author

Thanks for the verbosity -- it clarified lots of your thought process, which is critical given our working hours rarely overlap :)

I'm summarizing the discussion with you today:

I'll add team back in this PR. I may extend it to teams in future PRs, which is a vector of disjoint teams. This shortens host IR for the common cases -- it doesn't have to replicate a collective and guard them with predicates. This is also analogous to the replica_groups in https://openxla.org/xla/operation_semantics#allgather (but be aware of first principles vs analogies thinking).

I'll add input/output TensorViews to Communication in future PRs. While I understand how you can live without them using PostOnStream, I plan to try out, in parallel, using them as top-level host IR instructions. I haven't convinced myself whether PostOnStream is needed eventually (there are many ways to represent stream assignments), so it makes sense at the moment to keep the flexibility to use Communication without having to encapsulate them in a PostOnStream.

because of the way Fusion is implemented, when a Val is given as a Expr's output, it changes its definition, and so we will loose the data depency in the compute DAG.

We are on the same page already -- I'm writing this down for posterity: there's a difference between fusion IR which requires SSA and kernel IR which doesn't and is I think more analogous to host IR. The concern you raised is definitely a challenge to SSA (although there's a solution called Phi node to address that), but shouldn't be a problem for host IR.

@wujingyue wujingyue changed the title Remove is_root_in_mesh and team from CommParams. Remove is_root_in_mesh from CommParams. May 31, 2024
@wujingyue wujingyue changed the title Remove is_root_in_mesh from CommParams. Remove is_root_in_mesh from CommParams and add mesh. May 31, 2024
}

if (params.team.size() == 1) {
if (communication->team().size() == 1) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an intermediate change that turns out to be unnecessary for this PR. However, #2256 will remove CommParams, so I chose to keep this change which makes the code closer to the end state.

@wujingyue
Copy link
Collaborator Author

I added team back so this is now good for another round of reviews. PTAL and thanks!

@wujingyue
Copy link
Collaborator Author

!build

1 similar comment
@wujingyue
Copy link
Collaborator Author

!build

@wujingyue
Copy link
Collaborator Author

!build

Copy link
Collaborator

@samnordmann samnordmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thank you!

@wujingyue wujingyue merged commit 6edded1 into main Jun 3, 2024
@wujingyue wujingyue deleted the wjy/comm branch June 3, 2024 05:08
zasdfgbnm pushed a commit that referenced this pull request Jun 5, 2024
As a follow up to #2172. `mesh` will eventually go away after the input
and output TensorViews of a communication are embedded in the base Expr.
Anyhow, `is_root_in_mesh` can be computed from other parameters so is
removed.
wujingyue added a commit that referenced this pull request Jun 9, 2024
…OnStream` (#2328)

Bootstrapping for simplifying the design. With this patch, the
`Communication` class is used to post a collective, without the need of
wrapping it inside a `PostOnStream` Ir.

There is no functional change, this is only a minor cleanup I would like
to upstream before introducing the "collective wait" IR.

To achieve this, we need to add I/O TVs to the `Communication`'s
constructor, and use those I/O directly. This patch aligns with the
discussion from #2250 and should
lead to further simplification in `MultiDeviceExecutor`, that I would
like to leave to another PR for avoiding conflicts.

Note that the `Communication` class is used in two places: Host IRs and
`MultiDeviceExecutor`. Since the former case should replace the latter
eventually, we maintain for now these two usages. In
`MultiDeviceExecutor`, we don't set a `Communication`'s I/O, that's why
those arguments are made optional in the constructor

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants