Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
10 changes: 10 additions & 0 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def auto_schedule_topi(func_name, outs):
A tuned schedule or none (if not tuned) in the final build mode;
None in the tracing mode so that the fallback topi schedule will be used.
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
prepare_input_map,
Expand Down Expand Up @@ -376,6 +377,15 @@ def auto_schedule_topi(func_name, outs):
return schedule


@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights")
def te_compiler_update_weights(function_weights):
"""A callback for updating the weights of extracted tasks."""
env = TracingEnvironment.current
if env is not None:
for key in env.wkl_key_to_weight:
env.wkl_key_to_weight[key] = function_weights[key[0]]


def tensor_no_check_call(self, *indices):
"""An indexing function without any check.
This is the same as `tvm.te.Tensor::__call__` except that the safety
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id):

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3)
else:
total_latency_str = "-"
print(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand All @@ -444,7 +444,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand Down
24 changes: 23 additions & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar
from tvm.ir import RelayExpr, GlobalVar, Node

from .base import RelayNode
from . import _ffi_api
Expand Down Expand Up @@ -538,3 +538,25 @@ def bind(expr, binds):
The expression or function after binding.
"""
return _ffi_api.Bind(expr, binds)


@tvm._ffi.register_object("relay.StorageInfo")
class StorageInfo(Node):
"""StorageInfo

The static storage information produced by memory planning.
Contains the storage ids where expressions are stored, the
type of the "virtual devices" the expressions are stored on,
and the sizes of each storage element."""

@property
def storage_ids(self):
return _ffi_api.StorageInfoStorageIds(self)

@property
def device_types(self):
return _ffi_api.StorageInfoDeviceTypes(self)

@property
def storage_sizes(self):
return _ffi_api.StorageInfoStorageSizes(self)
10 changes: 7 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,18 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
}

if (target->kind->device_type == kDLCPU && target_host == target) {
ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
// TODO(@jroesch): This check is no longer true we need to figure out if we care about this.
// We need to relax this check for just TIR functions.
// ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
// << "and host_target are both llvm target."
// << "\n";
}

return {mhost, mdevice};
}

// Can we make this take one annotated IRModule?
//
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
auto pass_ctx = transform::PassContext::Current();
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor {
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
Expand All @@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor {
<< "(i.e functions composed of fusable operator invocations)";
}

auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), ext_func->func_name);
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
return;
}

Expand All @@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor {
}
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);

if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
Expand All @@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor {
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
}

void VisitExpr_(const VarNode* op) override {
Expand Down
Loading