diff --git a/python/tvm/relax/transform/ipc_allreduce_rewrite.py b/python/tvm/relax/transform/ipc_allreduce_rewrite.py index df40181cb981..de5c22863403 100644 --- a/python/tvm/relax/transform/ipc_allreduce_rewrite.py +++ b/python/tvm/relax/transform/ipc_allreduce_rewrite.py @@ -97,8 +97,8 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re # Return if the call is not a summation all-reduce. return - assert len(call.args) == 3 - allreduce_input = call.args[0] + assert len(call.args) == 4 + allreduce_input, _strategy, _ingroup, allreduce_output = call.args alloc_tensor = self.alloc_map.get(allreduce_input, None) if alloc_tensor is None or alloc_tensor.args[3].value != "global": # Return if the allocation of all-reduce input is not recorded, @@ -113,9 +113,13 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re alloc_tensor.args[2], relax.StringImm("ipc_memory"), ) + self.binding_replacement_map[call] = relax.Call( relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"), - args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]], + # The "cuda_ipc.custom_allreduce" implementation does not + # yet support num_groups>1, and therefore does not use the + # `in_group` argument. + [allreduce_input, relax.PrimValue(self.allreduce_strategy), allreduce_output], ) diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index f14953122ee3..da85423aafd7 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -37,7 +37,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(True), alloc1 + ) return alloc1 @I.ir_module @@ -85,7 +87,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(False), alloc1 + ) return alloc1 @I.ir_module @@ -137,7 +141,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([1]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([1]), R.prim_value(True), alloc1 + ) return alloc1 allreduce_strategy = 1 @@ -146,6 +152,4 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore if __name__ == "__main__": - test_ipc_allreduce_rewrite() - test_ipc_allreduce_spread_along_reshape() - test_ipc_allreduce_skip_reducer_other_than_sum() + tvm.testing.main()