From 842b092964c9c3fee99534b44901e09462178530 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Sep 2024 09:48:05 -0500 Subject: [PATCH 1/3] [Relax] Validate StructInfo annotations in well-formed check Prior to this commit, the Relax well-formed checker verified that each expression had a non-null `StructInfo` annotation, but did not perform any validation on the contents of the `StructInfo` annotation. This commit updates the Relax well-formed check to verify that the `StructInfo` annotations are accurate by comparing against the `StructInfo` that would be inferred for an expression. (This only requires that the information is accurate, not that it is complete. For example, an expression that is inferred to be `R.Tensor(shape=[128,8], dtype="float32")` may have annotation of `R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of `R.Tensor(shape=[128,8], dtype="int32")`.) --- src/relax/analysis/well_formed.cc | 43 ++++++++++ src/relax/op/op.cc | 21 +++-- .../python/relax/test_analysis_well_formed.py | 85 +++++++++++++++++++ tests/python/relax/test_ast_printer.py | 4 +- tests/python/relax/test_frontend_from_fx.py | 10 +-- .../relax/test_transform_decompose_ops.py | 4 +- .../test_transform_ipc_allreduce_rewrite.py | 4 +- .../relax/test_transform_legalize_ops_ccl.py | 4 +- ..._transform_legalize_ops_create_datatype.py | 34 ++++---- ...sform_legalize_ops_index_linear_algebra.py | 2 +- .../test_transform_legalize_ops_manipulate.py | 51 ++++++----- .../relax/test_transform_legalize_ops_nn.py | 38 ++++++--- ...ansform_legalize_ops_search_statistical.py | 4 +- .../relax/test_transform_realize_vdevice.py | 16 ++-- ...test_transform_static_plan_block_memory.py | 8 +- .../test_transform_to_mixed_precision.py | 12 +-- tests/python/relax/test_tvmscript_parser.py | 11 ++- tests/python/relax/test_vm_cuda_graph.py | 8 +- tests/python/relax/test_vm_multi_device.py | 16 ++-- 19 files changed, 268 insertions(+), 107 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 7688c4a64291..7873d5ce2022 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor, << err.what()); } } + + if (check_struct_info_ && call->struct_info_.defined()) { + // The `InferStructInfo` method isn't currently exposed by the + // Normalizer, and can only be called indirectly by normalizing + // an expression that does not yet have `StructInfo`. + auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); + Call copied(call->op, call->args, call->attrs, call->sinfo_args); + Optional normalized = NullOpt; + try { + normalized = dummy_builder->Normalize(copied); + } catch (std::exception& err) { + Malformed(Diagnostic::Error(call) + << "Each Relax expression must be able to have its StructInfo inferred. " + << "However, inferring the struct info of expression " << GetRef(call) + << " resulted in the error: \n" + << err.what()); + } + if (normalized.defined()) { + auto inferred_struct_info = GetStructInfo(normalized.value()); + auto current_struct_info = Downcast(call->struct_info_); + + // An error should be raised if the annotated StructInfo is + // provably incorrect. This check is done using + // `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1` + // represents cases that are neither provably correct nor + // provably incorrect. If this check were replaced with + // `!IsBaseOf(...)`, cases that are correct but not provably + // so would raise an exception. + // + // For example, if a dynamic size in the inferred StructInfo + // is equivalent to the expression used in the annotated + // StructInfo, but the TIR simplifications are not sufficient + // to prove that the two expressions are equivalent, we should + // not raise an error. + if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) < + BaseCheckResult::kFailL1) { + Malformed(Diagnostic::Error(call) + << "All information in StructInfo annotations must be correct. " + << "However, while the expression " << GetRef(call) << " is annotated as " + << current_struct_info << ", the expression outputs " << inferred_struct_info); + } + } + } } void VisitExpr_(const IfNode* op) final { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 3e0f0eba313a..a7d97a59a100 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c ICHECK(call->args.size() == 1); ICHECK(call->args[0]->struct_info_.defined()); const auto* tsinfo = GetStructInfoAs(call->args[0]); - ICHECK(tsinfo && tsinfo->shape.defined()); - ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); - ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected argument to be 1-d, " - << "but " << call << " has argument " << call->args[0] - << " with struct info " << call->args[0]->struct_info_; - const IntImmNode* ndim = shape_expr->values[0].as(); - ICHECK(ndim); - return ShapeStructInfo(ndim->value); + ICHECK(tsinfo); + ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, " + << "but " << call << " has argument " << call->args[0] + << " with struct info " << call->args[0]->struct_info_; + + if (tsinfo->shape.defined()) { + ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); + const IntImmNode* ndim = shape_expr->values[0].as(); + if (ndim) { + return ShapeStructInfo(ndim->value); + } + } + return ShapeStructInfo(kUnknownNDim); } RELAY_REGISTER_OP("relax.tensor_to_shape") diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 3db3efee1afc..d9eefcfd0ef2 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -1295,5 +1295,90 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent(): assert not rx.analysis.well_formed(main) +def test_incomplete_struct_info_must_be_consistent(): + """StructInfo annotations must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(ndim=3) = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + +def test_struct_info_annotations_must_be_correct(): + """StructInfo annotations must be correct + + To be well-formed, the inferred struct info must not conflict with + the StructInfo annotations. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + +def test_struct_info_may_be_incomplete(): + """StructInfo annotations may be less specific + + The StructInfo annotations are not required to be an exact match + to the inferred StructInfo, and may provide less specific + information than the inference would provide. + + """ + + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Object = R.add(A, B) + return C + + assert rx.analysis.well_formed(Module) + + +def test_incomplete_struct_info_must_be_consistent(): + """StructInfo annotations must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(ndim=3) = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 6005ecb0fa58..1df7dcf36f79 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -366,8 +366,8 @@ def f( ) -> R.Object: m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) - w: R.Tensor = R.multiply(z, z) - q: R.Tensor(ndim=2) = R.add(w, w) + w: R.Tensor(ndim=2) = R.multiply(z, z) + q: R.Tensor = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.shape_of(t) o: R.Object = R.call_packed( diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 78fc7abdf748..191ea4da5e56 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -79,7 +79,7 @@ def main( out_layout="NCW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4), dtype="float32") = lv3 R.output(gv) @@ -171,7 +171,7 @@ def main( out_layout="NCW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 6), dtype="float32") = lv3 R.output(gv) @@ -263,7 +263,7 @@ def main( out_layout="NCHW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1]) lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 R.output(gv) @@ -355,7 +355,7 @@ def main( out_layout="NCHW", out_dtype="float32", ) - lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, [1, 3, 1, 1]) lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3 R.output(gv) @@ -447,7 +447,7 @@ def main( out_layout="NCDHW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1, 1]) lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3 R.output(gv) diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index 4e5bcb82e979..2564913d79ae 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -360,14 +360,14 @@ def test_op_tensor_to_shape(): @I.ir_module class Before: @R.function - def main(t: R.Tensor(ndim=1, dtype="int64")): + def main(t: R.Tensor([3], dtype="int64")): gv: R.Shape(ndim=3) = R.tensor_to_shape(t) return gv @I.ir_module class Expected: @R.function - def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3): + def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3): x = T.int64() x_1 = T.int64() x_2 = T.int64() diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index da85423aafd7..fa68c16e691d 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -83,7 +83,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore + lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) @@ -103,7 +103,7 @@ def main( alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("ipc_memory") ) - lv1: R.Tensor((m, n), dtype="float16") = R.reshape( # type: ignore + lv1: R.Tensor((m * n,), dtype="float16") = R.reshape( # type: ignore alloc, R.shape([m * n]) ) alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 9ea4d21d610d..923a8e8d9739 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -101,8 +101,8 @@ def test_scatter_from_worker0(): @tvm.script.ir_module class ScatterFromWorker0: @R.function - def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10), "float32"): - gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1) + def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5), "float32"): + gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1) return gv0 @I.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 7b2b2d2e7644..a8af295ac3b9 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -160,19 +160,19 @@ def test_full_like(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.full_like(x, v) + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full_like(x, v) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32")) + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32")) return gv @T.prim_func(private=True) - def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_full"): @@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32")) + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5, "float32")) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32")) + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="int32")) return gv @T.prim_func(private=True) - def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) - T_full[ax0, ax1] = T.float32(-5) + T_full[ax0, ax1] = T.int32(-5) # fmt: on mod = LegalizeOps()(FullLike) @@ -253,19 +253,19 @@ def test_full_like_symbolic(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"): m = T.int64() n = T.int64() - gv: R.Tensor((m, n), "float32") = R.full_like(x, v) + gv: R.Tensor((m, n), "int32") = R.full_like(x, v) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"): m = T.int64() n = T.int64() - gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32")) + gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="int32")) return gv @T.prim_func(private=True) @@ -273,13 +273,13 @@ def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): T.func_attr({"tir.noalias": True}) m = T.int64() n = T.int64() - T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) - T_full[ax0, ax1] = rxplaceholder[()] + T_full[ax0, ax1] = T.int32(rxplaceholder[()]) # fmt: on mod = LegalizeOps()(FullLike) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index d0aaddb1ca52..2f4da5cf0653 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -230,7 +230,7 @@ def test_strided_slice_no_strides(): class StridedSlice: @R.function def main(x: R.Tensor((8, 9, 10, 10), "float32")) : - gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) + gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) return gv @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index ba5d4d7d1219..a0ecd3c73dc9 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -691,9 +691,12 @@ def test_data_dependent_reshape(): @tvm.script.ir_module class DDReshape: @R.function - def main(x: R.Tensor((3, ), dtype="int64")): - lv: R.Shape([3,]) = R.tensor_to_shape(x) - gv = R.reshape(x, lv) + def main( + x: R.Tensor([2], dtype="int64"), + y: R.Tensor([16],dtype='float32'), + ): + lv: R.Shape(ndim=2) = R.tensor_to_shape(x) + gv = R.reshape(y, lv) return gv # fmt: on @@ -704,29 +707,35 @@ def main(x: R.Tensor((3, ), dtype="int64")): # fmt: off @I.ir_module class Expected: + @R.function + def main( + x: R.Tensor([2], dtype="int64"), + y: R.Tensor([16],dtype="float32"), + ) -> R.Tensor(ndim=2, dtype="float32"): + M = T.int64() + N = T.int64() + gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape(ndim=2),)) + _ = R.match_cast(gv, R.Shape([M,N])) + _ = R.shape([M,N]) + gv_1 = R.call_tir(Expected.reshape, (y,), out_sinfo=R.Tensor([M,N], dtype="float32")) + return gv_1 + @T.prim_func(private=True) def reshape( - rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: T.handle + rxplaceholder: T.Buffer(T.int64(16), "float32"), + var_T_reshape: T.handle, ): T.func_attr({"tir.noalias": True}) - x = T.int64() - T_reshape = T.match_buffer(var_T_reshape, (x,), "int64") - # with T.block("root"): - for ax0 in range(x): + M = T.int64() + N = T.int64() + T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32") + for i,j in T.grid(M,N): with T.block("T_reshape"): - v_ax0 = T.axis.spatial(x, ax0) - T.reads(rxplaceholder[v_ax0 % T.int64(3)]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] + vi,vj = T.axis.remap('SS',[i,j]) + T.reads(rxplaceholder[(vi*N + vj) % 16]) + T.writes(T_reshape[vi,vj]) + T_reshape[vi,vj] = rxplaceholder[(vi*N + vj) % 16] - @R.function - def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"): - x_1 = T.int64() - gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) - y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) - lv: R.Shape([x_1]) = R.shape([x_1]) - gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) - return gv_1 # fmt: on tvm.ir.assert_structural_equal(out_mod, Expected) @@ -914,7 +923,7 @@ def test_squeeze_no_axis(): class Squeeze: @R.function def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) : - gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x) + gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x) return gv @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 92d139d23b5d..d03d48968d90 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -33,7 +33,7 @@ def test_conv1d(): class Conv1d: @R.function def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"): - gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) + gv: R.Tensor((2, 64, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) return gv @tvm.script.ir_module @@ -210,7 +210,7 @@ def test_conv2d(): class Conv2d: @R.function def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): - gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) + gv: R.Tensor((2, 64, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) return gv @tvm.script.ir_module @@ -3298,20 +3298,32 @@ def test_nll_loss(): @tvm.script.ir_module class NLLLoss: @R.function - def main(predictions: R.Tensor((2, 3, 4, 5), "float32"), targets: R.Tensor((2, 4, 5), "int64"), weights: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): - gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) + def main( + predictions: R.Tensor((2, 3, 4, 5), "float32"), + targets: R.Tensor((2, 4, 5), "int64"), + weights: R.Tensor((3,), "float32"), + ) -> R.Tensor((), "float32"): + gv = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) return gv @tvm.script.ir_module class Expected: @R.function - def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5), dtype="int64"), weights: R.Tensor((4,), dtype="float32"),) -> R.Tensor((), dtype="float32"): - # block 0 + def main( + predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), + targets: R.Tensor((2, 4, 5), dtype="int64"), + weights: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), R.Tensor((), dtype="float32")) return gv @T.prim_func(private=True) - def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), rxplaceholder_2: T.Buffer(T.int64(4), "float32"), T_divide: T.Buffer((), "float32"),): + def nll_loss( + predictions: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), + targets: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), + weights: T.Buffer(T.int64(3), "float32"), + output: T.Buffer((), "float32"), + ): # function attr dict T.func_attr({"tir.noalias": True}) # body @@ -3323,9 +3335,9 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) + T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss[v_ax0, v_ax1, v_ax2]) - nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) + nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) @@ -3337,9 +3349,9 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) + T.reads(targets[v_ax0, v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2]) - nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) + nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_red_1"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) @@ -3351,8 +3363,8 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 with T.block("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss_red[()], nll_loss_red_1[()]) - T.writes(T_divide[()]) - T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()] + T.writes(output[()]) + output[()] = nll_loss_red[()] / nll_loss_red_1[()] # fmt: on mod = LegalizeOps()(NLLLoss) tvm.ir.assert_structural_equal(mod, Expected) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 2a28151dbe7e..f8dab8981552 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -999,8 +999,8 @@ def test_variance_no_keepdims(): @tvm.script.ir_module class Variance: @R.function - def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4, 1), "float32"): - gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3], keepdims=False) + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv: R.Tensor((3, 4), "float32") = R.variance(x, [0, 3], keepdims=False) return gv @I.ir_module diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index 4c530d5e4931..fa642821842d 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -61,8 +61,9 @@ def foo( y1 = y x2 = x1 y2 = y1 - lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) - gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z) + x2 = R.hint_on_device(x2, tvm.cpu()) + lv0 = R.add(x2, y2) + gv = R.multiply(lv0, z) R.output(gv) return gv @@ -91,6 +92,7 @@ def foo( y1: R.Tensor((2, 3), "float32", "llvm") = y x2: R.Tensor((2, 3), "float32", "llvm") = x1 y2: R.Tensor((2, 3), "float32", "llvm") = y1 + x2: R.Tensor((2, 3), "float32", "llvm") = x2 lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z) R.output(gv) @@ -121,7 +123,8 @@ def foo( y1 = y x2 = x1 y2 = y1 - s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) + x2 = R.hint_on_device(x2, tvm.cpu()) + s = R.add(x2, y2) m = R.multiply(s, z) return m @@ -146,6 +149,7 @@ def foo( y1: R.Tensor((2, 3), "float32", "llvm") = y x2: R.Tensor((2, 3), "float32", "llvm") = x1 y2: R.Tensor((2, 3), "float32", "llvm") = y1 + x2: R.Tensor((2, 3), "float32", "llvm") = x2 s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) m: R.Tensor((2, 3), "float32", "llvm") = R.multiply(s, z) return m @@ -275,10 +279,11 @@ def foo( z: R.Tensor((2, 3), "float32"), ) -> R.Tensor((2, 3), "float32", "cuda"): with R.dataflow(): - lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y) + lv0 = R.add(x, y) + lv0 = R.hint_on_device(lv0, tvm.cpu()) lv1 = R.to_vdevice(lv0, "cuda") lv2 = R.add(z, z) - gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2) + gv = R.multiply(lv1, lv2) R.output(gv) return gv @@ -304,6 +309,7 @@ def foo( ) -> R.Tensor((2, 3), "float32", "cuda"): with R.dataflow(): lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y) + lv0: R.Tensor((2, 3), "float32", "llvm") = lv0 lv1: R.Tensor((2, 3), "float32", "cuda") = R.to_vdevice(lv0, "cuda") lv2: R.Tensor((2, 3), "float32", "cuda") = R.add(z, z) gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index f9e632d34897..1150827b19f9 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1386,11 +1386,11 @@ def main( ) cls.cumsum(probs, lv1, alloc1) cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1 - lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed( + lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_packed( "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), - sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float"),), + sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), ) return lv1_1 @@ -1403,7 +1403,7 @@ def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.hand @R.function def main( probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32") - ) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"): + ) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"): batch_size = T.int64() vocab_size = T.int64() R.func_attr( @@ -1437,7 +1437,7 @@ def main( ) cls.cumsum(probs, lv1, alloc1) cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1 - lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed( + lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_packed( "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index ed10fc95c723..658f80a06ec5 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -906,15 +906,15 @@ def main( ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): # block 0 with R.dataflow(): - lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = R.nn.conv2d( x, w, strides=[1, 1], padding=[0, 0, 0, 0], out_dtype="float32", ) - lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) - lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) + lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = R.add(lv142, lv143) R.output(lv144) return lv144 @@ -1001,15 +1001,15 @@ def main( ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): # block 0 with R.dataflow(): - lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = R.nn.conv2d( x, w, strides=[1, 1], padding=[0, 0, 0, 0], out_dtype="float32", ) - lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) - lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) + lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = R.add(lv142, lv143) R.output(lv144) return lv144 diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64f2efd4af9e..ac11c6d2cb99 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -873,8 +873,8 @@ def foo( ) -> R.Object: m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) - w: R.Tensor = R.multiply(z, z) - q: R.Tensor(ndim=2) = R.add(w, w) + w: R.Tensor(ndim=2) = R.multiply(z, z) + q: R.Tensor = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) @@ -893,9 +893,9 @@ def _check_struct_info(binding, expected_sinfo): sh = bindings[4].var _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) - _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) - _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) - _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2)) _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) _check_struct_info(bindings[6], relax.ObjectStructInfo()) @@ -1045,7 +1045,6 @@ def main( def test_call_tir_inplace_with_tuple_var_raises_error(): - with pytest.raises(tvm.error.DiagnosticError): @tvm.script.ir_module diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 49ebcc1d05b2..b6c8cdfdeea4 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -36,13 +36,13 @@ def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"global_symbol": "main"}) gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] - alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _: R.Tuple = cls.add(x, alloc) storage1: R.Object = gv[1] gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage) gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),)) storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8")) - alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) lv4: R.Tensor((16, 16), dtype="float32") = gv2[0] _3: R.Tuple = cls.add(lv4, alloc3) lv5: R.Tensor(dtype="float32") = alloc3 @@ -71,12 +71,12 @@ def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.O cls = Module R.func_attr({"global_symbol": "cuda_graph_capture"}) lv0: R.Tensor((16, 16), dtype="float32") = alloc - alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _1: R.Tuple = cls.add(lv0, alloc1) lv1: R.Tensor(dtype="float32") = alloc1 lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,) lv3: R.Tensor(dtype="float32") = lv2[0] - alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _2: R.Tuple = cls.add(lv3, alloc2) lv4: R.Tensor(dtype="float32") = alloc2 gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,) diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index ec2fbd1cdf60..1481d3f760d1 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" + from typing import List import tvm from tvm import relax @@ -61,11 +62,10 @@ def foo( z: R.Tensor((4, 5), "float32"), ) -> R.Tensor((2, 5), "float32"): with R.dataflow(): - lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y) # noqa: F722 - lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( # noqa: F722 - lv0, "llvm:1" # noqa: F722 - ) - gv = R.matmul(lv1, z) # noqa: F722 + lv0 = R.matmul(x, y) + lv0 = R.hint_on_device(lv0, tvm.cpu(0)) + lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice(lv0, "llvm:1") # noqa: F722 + gv = R.matmul(lv1, z) R.output(gv) return gv @@ -109,11 +109,13 @@ def foo( with R.dataflow(): lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b) # noqa: F722 lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice( # noqa: F722 - lv0, "cuda:1" # noqa: F722 + lv0, + "cuda:1", # noqa: F722 ) lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c) # noqa: F722 lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice( # noqa: F722 - lv2, "cuda:2" # noqa: F722 + lv2, + "cuda:2", # noqa: F722 ) gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d) # noqa: F722 R.output(gv) From bcc040bcba5f5408773a72bfc1a26f2092185840 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 16 Sep 2024 14:34:22 -0500 Subject: [PATCH 2/3] lint fix --- tests/python/relax/test_vm_multi_device.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index 1481d3f760d1..c87d0b417c23 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -64,7 +64,9 @@ def foo( with R.dataflow(): lv0 = R.matmul(x, y) lv0 = R.hint_on_device(lv0, tvm.cpu(0)) - lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice(lv0, "llvm:1") # noqa: F722 + lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( + lv0, "llvm:1" + ) # noqa: F722 gv = R.matmul(lv1, z) R.output(gv) return gv From be57e09bf92546f9d30ba2ec88f33e228db1ca3e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 16 Sep 2024 15:36:14 -0500 Subject: [PATCH 3/3] lint fix --- tests/python/relax/test_vm_multi_device.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index c87d0b417c23..73c78d70f042 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -64,9 +64,9 @@ def foo( with R.dataflow(): lv0 = R.matmul(x, y) lv0 = R.hint_on_device(lv0, tvm.cpu(0)) - lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( + lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( # noqa: F722 lv0, "llvm:1" - ) # noqa: F722 + ) gv = R.matmul(lv1, z) R.output(gv) return gv