Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/runtime/c_backend_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
*/
TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes);

/*! \brief Generate a pointer usable as kTVMStr in a TVMRetValue
*
* While TVMArgValue uses `const char*` to represent string arguments,
* TVMRetValue represents string return values as a heap-allocated C++
* `std::string` objects.
*
* \param c_str A null-terminated C-style string
*
* \return A pointer to the heap-allocated C++ `std::string`.
*/
void* TVMBackendStringRetValue(const char* c_str);

#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ constexpr const char* tvm_module_main = "__tvm_main__";
constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_get_tir_function_metadata = "__tvm_get_tir_function_metadata__";
/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */
constexpr const char* tvm_entrypoint_suffix = "run";
} // namespace symbol
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ class TVMRetValue : public TVMPODValue_ {
*/
static TVMRetValue MoveFromCHost(TVMValue value, int type_code) {
// Can move POD and everything under the object system.
ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle);
ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle ||
type_code == kTVMStr);
TVMRetValue ret;
ret.value_ = value;
ret.type_code_ = type_code;
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,13 @@ TVM_DLL Pass LowerCustomDatatypes();
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Generate module metadata describing function signatures
*
* \return The pass.
*/
TVM_DLL Pass GenerateFunctionSignatureMetadata();

/*!
* \brief Annotate locations that should be run on the device
*
Expand Down
120 changes: 111 additions & 9 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@

# pylint: disable=invalid-name, unused-import, import-outside-toplevel, inconsistent-return-statements
"""Runtime Module namespace."""
import os

import ctypes
import difflib
import json
import os
import struct
from typing import Sequence

from typing import Sequence, Iterable, Iterator, Optional

import numpy as np

from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
Expand Down Expand Up @@ -100,7 +105,20 @@ class ModulePropertyMask(object):
class Module(object):
"""Runtime Module."""

__slots__ = ["handle", "_entry", "entry_name"]
__slots__ = ["handle", "_entry", "entry_name", "_cached_metadata"]

_GET_TIR_FUNCTION_METADATA = "__tvm_get_tir_function_metadata__"
"""The function to retrieve TIR metadata, if any

This is for internal use only. If present in the module, this
should be a function that returns a JSON string describing the
functions provided by the module, along with the signatures of
those functions.

This parameter corresponds to the C++ value
`tvm::runtime::symbol::tvm_get_tir_function_metadata`, located in
`#include <tvm/runtime/module.h>`.
"""

def __init__(self, handle):
self.handle = handle
Expand Down Expand Up @@ -163,18 +181,54 @@ def get_function(self, name, query_imports=False):

Returns
-------
f : tvm.runtime.PackedFunc
func : tvm.runtime.PackedFunc
The result function.
"""
func = self._get_function(name, query_imports=query_imports)
if func is None:
nearby_names = difflib.get_close_matches(name, self.keys())
if nearby_names:
message = (
f"Module has no function '{name}'. "
f"The module does contain functions with similar names: "
f"{nearby_names}."
)
else:
message = (
f"Module has no function '{name}'. "
f"The module does not contain any function with a similar name."
)
raise KeyError(message)

return func

def _get_function(self, name, query_imports=False) -> Optional[PackedFunc]:
"""Get function from the module.

Parameters
----------
name : str
The name of the function

query_imports : bool
Whether also query modules imported by this module.

Returns
-------
func : Optional[tvm.runtime.PackedFunc]
The result function, or None if it cannot be found
"""
ret_handle = PackedFuncHandle()
check_call(
_LIB.TVMModGetFunction(
self.handle, c_str(name), ctypes.c_int(query_imports), ctypes.byref(ret_handle)
)
)
if not ret_handle.value:
raise AttributeError(f"Module has no function '{name}'")
return PackedFunc(ret_handle, False)
# pylint: disable=using-constant-test
if ret_handle.value:
return PackedFunc(ret_handle, False)
else:
return None

def import_module(self, module):
"""Add module to the import list of current one.
Expand All @@ -186,11 +240,59 @@ def import_module(self, module):
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))

def __getitem__(self, name):
def __getitem__(self, name: str) -> PackedFunc:
"""Return the PackedFunc associated with the gven name

Parameters
----------
name: str
The name of the function to be returned

Returns
-------
PackedFunc

"""

if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
raise TypeError(f"Module.__getitem__ expects a string, but received {type(name)}")
return self.get_function(name)

def __contains__(self, key: str) -> bool:
return key in self.keys()

@property
def _metadata(self):
if hasattr(self, "_cached_metadata"):
# pylint: disable=access-member-before-definition
return self._cached_metadata

metadata_func = self._get_function(Module._GET_TIR_FUNCTION_METADATA)
if metadata_func is None:
raise RuntimeError("Cannot find function metadata in runtime.Module")

self._cached_metadata = json.loads(metadata_func())
return self._cached_metadata

def keys(self) -> Sequence[str]:
"""Return a list of functions in the module

Returns
-------
Sequence[str]
The functions in the module
"""
for function in self._metadata["functions"]:
yield function

def values(self) -> Iterator[PackedFunc]:
for key in self.keys():
yield self[key]

def items(self) -> Iterator[PackedFunc]:
for key in self.keys():
yield key, self[key]

def __eq__(self, other):
return self.handle.value == other.handle.value

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,9 @@ def ret(val):
"""

val = convert(val)
if isinstance(val, tvm.runtime.String):
val = StringImm(val)

return call_intrin(val.dtype, "tir.ret", val)


Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,22 @@ def MakeUnpackedAPI():
return _ffi_api.MakeUnpackedAPI() # type: ignore


def GenerateFunctionSignatureMetadata():
"""Generate metadata describing the function signatures

Generate a metadata function that returns a JSON-formatted string,
describing the functions available within the IRModule and their
signatures.

Returns
-------
fpass : tvm.transform.Pass
The result pass

"""
return _ffi_api.GenerateFunctionSignatureMetadata() # type: ignore


def AnnotateDeviceRegions():
"""Annotate locations that should be run on the device

Expand Down
18 changes: 15 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,22 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

// FPComputeLegalize uses the target attrs added by BindTarget, so it must come first
// AnnotateEntryFunc inspects user-defined functions to provide a
// default function to call. Therefore, it should appear before any
// passes that would generate additional PrimFuncs.
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());

// GenerateFunctionSignatureMetadata produces a function with
// `tir.is_host_func`, relying on `BindTarget` to generate the
// correct target annotation. Therefore, it must come before
// BindTarget.
mixed_pass_list.push_back(tir::transform::GenerateFunctionSignatureMetadata());

// Many later passes, such as FP8ComputeLegalize, use the target
// attrs added by BindTarget. Therefore, BindTarget should occur as
// early as possible.
mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());

// VerifyVTCMLimit must occur before LowerVtcmAlloc
Expand All @@ -580,8 +594,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::VerifyMemory());

mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());

bool detect_global_barrier =
pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
if (detect_global_barrier) {
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,12 @@ int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
return 0;
}

void* TVMBackendStringRetValue(const char* c_str) {
// The `std::string` allocated here is immediately stored in a
// `TVMRetValue`, which takes ownership of the object.
return new std::string(c_str);
}

int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); }

int TVMByteArrayFree(TVMByteArray* arr) {
Expand Down
23 changes: 22 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,28 @@ void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
const auto& str = op->value;
os << '"';
for (size_t i = 0; i < str.size(); i++) {
char c = str.c_str()[i];
switch (c) {
case '\n':
os << "\\n";
break;

case '\\':
case '"':
case '?':
os << '\\' << c;
break;

default:
os << c;
break;
}
}

os << '"';
}

template <typename T>
Expand Down
Loading