From 6b51f251103390751dc54adea9ae4bf13511e41e Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Thu, 27 Feb 2025 13:41:45 +0100 Subject: [PATCH] Fix relax.ccl.scatter_from_worker0 assert The current code asserts that floormod(input_dims[0], num_shards) can not be proven to be true, which has two problems: - It is unclear what it means to prove floormod(..). Prove that floormod's return value is truthy, i.e. non-zero? - It always checks the 0th dimension of the input shape, but the dimension index to be sharded is given by the "axis" parameter. This commit fixes both of the above. --- src/relax/op/ccl/ccl.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 092727cb5115..c32cdc3aacb3 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -133,12 +133,12 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { auto input_shape = input_sinfo->GetShape(); CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape."; - if (analyzer->CanProve(floormod(input_shape.value()[0], PrimExpr(num_workers))) != 0) { + if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], PrimExpr(num_workers)) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) - << "scatter_from_worker0 expects the size of axis 0 of input tensor to be " - "divisible by the " - "num_workers. However, the axis 0 of input tensor is " - << input_shape.value() << " while num_workers is " << num_workers); + << "scatter_from_worker0 expects the size of axis " << attrs->axis + << " of input tensor to be divisible by the num_workers. However, axis " + << attrs->axis << " of input tensor is " << input_shape.value() + << " while num_workers is " << num_workers); } Array output_shape = input_shape.value();