Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,12 @@ template<typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) final {
detail::AttrNormalVisitor vis(v);
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}

void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v);
::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}

Expand All @@ -761,7 +761,7 @@ class AttrsNode : public BaseAttrsNode {
}
return false;
};
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind);
auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_;
} else {
Expand All @@ -779,14 +779,14 @@ class AttrsNode : public BaseAttrsNode {
}
return false;
};
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind);
auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_;
}
// error handling, slow path
if (hit_count * 2 != args.size() && !allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
detail::AttrExistVisitor visitor;
::tvm::detail::AttrExistVisitor visitor;
visitor.key_ = args[i].operator std::string();
self()->__VisitAttrs__(visitor);
if (!visitor.exist_) {
Expand All @@ -803,7 +803,7 @@ class AttrsNode : public BaseAttrsNode {
}

Array<AttrFieldInfo> ListFieldInfo() const final {
detail::AttrDocVisitor visitor;
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
return visitor.fields_;
}
Expand All @@ -813,13 +813,13 @@ class AttrsNode : public BaseAttrsNode {
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other, equal);
::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

size_t ContentHash(AttrsHash hasher) const final {
detail::AttrsHashVisitor visitor(hasher);
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor);
return visitor.result_;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
runtime::detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/relay/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,26 @@ def evaluate(self, expr, params=None):
"""
if params:
scope_builder = ScopeBuilder()
for key, value in params:
for key in params:
value = params[key]
scope_builder.let(key, value)
scope_builder.ret(expr)
expr = scope_builder.get()

if isinstance(expr, Function):
assert not ir_pass.free_vars(expr)

return self._make_executor(expr)
executor = self._make_executor(expr)

# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if isinstance(expr, (Function, GlobalVar)):
return executor
else:
return executor()


class Interpreter(Executor):
Expand All @@ -168,10 +179,14 @@ def _interp_wrapper(*args):
self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.mod, opt_expr)
else:
elif isinstance(expr, Function):
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.mod, opt_expr)
else:
assert not args
opt_expr = self.optimize(expr)
return _interpreter.evaluate(self.mod, opt_expr)

return _interp_wrapper

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, Op
from .op import get, register, register_schedule, register_compute, Op

# Operators
from .reduce import *
Expand Down
Loading