diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 04c030bea6fa..4664ec549388 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1719,6 +1719,28 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name=" return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name) +def ccl_allgather(x: Tensor, num_workers: int, name="ccl_allgather"): + """CCL Allgather operator + + Parameters + ---------- + x : relax.Expr + The input tensor. + + num_workers : int + Number of workers. + + name : str + Name hint for this operation. + + Returns + ------- + result : Tensor + The result tensor of allgather. + """ + return wrap_nested(_op.ccl.allgather(x._expr, num_workers), name) + + def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): """Broadcast data from worker-0 to all other workers.