diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index c2bf27393173..510093810249 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -107,7 +107,12 @@ class DataTypeVisitor final : public StmtExprVisitor { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); - vextent_[iv->var.as()] = op->value.dtype(); + if (op->attr_key == attr::thread_extent) { + // Narrow extents to 32 bits on GPU. + vextent_[iv->var.as()] = DataType::Int(32); + } else { + vextent_[iv->var.as()] = op->value.dtype(); + } StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0486ef40017b..85a3dd5636f1 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -229,6 +229,23 @@ def test_broadcast_to(): tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu +def test_broadcast_to_const_shape_int64(): + shape_like = relay.const(np.array([1, 5]), dtype="int64") + x = relay.var("x", shape=(1,), dtype="int64") + z = relay.broadcast_to(x, shape=shape_like) + z = relay.sum(z, axis=0) + + f = relay.Function([x], z) + + x = np.random.randint(10, size=(1,), dtype="int64") + ref_res = np.broadcast_to(x, (5,)) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) + + @tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6)