From 9a019e8e3b7be892b5fef15fc2f36bbd626a22d3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Jul 2024 10:13:10 -0500 Subject: [PATCH 1/3] [Transform][Relax] Handle `is_group` argument in IPC AllReduce The `relax.transform.IPCAllReduceRewrite` pass rewrites calls to `"runtime.disco.allreduce"` to instead call an optimized `"runtime.disco.cuda_ipc.custom_allreduce"` version. When the legalization of `R.ccl.allreduce` was updated in https://github.com/apache/tvm/pull/17180 to provide an `in_group` argument, the `IPCAllReduceRewrite` pass was not updated. This commit updates the `IPCAllReduceRewrite` to be handle the additional `in_group` argument. --- .../tvm/relax/transform/ipc_allreduce_rewrite.py | 10 +++++++--- .../test_transform_ipc_allreduce_rewrite.py | 16 ++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/transform/ipc_allreduce_rewrite.py b/python/tvm/relax/transform/ipc_allreduce_rewrite.py index df40181cb981..dae48026c532 100644 --- a/python/tvm/relax/transform/ipc_allreduce_rewrite.py +++ b/python/tvm/relax/transform/ipc_allreduce_rewrite.py @@ -97,8 +97,8 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re # Return if the call is not a summation all-reduce. return - assert len(call.args) == 3 - allreduce_input = call.args[0] + assert len(call.args) == 4 + allreduce_input, strategy, in_group, allreduce_output = call.args alloc_tensor = self.alloc_map.get(allreduce_input, None) if alloc_tensor is None or alloc_tensor.args[3].value != "global": # Return if the allocation of all-reduce input is not recorded, @@ -113,9 +113,13 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re alloc_tensor.args[2], relax.StringImm("ipc_memory"), ) + self.binding_replacement_map[call] = relax.Call( relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"), - args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]], + # The "cuda_ipc.custom_allreduce" implementation does not + # yet support num_groups>1, and therefore does not use the + # `in_group` argument. + [allreduce_input, relax.PrimValue(self.allreduce_strategy), allreduce_output], ) diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index f14953122ee3..da85423aafd7 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -37,7 +37,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(True), alloc1 + ) return alloc1 @I.ir_module @@ -85,7 +87,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(False), alloc1 + ) return alloc1 @I.ir_module @@ -137,7 +141,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([1]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([1]), R.prim_value(True), alloc1 + ) return alloc1 allreduce_strategy = 1 @@ -146,6 +152,4 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore if __name__ == "__main__": - test_ipc_allreduce_rewrite() - test_ipc_allreduce_spread_along_reshape() - test_ipc_allreduce_skip_reducer_other_than_sum() + tvm.testing.main() From 397a839c0090ae62e39d690eab463106cd947200 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Jul 2024 11:45:39 -0500 Subject: [PATCH 2/3] lint fix --- python/tvm/relax/transform/ipc_allreduce_rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/ipc_allreduce_rewrite.py b/python/tvm/relax/transform/ipc_allreduce_rewrite.py index dae48026c532..0266f222a3cf 100644 --- a/python/tvm/relax/transform/ipc_allreduce_rewrite.py +++ b/python/tvm/relax/transform/ipc_allreduce_rewrite.py @@ -98,7 +98,7 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re return assert len(call.args) == 4 - allreduce_input, strategy, in_group, allreduce_output = call.args + allreduce_input, _strategy, _in_group, allreduce_output = call.args alloc_tensor = self.alloc_map.get(allreduce_input, None) if alloc_tensor is None or alloc_tensor.args[3].value != "global": # Return if the allocation of all-reduce input is not recorded, From 15149f61a2a8221a4c566d68c5901c47623e3e2b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Jul 2024 13:22:48 -0500 Subject: [PATCH 3/3] lint fix --- python/tvm/relax/transform/ipc_allreduce_rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/ipc_allreduce_rewrite.py b/python/tvm/relax/transform/ipc_allreduce_rewrite.py index 0266f222a3cf..de5c22863403 100644 --- a/python/tvm/relax/transform/ipc_allreduce_rewrite.py +++ b/python/tvm/relax/transform/ipc_allreduce_rewrite.py @@ -98,7 +98,7 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re return assert len(call.args) == 4 - allreduce_input, _strategy, _in_group, allreduce_output = call.args + allreduce_input, _strategy, _ingroup, allreduce_output = call.args alloc_tensor = self.alloc_map.get(allreduce_input, None) if alloc_tensor is None or alloc_tensor.args[3].value != "global": # Return if the allocation of all-reduce input is not recorded,