From e8be5cf6017884825982e1adcdab64692e5c891e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 29 Jul 2024 10:11:30 -0500 Subject: [PATCH 1/2] [Relax] Refactor RealizeVDevice to remove in-place mutation Prior to this commit, the `relax.transform.RealizeVDevice` pass performed in-place update on expressions appearing in its input `IRModule`, overwriting their struct info. In-place mutation of TVM's IR types is only legal when the scope has sole ownership of the IR object, such as through the `CopyOnWrite` functionality, and is not allowed when the object is shared. As a result, applying `RealizeVDevice` would cause unexpected updates in unrelated expressions. Most noticeably, the `IRModule` used as input to `RealizeVDevice` would have its variable erroneously updated. This commit refactors the `RealizeVDevice` transform to remove all in-place mutation. The same propagation rules are followed, with known `VDevice` annotations propagated forward from the output of `R.hint_on_device`, and propagated backwards from the input of `R.hint_on_device` if no such annotation already exists. Closes https://github.com/apache/tvm/issues/17205. --- src/relax/transform/realize_vdevice.cc | 492 +++++++++++------- .../relax/test_transform_realize_vdevice.py | 90 ++++ 2 files changed, 399 insertions(+), 183 deletions(-) diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index ec02efa996e6..0df86515dbcc 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -29,259 +29,385 @@ namespace tvm { namespace relax { -void UpdateTensorStructInfo(Expr expr, StructInfo struct_info) { - if (auto* tensor_sinfo = expr->struct_info_.as()) { - auto* new_tensor_sinfo = struct_info.as(); - if (new_tensor_sinfo != nullptr && new_tensor_sinfo->vdevice.defined() && - !tensor_sinfo->vdevice.defined()) { - expr->struct_info_ = struct_info; - expr->checked_type_ = GetStaticType(struct_info); - } +namespace { + +class VDeviceLookup { + public: + explicit VDeviceLookup(IRModule mod) { + auto opt_global_info = mod->global_infos.Get("vdevice"); + if (!opt_global_info) return; + + auto downcast_vdevice = [](GlobalInfo info) -> VDevice { + if (auto vdevice = info.as()) { + return vdevice.value(); + } else { + LOG(FATAL) << "TypeError: " + << "Each item in an IRModule's \"vdevice\" annotation must be a VDevice, " + << "but instead found item of type " << info->GetTypeKey(); + } + }; + + opt_vdevices_ = opt_global_info.value().Map(downcast_vdevice); } -} -void AddVDeviceToStuctInfo(Expr expr, VDevice vdevice) { - auto* tinfo = GetStructInfoAs(expr); - if (tinfo != nullptr) { - if (tinfo->shape.defined()) { - UpdateTensorStructInfo( - expr, TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span)); - } else { - UpdateTensorStructInfo(expr, - TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span)); + VDevice operator()(Attrs hint_on_device_attrs) { + auto attrs = hint_on_device_attrs.as(); + ICHECK(attrs); + int32_t device_type = attrs->dev_type; + int32_t device_id = attrs->dev_id; + + CHECK(opt_vdevices_.defined()) + << "ValueError: The target VDevice in the GlobalInfos was not found."; + + auto vdevices = opt_vdevices_.value(); + CHECK_GE(device_id, 0) << "ValueError: " + << "The device id in R.hint_on_device must not be negative"; + + for (auto vdevice : vdevices) { + int dev_type = vdevice->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevice->vdevice_id == device_id) { + return vdevice; + } } + LOG(FATAL) << "ValueError: " + << "Expected to find device with type " << device_id << " and id " << device_id + << ", but no such device was found in the IRModule's \"vdevice\" annotation"; } -} -class VDeviceRealizer : public ExprMutator { + private: + Optional> opt_vdevices_ = NullOpt; +}; + +class DeviceHintCollector : ExprVisitor { public: - explicit VDeviceRealizer(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {} + static std::tuple, Map> Collect(IRModule mod) { + DeviceHintCollector visitor{VDeviceLookup(mod)}; - IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - auto updated_func = Downcast(this->VisitExpr(func)); - builder_->UpdateFunction(gv, Downcast(updated_func)); + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor(func.value()); } } - return builder_->GetContextIRModule(); + + return {visitor.known_vdevice_, visitor.hint_on_device_inputs_}; } private: - using ExprMutator::VisitExpr_; + explicit DeviceHintCollector(VDeviceLookup vdevice_lookup) : vdevice_lookup_(vdevice_lookup) {} + + void VisitExpr_(const FunctionNode* func) override { + ExprVisitor::VisitExpr_(func); + + std::function check_ret_sinfo = [this, &check_ret_sinfo]( + Expr expr, StructInfo sinfo) { + // If the function is annotated as returning a tensor on a + // specific device, then that annotation may be propagated into + // the returned variable. + if (auto tensor_info = sinfo.as(); + tensor_info && tensor_info->vdevice.defined()) { + if (auto opt_var = expr.as()) { + auto var = opt_var.value(); + if (!known_vdevice_.count(var)) { + known_vdevice_.Set(var, tensor_info->vdevice.value()); + } + } + } - void AddToVDeviceMap(Expr expr, VDevice vdevice) { - ICHECK((vdevice_map_.count(expr) == 0) || (vdevice_map_[expr] == vdevice)) - << "Conflicted vdevice found."; - vdevice_map_.Set(expr, vdevice); + // If the function is annotated as returning a tuple of tensors, + // where some elements of the tuple are tensors that exist on a + // specific device, then those annotations may be propagated + // into the corresponding tensor annotations. + if (auto tuple_info = sinfo.as()) { + // The returned tuple is not necessarily an in-line tuple. In + // order to find the variables that are bound to the + // individual tuple elements, we may need to unwrap the + // variable bindings in order to find the tuple itself. This + // unwrapping is not required for the tensor case, as it would + // already be handled when propagating VDevice across variable + // definitions. + while (auto bound_value = LookupBinding(expr)) { + expr = bound_value.value(); + } + + // Even after unwrapping variable bindings, the resulting + // expression is not required to be a tuple literal. For + // example, the function may return one of its arguments as an + // output, or may return the result of a `relax::Call` that + // produces a tuple of outputs. + if (auto tuple = expr.as()) { + CHECK_EQ(tuple_info->fields.size(), tuple->fields.size()) + << "ValueError: " + << "Function returns a tuple with " << tuple->fields.size() << " elements, " + << "but is annotated as returning a tuple with " << tuple_info->fields.size() + << " elements"; + for (size_t i = 0; i < tuple->fields.size(); i++) { + check_ret_sinfo(tuple->fields[i], tuple_info->fields[i]); + } + } + } + }; + + check_ret_sinfo(func->body->body, func->ret_struct_info); } - Expr VisitExpr(const Expr& expr) { - auto visited_expr = ExprMutator::VisitExpr(expr); - if (vdevice_map_.count(visited_expr)) { - AddVDeviceToStuctInfo(visited_expr, vdevice_map_[visited_expr]); + void VisitVarDef(const Var& var) override { + if (auto tinfo = var->struct_info_.as(); + tinfo && tinfo->vdevice.defined()) { + known_vdevice_.Set(var, tinfo->vdevice.value()); } - return visited_expr; + ExprVisitor::VisitVarDef(var); } - Expr VisitExpr_(const FunctionNode* op) final { - Function func = GetRef(op); - auto* finfo = GetStructInfoAs(func); - if (finfo != nullptr) { - StructInfo ret = finfo->ret; - auto* tinfo = finfo->ret.as(); - if (tinfo != nullptr && tinfo->vdevice.defined()) { - AddToVDeviceMap(op->body, tinfo->vdevice.value()); - } - } - Function visited_func = Downcast(this->VisitExprPostOrder_(op)); - return visited_func; + void VisitBinding(const Binding& binding) override { + ExprVisitor::VisitBinding(binding); + binding_lookup_.Set(binding->var, GetBoundValue(binding)); } - Expr VisitExpr_(const SeqExprNode* op) final { - SeqExpr seq_expr = GetRef(op); - if (vdevice_map_.count(seq_expr)) { - AddToVDeviceMap(seq_expr->body, vdevice_map_[seq_expr]); + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) override { + ExprVisitor::VisitBinding_(binding, call); + if (call->op == hint_on_device_op_) { + auto vdevice = vdevice_lookup_(call->attrs); + known_vdevice_.Set(binding->var, vdevice); + + ICHECK_EQ(call->args.size(), 1); + if (auto arg_var = call->args[0].as()) { + hint_on_device_inputs_.Set(arg_var.value(), vdevice); + } } - SeqExpr visited_seqexpr = Downcast(this->VisitExprPostOrder_(op)); - return visited_seqexpr; } - BindingBlock VisitBindingBlock_(const BindingBlockNode* block) { - builder_->BeginBindingBlock(); - for (size_t i = block->bindings.size(); i > 0; --i) { - this->VisitBinding(block->bindings[i - 1]); - } - for (size_t i = bindings_.size(); i > 0; --i) { - builder_->EmitNormalized(bindings_[i - 1]); + Optional LookupBinding(const Expr& expr) const { + if (auto var = expr.as()) { + if (auto bound = binding_lookup_.Get(var.value())) { + return bound.value(); + } } - bindings_.clear(); - return builder_->EndBlock(); + return NullOpt; } - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { - builder_->BeginDataflowBlock(); - for (size_t i = block->bindings.size(); i > 0; --i) { - this->VisitBinding(block->bindings[i - 1]); - } - for (size_t i = bindings_.size(); i > 0; --i) { - builder_->EmitNormalized(bindings_[i - 1]); + // A lookup to identify the VDevice from the IRModule attributes, + // given the device type and device id from the R.hint_on_device + // attributes. + VDeviceLookup vdevice_lookup_; + + // A lookup of variable bindings, used to unwrap the variable + // bindings in functions that return a tuple. + Map binding_lookup_; + + // A map from Var to the VDevice they are known to occur on. This + // only contains variables whose location is explicitly known + // (e.g. output of `R.hint_on_device`, variables with explicit + // `VDevice` in their struct info), and does not include variables + // whose location is (e.g. input of `R.hint_on_device`). + Map known_vdevice_; + + // A map from Var to the VDevice they are expected to occur on. If + // a variable appears in both `known_vdevice_` and + // `hint_on_device_inputs_`, then `known_vdevice_` takes priority. + // + // For example, `B = R.hint_on_device(A, tvm.cuda(0))` implies that + // `B` must be located on "cuda:0". However, `A` may already have a + // `VDevice` annotation, or may be the output of `R.to_device`. + // Therefore, we only determine that `A` is located on "cuda:0" if + // no other annotation has already provided a known location for + // `A`. + Map hint_on_device_inputs_; + + // The `R.hint_on_device` operator. + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); +}; + +// Utility to determine which Var instances must be located on the +// same VDevice. +class VDeviceSetCollector : ExprVisitor { + public: + static Map> Collect(IRModule mod) { + VDeviceSetCollector visitor; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor(func.value()); + } } - bindings_.clear(); - return builder_->EndBlock(); + return visitor.var_to_co_located_vars_; } - void VisitBinding_(const VarBindingNode* binding) { - if (vdevice_map_.count(binding->var)) { - AddToVDeviceMap(binding->value, vdevice_map_[binding->var]); - AddVDeviceToStuctInfo(binding->var, vdevice_map_[binding->var]); - } - auto* tinfo = GetStructInfoAs(binding->var); - if (tinfo != nullptr && tinfo->vdevice.defined()) { - AddToVDeviceMap(binding->value, tinfo->vdevice.value()); - } - UpdateTensorStructInfo(binding->value, GetStructInfo(binding->var)); - Expr new_value = this->VisitExpr(binding->value); - if (!binding->var->struct_info_.defined()) { - UpdateTensorStructInfo(binding->var, GetStructInfo(new_value)); - } + private: + void VisitBinding(const Binding& binding) override { + auto cached = current_binding_; + current_binding_ = binding->var; + ExprVisitor::VisitBinding(binding); + current_binding_ = cached; + } - if (new_value.same_as(binding->value)) { - bindings_.push_back(GetRef(binding)); - } else { - bindings_.push_back(VarBinding(binding->var, new_value)); + void VisitExpr_(const CallNode* call) override { + if (call->op != to_vdevice_op_ && call->op != hint_on_device_op_) { + ExprVisitor::VisitExpr_(call); } } - Expr VisitExpr_(const CallNode* call) final { - // Record the vdevice information of each arguments of call - if (auto* sinfo = call->struct_info_.as()) { - if (sinfo->vdevice.defined() && call->op != to_vdevice_op_) { - Array call_args; - for (Expr arg : call->args) { - AddToVDeviceMap(arg, sinfo->vdevice.value()); - } - } + void VisitExpr_(const VarNode* op) override { + if (current_binding_) { + auto var = GetRef(op); + var_to_co_located_vars_[current_binding_.value()].push_back(var); + var_to_co_located_vars_[var].push_back(current_binding_.value()); } - return Downcast(ExprMutator::VisitExpr_(call)); } - /*! \brief The context IRModule. */ - IRModule mod_; - /*! \brief The bindings in reverse ordering. */ - Array bindings_; - /*! \brief The virtual device map. */ - Map vdevice_map_; + Optional current_binding_ = NullOpt; + + // Lookup from relax variable to the set of relax variables which + // must be located on the same device. For example, a trivial + // binding `B = A` implies that both `B` and `A` are on the same + // device. Similarly, `C = R.add(A,B)` implies that `A`, `B`, and + // `C` are all on the same device. + // + // In general, variables that are used as part of the same + // `relax::Call` operation must be located on the same device, with + // the exception of `R.hint_on_device` and `R.to_vdevice`, which may + // introduce a transfer across devices. + std::unordered_map> var_to_co_located_vars_; + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; -class HintOnDeviceRemover : public ExprMutator { - public: - explicit HintOnDeviceRemover(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {} +Map InferVDevice(IRModule mod) { + auto [explicit_annotations, hint_on_device_args] = DeviceHintCollector::Collect(mod); + + auto co_located_var_lookup = VDeviceSetCollector::Collect(mod); + + Map known_vdevice; + std::vector to_visit; + + // A helper function to propagate all `known_vdevice` entries based + // on the connections in `co_located_var_lookup`. + auto propagate = [&]() { + while (to_visit.size()) { + Var visiting = to_visit.back(); + to_visit.pop_back(); - IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - auto updated_func = Downcast(this->VisitExpr(func)); - builder_->UpdateFunction(gv, Downcast(updated_func)); + if (auto upstream_vars = co_located_var_lookup.Get(visiting)) { + auto vdevice = known_vdevice.at(visiting); + for (Var upstream_var : upstream_vars.value()) { + if (!known_vdevice.count(upstream_var)) { + known_vdevice.Set(upstream_var, vdevice); + to_visit.push_back(upstream_var); + } + } } } - return builder_->GetContextIRModule(); + }; + + // First round, mark variables whose vdevice is explicitly known + // (e.g. the output of R.hint_on_device), and propagate. + for (const auto& [var, vdevice] : explicit_annotations) { + to_visit.push_back(var); + known_vdevice.Set(var, vdevice); + } + propagate(); + + // Second round, mark variables whose vdevice is hinted at (e.g. the + // input of R.hint_on_device), and propagate. + for (const auto& [var, vdevice] : hint_on_device_args) { + if (!known_vdevice.count(var)) { + to_visit.push_back(var); + known_vdevice.Set(var, vdevice); + } } + propagate(); - private: - using ExprMutator::VisitExpr_; + return known_vdevice; +} - void AddToVDeviceMap(Expr expr, VDevice vdevice) { - ICHECK((vdevice_map_.count(expr) == 0) || (vdevice_map_[expr] == vdevice)) - << "Conflicted vdevice found."; - vdevice_map_.Set(expr, vdevice); - } +// Update the module to include the inferred VDevice annotations. +class VDeviceStructInfoUpdater : ExprMutator { + public: + static IRModule Apply(IRModule mod, Map vdevice_map) { + VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map); - VDevice LookupVDevice(int32_t device_type, int32_t device_id) { - Array vdevices = mod_->global_infos["vdevice"]; - if (vdevices.empty() || device_id < 0 || static_cast(device_id) >= vdevices.size()) { - LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found."; - } - for (auto vdev : vdevices) { - auto vdevice = Downcast(vdev); - int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { - return vdevice; + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto updated = Downcast(mutator(func.value())); + if (!updated.same_as(base_func)) { + updates->Add(gvar, updated); + } } } - LOG(WARNING) << "The specified device was not found in the global_infos"; - return VDevice(); - } - Expr VisitExpr(const Expr& expr) { - auto visited_expr = ExprMutator::VisitExpr(expr); - if (vdevice_map_.count(visited_expr)) { - AddVDeviceToStuctInfo(visited_expr, vdevice_map_[visited_expr]); + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); } - return visited_expr; - } - void VisitBinding_(const VarBindingNode* binding) { - Expr new_value = this->VisitExpr(binding->value); - UpdateTensorStructInfo(binding->var, GetStructInfo(new_value)); - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); - } else { - builder_->EmitNormalized(VarBinding(binding->var, new_value)); - } + return mod; } - Expr VisitExpr_(const CallNode* call) final { - // Replace hint_on_device with to_vdevice - if (call->op == hint_on_device_op_) { - // Find out the vdevice from global_infos - Expr data = call->args[0]; - auto attrs = call->attrs.as(); - int32_t device_type = attrs->dev_type; - int32_t device_id = attrs->dev_id; - VDevice dst_vdev = LookupVDevice(device_type, device_id); - // Insert to_vdevice if input are on different device - auto* tinfo = GetStructInfoAs(data); - if (tinfo != nullptr) { - if (!tinfo->vdevice.defined()) { - // Remove hint_on_device - AddVDeviceToStuctInfo(data, dst_vdev); - AddToVDeviceMap(data, dst_vdev); - return data; - } else if (tinfo->vdevice.value() != dst_vdev) { - // Call to_vdevice - ObjectPtr attrs = make_object(); - attrs->dst_vdevice = dst_vdev; - auto new_call = Call(to_vdevice_op_, {data}, Attrs(attrs), {}); - AddToVDeviceMap(new_call, dst_vdev); - return new_call; + private: + VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, Map vdevice_map) + : vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {} + + Var VisitVarDef(const Var& old_var) override { + auto var = ExprMutator::VisitVarDef(old_var); + if (auto tinfo = var->struct_info_.as()) { + if (auto opt = vdevice_map_.Get(old_var)) { + auto vdevice = opt.value(); + TensorStructInfo new_sinfo = [&]() { + if (tinfo->shape.defined()) { + return TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span); + } else { + return TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span); + } + }(); + + if (var->IsInstance()) { + var = DataflowVar(var->vid, new_sinfo, var->span); + } else { + var = Var(var->vid, new_sinfo, var->span); } } } - auto visited_call = ExprMutator::VisitExpr_(call); - visited_call->struct_info_ = NullOpt; - return builder_->Normalize(visited_call); + return var; } - /*! \brief The context IRModule. */ - IRModule mod_; - /*! \brief The virtual device map. */ - Map vdevice_map_; + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) override { + auto call = Downcast(ExprMutator::VisitExpr_(op)); + + if (call->op != hint_on_device_op_) { + return call; + } + + ICHECK_EQ(call->args.size(), 1); + auto arg = call->args[0]; + auto input_vdevice = Downcast(arg->struct_info_)->vdevice; + auto output_vdevice = vdevice_lookup_(call->attrs); + + if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) { + return arg; + } else { + ObjectPtr attrs = make_object(); + attrs->dst_vdevice = output_vdevice; + return Call(to_vdevice_op_, {arg}, Attrs(attrs), {}); + } + } + VDeviceLookup vdevice_lookup_; + Map vdevice_map_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; +} // namespace namespace transform { Pass RealizeVDevice() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - IRModule new_mod = HintOnDeviceRemover(mod).Run(); - return VDeviceRealizer(new_mod).Run(); + auto known_vdevices = InferVDevice(mod); + return VDeviceStructInfoUpdater::Apply(mod, known_vdevices); }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index f8d99eb3b59f..59c910d78865 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" + import tvm import tvm.testing from tvm.ir import VDevice @@ -202,6 +203,66 @@ def foo( verify(Input, Expect) +def test_tuple_func_ret(): + @I.ir_module + class Input: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda"), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + z: R.Tensor((2, 3), "float32"), + ) -> R.Tuple( + [ + R.Tensor((2, 3), "float32", "cuda"), + R.Tensor((2, 3), "float32", "cuda"), + ] + ): + with R.dataflow(): + lv0 = R.add(x, y) + gv = R.multiply(lv0, z) + R.output(gv) + return (gv, gv) + + @I.ir_module + class Expect: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda"), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32", "cuda"), + y: R.Tensor((2, 3), "float32", "cuda"), + z: R.Tensor((2, 3), "float32", "cuda"), + ) -> R.Tuple( + [ + R.Tensor((2, 3), "float32", "cuda"), + R.Tensor((2, 3), "float32", "cuda"), + ] + ): + with R.dataflow(): + lv0: R.Tensor((2, 3), "float32", "cuda") = R.add(x, y) + gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv0, z) + R.output(gv) + return (gv, gv) + + verify(Input, Expect) + + def test_multi_device(): @I.ir_module class Input: @@ -326,5 +387,34 @@ def foo( verify(Input, Expect) +def test_input_module_is_unmodified(): + def make_module(): + @I.ir_module + class Module: + I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + z: R.Tensor((2, 3), "float32"), + ) -> R.Tensor((2, 3), "float32"): + x1 = x + y1 = y + x2 = x1 + y2 = y1 + s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) + m = R.multiply(s, z) + return m + + return Module + + original = make_module() + expected = make_module() + + RealizeVDevice()(original) + tvm.ir.assert_structural_equal(original, expected) + + if __name__ == "__main__": tvm.testing.main() From 43093e7fc58f17fc5b7f943098ad5a157f57dfcb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 29 Jul 2024 12:20:03 -0500 Subject: [PATCH 2/2] lint fixes --- .../python/relax/test_transform_realize_vdevice.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index 59c910d78865..4c530d5e4931 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -220,12 +220,7 @@ def foo( x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32"), z: R.Tensor((2, 3), "float32"), - ) -> R.Tuple( - [ - R.Tensor((2, 3), "float32", "cuda"), - R.Tensor((2, 3), "float32", "cuda"), - ] - ): + ) -> R.Tuple([R.Tensor((2, 3), "float32", "cuda"), R.Tensor((2, 3), "float32", "cuda")]): with R.dataflow(): lv0 = R.add(x, y) gv = R.multiply(lv0, z) @@ -248,12 +243,7 @@ def foo( x: R.Tensor((2, 3), "float32", "cuda"), y: R.Tensor((2, 3), "float32", "cuda"), z: R.Tensor((2, 3), "float32", "cuda"), - ) -> R.Tuple( - [ - R.Tensor((2, 3), "float32", "cuda"), - R.Tensor((2, 3), "float32", "cuda"), - ] - ): + ) -> R.Tuple([R.Tensor((2, 3), "float32", "cuda"), R.Tensor((2, 3), "float32", "cuda")]): with R.dataflow(): lv0: R.Tensor((2, 3), "float32", "cuda") = R.add(x, y) gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv0, z)