Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/relax/testing/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,8 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any
res2 = vm[saved_name]()
tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7)
return res1


@tvm.register_func("test.vm.check_if_defined")
def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm:
return tvm.runtime.convert(obj is not None)
20 changes: 13 additions & 7 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,20 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Instruction::Arg value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
Expr expr = GetBoundValue(binding);

Instruction::Arg value = VisitExpr(expr);
if (expr.as<VarNode>()) {
// For a normalized relax module, there should be one
// register for each relax::Binding. This makes the Relax
// semantics of R.vm.kill_* operate the same as the Python
// "del" operator. These bindings may be removable by using
// relax.transform.CanonicalizeBindings earlier in lowering.
RegName new_reg = NewRegister();
builder_->EmitCall("vm.builtin.copy", {value}, new_reg);
value = Instruction::Arg::Register(new_reg);
}

this->var_arg_map_.insert({binding->var, value});
}
}
Expand Down
20 changes: 13 additions & 7 deletions src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,20 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Optional<PrimExpr> value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
Expr expr = GetBoundValue(binding);
Optional<PrimExpr> value = VisitExpr(expr);

if (expr.as<Var>() && value.defined()) {
// For a normalized relax module, there should be one
// register for each relax::Binding. This makes the Relax
// semantics of R.vm.kill_* operate the same as the Python
// "del" operator. These bindings may be removable by using
// relax.transform.CanonicalizeBindings earlier in lowering.
auto new_reg = NewRegister();
EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg);
value = RegListGet(new_reg);
}

this->var_map_.insert({binding->var, value});
}
}
Expand Down
37 changes: 2 additions & 35 deletions src/relax/transform/kill_after_last_use.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class CollectLastUsage : public ExprVisitor {
// completes, last_usage_of_ contains the last usage point. If
// this occurs in an output, then current_binding_ will be
// nullptr.
last_usage_of_[UnwrapTrivialBindings(op)] = current_binding_;
last_usage_of_[op] = current_binding_;
}

void VisitBinding_(const VarBindingNode* binding, const CallNode* val) override {
Expand All @@ -169,51 +169,18 @@ class CollectLastUsage : public ExprVisitor {
<< "but instead found " << val->args.size() << " arguments: " << val->args;
auto killed_object = val->args[0].as<VarNode>();
ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef<Call>(val);
killed_objects_.insert(UnwrapTrivialBindings(killed_object));
killed_objects_.insert(killed_object);
} else {
// Only recursively visit if it isn't one of the special cases.
ExprVisitor::VisitBinding_(binding, val);
}
}

void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override {
// Because the VM re-uses the same register for variable
// re-binding, we need to de-duplicate across trivial bindings in
// order to avoid calling `vm.kill_object` multiple times on the
// same register. In the future, this can be simplified by
// replacing the de-duplication in CodeGenVM with a call to
// CanonicalizeBindings.
trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(val)});

// Do not call ExprVisitor::VisitBinding_ here, as the trivial
// rebinding should not be treated as a point of use.
}

void VisitBinding_(const MatchCastNode* binding) override {
if (auto rebound = binding->value.as<VarNode>()) {
// Because CodeGenVM treats MatchCast nodes identically to
// VarBinding nodes, we must also de-duplicate at this level.
trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(rebound)});
} else {
ExprVisitor::VisitBinding_(binding);
}
}

void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) override {
constant_tensors_.insert(binding->var.get());
}

private:
const VarNode* UnwrapTrivialBindings(const VarNode* var) const {
while (true) {
if (auto it = trivial_bindings_.find(var); it != trivial_bindings_.end()) {
var = it->second;
} else {
return var;
}
}
}

// The current binding being visited, or nullptr if no binding is
// being visited.
const VarNode* current_binding_{nullptr};
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relax/test_kill_after_last_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class Expected:
def main(w: R.Tensor([16, 32], "float32")):
x = R.add(w, R.const(1, "float32"))
y = R.match_cast(x, R.Tensor([16, 32]))
z = R.add(y, R.const(1, "float32"))
_ = R.memory.kill_tensor(x)
z = R.add(y, R.const(1, "float32"))
_ = R.memory.kill_tensor(y)
return z

After = KillAfterLastUse()(Before)
Expand Down
59 changes: 58 additions & 1 deletion tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
import tvm.testing
import pytest

from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import ir as I
from tvm.relax.transform import LazyTransformParams
Expand Down Expand Up @@ -319,5 +322,59 @@ def main_transform_params() -> R.Tuple:
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_output():
target = "llvm"
dev = tvm.device(target)

@I.ir_module
class TransformModule:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32")
):
R.func_attr({"relax.force_pure": True})
param0 = params[0]
param1 = params[1]
transformed0 = R.permute_dims(param0, [1, 0, 2, 3])
transformed = (transformed0, param1)
return transformed

mod = TransformModule
mod = relax.transform.LazyTransformParams()(mod)
mod = relax.transform.LegalizeOps()(mod)
built = relax.build(mod, target=target)

params = [
np.random.random(size=(3, 64, 3, 3)).astype("float32"),
np.random.random(size=(16, 16, 3, 3)).astype("float32"),
]
transformed = {}
expected = [params[0].transpose(1, 0, 2, 3), params[1]]

@tvm.register_func("get_item", override=True)
def get_item(i):
return tvm.nd.array(params[i], dev)

@tvm.register_func("set_item", override=True)
def set_item(i, value):
assert i not in transformed, f"Set item called multiple times for index {i}"
transformed[i] = value.numpy()

vm = relax.VirtualMachine(built, dev)
vm["transform_params"]()

assert sorted(transformed) == list(range(len(transformed)))
transformed = [value for i, value in sorted(transformed.items())]
assert len(transformed) == len(expected)

for expected_i, transformed_i in zip(expected, transformed):
tvm.testing.assert_allclose(expected_i, transformed_i)


if __name__ == "__main__":
tvm.testing.main()
66 changes: 66 additions & 0 deletions tests/python/relax/test_vm_codegen_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,5 +421,71 @@ def main() -> R.Tensor((4,), dtype="float32"):
tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32"))


@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_preserve_trivial_bindings(exec_mode):
@I.ir_module
class mod:
@R.function
def main():
callback = R.ExternFunc("test.vm.check_if_defined")

storage = R.vm.alloc_storage(R.shape([16]), R.prim_value(0), R.dtype("uint8"))
alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([4]), R.dtype("float32"))
storage_alias = storage
alloc_alias = alloc

storage_before = callback(storage)
alloc_before = callback(alloc)
storage_alias_before = callback(storage_alias)
alloc_alias_before = callback(alloc_alias)

_ = R.vm.kill_object(storage)
_ = R.vm.kill_object(alloc)

storage_after = callback(storage)
alloc_after = callback(alloc)
storage_alias_after = callback(storage_alias)
alloc_alias_after = callback(alloc_alias)

return (
storage_before,
alloc_before,
storage_alias_before,
alloc_alias_before,
storage_after,
alloc_after,
storage_alias_after,
alloc_alias_after,
)

target = tvm.target.Target("llvm", host="llvm")
ex = codegen(mod, target, exec_mode)
dev = tvm.cpu()
vm = relax.VirtualMachine(ex, dev)

result_list = vm["main"]()

# Making a dictionary of expected results is purely to improve
# readability of test failures. This is equivalent to asserting
# on each element of the result array, but lets pytest give us a
# diff of the dictionaries in case of failure.
expected_results = {
"storage_before": True,
"alloc_before": True,
"storage_alias_before": True,
"alloc_alias_before": True,
"storage_after": False,
"alloc_after": False,
"storage_alias_after": True,
"alloc_alias_after": True,
}

observed_results = {
name: bool(tir_bool) for name, tir_bool in zip(expected_results.keys(), result_list)
}

assert observed_results == expected_results


if __name__ == "__main__":
tvm.testing.main()