-
Notifications
You must be signed in to change notification settings - Fork 79
Remove is_root_in_mesh from CommParams and add mesh.
#2250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d8a2bfc
681e762
a1fb079
a1770bc
1f0bac0
57d30c2
969ecbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,9 +48,7 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { | |
|
|
||
| namespace { | ||
|
|
||
| inline void assertBufferCount( | ||
| const std::vector<at::Tensor>& bufs, | ||
| size_t count) { | ||
| void assertBufferCount(const std::vector<at::Tensor>& bufs, size_t count) { | ||
| NVF_ERROR( | ||
| bufs.size() == count, | ||
| "there must be ", | ||
|
|
@@ -60,7 +58,7 @@ inline void assertBufferCount( | |
| " were given"); | ||
| } | ||
|
|
||
| inline void assertBuffersHaveSameSize( | ||
| void assertBuffersHaveSameSize( | ||
| const std::vector<at::Tensor>& bufs1, | ||
| const std::vector<at::Tensor>& bufs2) { | ||
| if (bufs1.empty() && bufs2.empty()) { | ||
|
|
@@ -76,7 +74,7 @@ inline void assertBuffersHaveSameSize( | |
| } | ||
| } | ||
|
|
||
| inline void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { | ||
| void doLocalCopy(const at::Tensor& dst, const at::Tensor& src) { | ||
| dst.view_as(src).copy_(src, /*non_blocking=*/true); | ||
| } | ||
|
|
||
|
|
@@ -100,7 +98,7 @@ T getInitialValue(c10d::ReduceOp::RedOpType op) { | |
|
|
||
| bool hasRoot(CommunicationType type) { | ||
| return type == CommunicationType::Gather || | ||
| type == CommunicationType::Scatter || | ||
| type == CommunicationType::Scatter || type == CommunicationType::Reduce || | ||
| type == CommunicationType::Broadcast || | ||
| type == CommunicationType::SendRecv; | ||
| } | ||
|
|
@@ -111,36 +109,27 @@ bool isReduction(CommunicationType type) { | |
| type == CommunicationType::ReduceScatter; | ||
| } | ||
|
|
||
| void assertValid(const CommParams& params) { | ||
| std::unordered_set<DeviceIdxType> team_without_duplicates( | ||
| params.team.begin(), params.team.end()); | ||
| NVF_ERROR( | ||
| team_without_duplicates.size() == params.team.size(), | ||
| "the communication must not involve the same device more than once"); | ||
|
Comment on lines
-115
to
-119
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could move this check to Communication::team() but it looks unnecessary because
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds ok to me. I think only the CommunicationTests directly set the team, but the normal use case creates a team from the DeviceMesh, which checks against duplicates.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR changed the CommunicationTests as well to set mesh instead of team. The bottom line is that CommParams contains all parameters (e.g. mesh) needed for constructing a Communication and other fields (e.g. team) are all computed. |
||
| NVF_ERROR(!params.team.empty(), "the team size must be greater than 0"); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is moved to Communication's constructor. |
||
| if (hasRoot(params.type)) { | ||
| auto it = std::find(params.team.begin(), params.team.end(), params.root); | ||
| NVF_ERROR( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check could be moved to Communication::team(), but it would look a bit silly: |
||
| it != params.team.end(), | ||
| "root (device ", | ||
| params.root, | ||
| ") must be present in the communication's team"); | ||
| } | ||
| } | ||
|
|
||
| // pytorch's process group expects the root to be specified | ||
| // as an integer between 0 and world_size-1. We choose it to be | ||
| // the device's relative index within the team | ||
| int64_t getRootRelativeIndex(const CommParams& params) { | ||
| auto it = std::find(params.team.begin(), params.team.end(), params.root); | ||
| return std::distance(params.team.begin(), it); | ||
| int64_t relative_index = params.mesh.idxOf(params.root); | ||
| // This assumes the root, if not in mesh, is added to the end of the team | ||
| // vector. Given the assumption is distantly enforced at | ||
| // Communication::team(), I slightly prefer making this function a method of | ||
| // Communication so relative_index can be computed from Communication::team() | ||
| // directly. Wdyt? | ||
| if (relative_index == -1) { | ||
| relative_index = params.mesh.size(); | ||
| } | ||
| return relative_index; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Communication::Communication(IrBuilderPasskey passkey, CommParams params) | ||
| : Expr(passkey), params_(std::move(params)) { | ||
| assertValid(params_); | ||
| NVF_ERROR(params_.mesh.size() > 0, "The mesh size must be greater than 0."); | ||
| } | ||
|
|
||
| Communication::Communication(const Communication* src, IrCloner* ir_cloner) | ||
|
|
@@ -153,10 +142,10 @@ std::string Communication::toString(const int indent_size) const { | |
|
|
||
| indent(ss, indent_size) << "Communication " << params_.type << ": {" | ||
| << std::endl; | ||
|
|
||
| if (hasRoot(params_.type)) { | ||
| indent(ss, indent_size + 1) << "root: " << params_.root << "," << std::endl; | ||
| } | ||
| indent(ss, indent_size + 1) << "mesh: " << params_.mesh << "," << std::endl; | ||
| indent(ss, indent_size + 1) << "team: " << params_.team << "," << std::endl; | ||
| indent(ss, indent_size) << "}"; | ||
|
|
||
|
|
@@ -178,9 +167,10 @@ bool Communication::sameAs(const Statement* other) const { | |
| const auto& p1 = this->params(); | ||
| const auto& p2 = other->as<Communication>()->params(); | ||
|
|
||
| return ( | ||
| p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && | ||
| p1.team == p2.team && (!isReduction(p1.type) || p1.redOp == p2.redOp)); | ||
| return (p1.type == p2.type && (!hasRoot(p1.type) || p1.root == p2.root) && | ||
| p1.mesh == p2.mesh && | ||
| (!isReduction(p1.type) || p1.redOp == p2.redOp)) && | ||
| p1.team == p2.team; | ||
| } | ||
|
|
||
| namespace { | ||
|
|
@@ -192,7 +182,7 @@ c10::intrusive_ptr<c10d::Work> postBroadcast( | |
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
| if (my_device_index == params.root) { | ||
| if (params.is_root_in_mesh) { | ||
| if (communication->isRootInMesh()) { | ||
| // Do a local copy and the subsequent broadcast will be in place. Consider | ||
| // ProcessGroupNCCL::_broadcast_oop so ncclBroadcast doesn't wait for the | ||
| // local copy to complete. | ||
|
|
@@ -203,7 +193,7 @@ c10::intrusive_ptr<c10d::Work> postBroadcast( | |
| } | ||
| } | ||
|
|
||
| if (params.team.size() == 1) { | ||
| if (communication->team().size() == 1) { | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was an intermediate change that turns out to be unnecessary for this PR. However, #2256 will remove CommParams, so I chose to keep this change which makes the code closer to the end state. |
||
| return nullptr; | ||
| } | ||
|
|
||
|
|
@@ -219,7 +209,7 @@ c10::intrusive_ptr<c10d::Work> postGather( | |
| at::Tensor input_tensor, | ||
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
| if (my_device_index == params.root && !params.is_root_in_mesh) { | ||
| if (my_device_index == params.root && !communication->isRootInMesh()) { | ||
| // This is likely a suboptimal way to allocate tensors for nccl. To benefit | ||
| // from zero copy | ||
| // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html), | ||
|
|
@@ -235,17 +225,17 @@ c10::intrusive_ptr<c10d::Work> postGather( | |
| if (my_device_index == params.root) { | ||
| output_tensors.resize(1); | ||
| int64_t j = 0; | ||
| for (auto i : c10::irange(params.team.size())) { | ||
| for (auto i : c10::irange(communication->team().size())) { | ||
| if (root_relative_index == static_cast<DeviceIdxType>(i) && | ||
| !params.is_root_in_mesh) { | ||
| !communication->isRootInMesh()) { | ||
| output_tensors[0].push_back(input_tensor); | ||
| continue; | ||
| } | ||
| output_tensors[0].push_back(output_tensor.slice(0, j, j + 1)); | ||
| j++; | ||
| } | ||
|
|
||
| assertBufferCount(output_tensors[0], params.team.size()); | ||
| assertBufferCount(output_tensors[0], communication->team().size()); | ||
| assertBuffersHaveSameSize(input_tensors, output_tensors[0]); | ||
| } | ||
|
|
||
|
|
@@ -259,13 +249,11 @@ c10::intrusive_ptr<c10d::Work> postAllgather( | |
| c10::intrusive_ptr<c10d::Backend> backend, | ||
| at::Tensor input_tensor, | ||
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
|
|
||
| std::vector<at::Tensor> input_tensors({input_tensor}); | ||
| std::vector<std::vector<at::Tensor>> output_tensors(1); | ||
| output_tensors[0] = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); | ||
|
|
||
| assertBufferCount(output_tensors[0], params.team.size()); | ||
| assertBufferCount(output_tensors[0], communication->team().size()); | ||
| assertBuffersHaveSameSize(input_tensors, output_tensors[0]); | ||
| return backend->allgather(output_tensors, input_tensors, {}); | ||
| } | ||
|
|
@@ -278,7 +266,7 @@ c10::intrusive_ptr<c10d::Work> postScatter( | |
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
|
|
||
| if (my_device_index == params.root && !params.is_root_in_mesh) { | ||
| if (my_device_index == params.root && !communication->isRootInMesh()) { | ||
| output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); | ||
| } | ||
| std::vector<at::Tensor> output_tensors({output_tensor}); | ||
|
|
@@ -288,17 +276,17 @@ c10::intrusive_ptr<c10d::Work> postScatter( | |
| if (my_device_index == params.root) { | ||
| input_tensors.resize(1); | ||
| int64_t j = 0; | ||
| for (auto i : c10::irange(params.team.size())) { | ||
| for (auto i : c10::irange(communication->team().size())) { | ||
| if (root_relative_index == static_cast<DeviceIdxType>(i) && | ||
| !params.is_root_in_mesh) { | ||
| !communication->isRootInMesh()) { | ||
| input_tensors.front().push_back(output_tensor); | ||
| continue; | ||
| } | ||
| input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); | ||
| j++; | ||
| } | ||
|
|
||
| assertBufferCount(input_tensors[0], params.team.size()); | ||
| assertBufferCount(input_tensors[0], communication->team().size()); | ||
| assertBuffersHaveSameSize(input_tensors[0], output_tensors); | ||
| } | ||
|
|
||
|
|
@@ -316,7 +304,7 @@ c10::intrusive_ptr<c10d::Work> postReduce( | |
|
|
||
| at::Tensor tensor; | ||
| if (my_device_index == params.root) { | ||
| if (params.is_root_in_mesh) { | ||
| if (communication->isRootInMesh()) { | ||
| doLocalCopy(output_tensor, input_tensor); | ||
| tensor = output_tensor; | ||
| } else { | ||
|
|
@@ -369,7 +357,7 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter( | |
|
|
||
| std::vector<at::Tensor> output_tensors({output_tensor}); | ||
|
|
||
| assertBufferCount(input_tensors[0], params.team.size()); | ||
| assertBufferCount(input_tensors[0], communication->team().size()); | ||
| return backend->reduce_scatter( | ||
| output_tensors, input_tensors, {.reduceOp = params.redOp}); | ||
| } | ||
|
|
@@ -382,18 +370,15 @@ c10::intrusive_ptr<c10d::Work> postSendRecv( | |
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
|
|
||
| NVF_ERROR( | ||
| params.team.size() == 1 || params.team.size() == 2, | ||
| "the team size should be 1 or 2"); | ||
| NVF_ERROR(params.mesh.size() == 1, "The mesh size should be 1."); | ||
|
|
||
| if (params.team.size() == 1) { | ||
| if (communication->isRootInMesh()) { | ||
| doLocalCopy(output_tensor, input_tensor); | ||
| return nullptr; | ||
| } | ||
|
|
||
| const DeviceIdxType sender = params.root; | ||
| const DeviceIdxType receiver = | ||
| params.team.at(0) == sender ? params.team.at(1) : params.team.at(0); | ||
| const DeviceIdxType receiver = params.mesh.at(0); | ||
|
|
||
| std::vector<at::Tensor> tensors; | ||
| if (my_device_index == sender) { | ||
|
|
@@ -414,9 +399,9 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication( | |
| at::Tensor input_tensor, | ||
| at::Tensor output_tensor) { | ||
| const CommParams& params = communication->params(); | ||
| const Team& team = communication->team(); | ||
| NVF_ERROR( | ||
| std::find(params.team.begin(), params.team.end(), my_device_index) != | ||
| params.team.end(), | ||
| std::find(team.begin(), team.end(), my_device_index) != team.end(), | ||
| "current device index ", | ||
| my_device_index, | ||
| " must be present in the communication's team"); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless proven to be important, I tend to leave the inline decision to the compiler. Otherwise, the function could organic grow to a giant without people realizing the overhead of inlining.