Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/uma/_template/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ def conv2d_pattern():
pattern = is_op("nn.conv2d")(wildcard(), wildcard())
pattern = pattern.has_attr({"strides": [1, 1], "groups": 1})
return pattern


def dense_pattern():
pattern = is_op("nn.dense")(wildcard(), wildcard())
return pattern
22 changes: 1 addition & 21 deletions python/tvm/relay/backend/contrib/uma/api/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,7 @@ def _lower_relay_to_tir(self, relay_prim_func: relay.Function) -> tvm.tir.PrimFu
"""

def _get_tensors(te_cached_func):
outputs = list(te_cached_func.outputs)
stack = []
visited = set()
for output_ in outputs:
if output_ not in visited:
visited.add(output_)
stack.append(output_)

args = []
while len(stack) != 0:
tensor = stack.pop()
if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp):
args.append(tensor)
elif isinstance(tensor.op, tvm.te.tensor.ComputeOp):
inputs = tensor.op.input_tensors
for input_ in inputs:
if input_ not in visited:
visited.add(input_)
stack.append(input_)

return args + outputs
return list(te_cached_func.inputs) + list(te_cached_func.outputs)

lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE")
te_cached_func = lower_to_te(relay_prim_func)
Expand Down
25 changes: 17 additions & 8 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,7 @@ def _create_header_file(tensor_name, npy_data, output_path, data_linkage):
header_file.write("};\n\n")


def convert_to_relay(
tflite_model_buf,
):
def convert_to_relay(tflite_model_buf, bind_params_by_name=True):
"""Convert a tflite model buffer in a Relay module"""
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
Expand All @@ -588,7 +586,8 @@ def convert_to_relay(
raise ImportError("The tflite package must be installed")

mod, params = relay.frontend.from_tflite(tflite_model)
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
if bind_params_by_name:
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
return mod, params


Expand Down Expand Up @@ -931,20 +930,30 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
return dict(zip(output_tensor_names, out))


def create_relay_module_and_inputs_from_tflite_file(tflite_model_file):
def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_params_by_name=True):
"""A helper function to create a Relay IRModule with inputs
and params from a tflite file"""
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
mod, params = convert_to_relay(tflite_model_buf)
mod, params = convert_to_relay(tflite_model_buf, bind_params_by_name)

inputs = dict()
for param in mod["main"].params:
name = str(param.name_hint)
data_shape = [int(i) for i in param.type_annotation.shape]
dtype = str(param.type_annotation.dtype)
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
if np.issubdtype(dtype, np.floating):
# Since np.random.uniform only allows the ranges of float32,
# at first float16 is used and scaled afterwards, if necessary.
in_min, in_max = (np.finfo("float16").min, np.finfo("float16").max)
data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype)
scale = np.finfo(dtype).min / np.finfo("float16").min
data *= scale
elif np.issubdtype(dtype, np.integer):
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
else:
raise TypeError(f"Type {dtype} not supported")
inputs[name] = data

return mod, inputs, params
82 changes: 82 additions & 0 deletions tests/python/contrib/test_uma/test_uma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
# under the License.

import pytest

pytest.importorskip("tflite")
pytest.importorskip("tensorflow")

import os
import tensorflow as tf
from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER
from tvm.relay import transform, testing
from tvm.testing.aot import (
AOTTestModel,
AOTTestRunner,
generate_ref_data,
compile_and_run,
create_relay_module_and_inputs_from_tflite_file,
)

import tvm
Expand Down Expand Up @@ -132,5 +139,80 @@ def test_mobilenet():
)


def test_tflite_model():
"""
End-to-end test of TF-Lite file using UMA
"""
tflite_file = "/tmp/model.tflite"
if os.path.exists(tflite_file):
os.remove(tflite_file)
generate_tflite_file(tflite_file)

pytest.importorskip("tflite")

interpreter = tf.lite.Interpreter(model_path=tflite_file)
tf_model_details = interpreter.get_input_details()
mod, _, params = create_relay_module_and_inputs_from_tflite_file(
tflite_file, bind_params_by_name=False
)

uma_backend = VanillaAcceleratorBackend()
uma_backend.register()
target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c"))
target_c = tvm.target.Target("c")

# Generation of test input and output
data_shape = [int(x) for x in mod["main"].params[0].type_annotation.shape]
data = np.random.uniform(size=data_shape).astype("float32")
input_list = {str(tf_model_details[0]["name"]): data}
output_list = generate_ref_data(mod, input_list, params)

# UMA partitioning (needs to be done after generate_ref_data)
mod = uma_backend.partition(mod)

aot_test_model = AOTTestModel(module=mod, inputs=input_list, outputs=output_list, params=params)
test_runner = AOTTestRunner(
pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": "greedy_by_size"}
)

compile_and_run(
aot_test_model,
test_runner,
interface_api="c",
use_unpacked_api=True,
workspace_byte_alignment=1,
debug_calculated_workspaces=False,
target=[target_c, target],
)


def generate_tflite_file(tflite_filename):
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.reshape(-1, 28, 28, 1), x_test.reshape(-1, 28, 28, 1)
tf_model = tf.keras.models.Sequential(
[
tf.keras.Input(shape=(28, 28, 1)),
tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"),
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(32, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10),
]
)
output = tf_model(x_train[:1])
output = output.numpy()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss(y_train[:1], output).numpy()
tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss)
tf_model.fit(x_train, y_train, epochs=1)

tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
tflite_model = tflite_converter.convert()
with open(tflite_filename, "wb") as f:
f.write(tflite_model)


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from apps.uma._template.codegen import gen_includes

from apps.uma._template.patterns import conv2d_pattern
from apps.uma._template.patterns import conv2d_pattern, dense_pattern
from tvm.relay.backend.contrib.uma import uma_available

pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available")
Expand All @@ -40,6 +40,7 @@ def __init__(self):
# Relay to Relay function registration
#######################################################################
self._register_pattern("conv2d", conv2d_pattern())
self._register_pattern("dense", dense_pattern())

#######################################################################
# Relay to TIR function registration
Expand Down