From 617f40d90544b3ce3237e79c397fbff5237ce1c1 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 11 Apr 2022 16:16:55 -0700 Subject: [PATCH 1/4] hotfix edge case for broadcast_to const shape int64 --- python/tvm/relay/op/transform.py | 2 +- tests/python/relay/test_op_level10.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 27dfefbb7890..663e0f0c60c0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -836,7 +836,7 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = list(shape.data.numpy()) + shape = [int(i) for i in shape.data.numpy()] if isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) if isinstance(shape, int): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0486ef40017b..85a3dd5636f1 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -229,6 +229,23 @@ def test_broadcast_to(): tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu +def test_broadcast_to_const_shape_int64(): + shape_like = relay.const(np.array([1, 5]), dtype="int64") + x = relay.var("x", shape=(1,), dtype="int64") + z = relay.broadcast_to(x, shape=shape_like) + z = relay.sum(z, axis=0) + + f = relay.Function([x], z) + + x = np.random.randint(10, size=(1,), dtype="int64") + ref_res = np.broadcast_to(x, (5,)) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) + + @tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6) From 27db291aeb27ce5a1d697f508efdf0b4d2023607 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 12 Apr 2022 12:05:42 -0700 Subject: [PATCH 2/4] narrow extents to int32 on gpu --- python/tvm/relay/op/transform.py | 2 +- src/tir/transforms/narrow_datatype.cc | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 663e0f0c60c0..27dfefbb7890 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -836,7 +836,7 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = [int(i) for i in shape.data.numpy()] + shape = list(shape.data.numpy()) if isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) if isinstance(shape, int): diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index c2bf27393173..008f00103434 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -107,7 +107,12 @@ class DataTypeVisitor final : public StmtExprVisitor { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); - vextent_[iv->var.as()] = op->value.dtype(); + if (op->attr_key == attr::thread_extent) { + // narrow extents to 32 bits on GPU + vextent_[iv->var.as()] = DataType::Int(32); + } else { + vextent_[iv->var.as()] = op->value.dtype(); + } StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); From 0ad6d9848aed0c1086ebee8d7618b5fa3e4c4edd Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 12 Apr 2022 12:39:35 -0700 Subject: [PATCH 3/4] check that narrowing can be done safely --- src/tir/transforms/narrow_datatype.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 008f00103434..486bf9c28f68 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -108,7 +108,9 @@ class DataTypeVisitor final : public StmtExprVisitor { ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); if (op->attr_key == attr::thread_extent) { - // narrow extents to 32 bits on GPU + // Narrow extents to 32 bits on GPU. + ICHECK(analyzer_.CanProveLess(op->value, static_cast(INT32_MAX) + 1)) + << "cannot prove thread extent <= INT32_MAX, which is required for GPU"; vextent_[iv->var.as()] = DataType::Int(32); } else { vextent_[iv->var.as()] = op->value.dtype(); From 38f27110c5c36d40cb8848730ffc916ae6740416 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 12 Apr 2022 14:41:54 -0700 Subject: [PATCH 4/4] remove CanProveLess check --- src/tir/transforms/narrow_datatype.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 486bf9c28f68..510093810249 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -109,8 +109,6 @@ class DataTypeVisitor final : public StmtExprVisitor { analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); if (op->attr_key == attr::thread_extent) { // Narrow extents to 32 bits on GPU. - ICHECK(analyzer_.CanProveLess(op->value, static_cast(INT32_MAX) + 1)) - << "cannot prove thread extent <= INT32_MAX, which is required for GPU"; vextent_[iv->var.as()] = DataType::Int(32); } else { vextent_[iv->var.as()] = op->value.dtype();