From af1e2b3c31d1b394eb0b25d4a76fea322135f9ae Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Fri, 12 Aug 2022 20:59:08 +0200 Subject: [PATCH 1/9] [UMA] Added def test_tflite_model unit test case --- apps/uma/_template/patterns.py | 5 ++ python/tvm/testing/aot.py | 19 +++-- .../contrib/test_uma/test_uma_pipeline.py | 82 ++++++++++++++++++- .../test_uma/test_uma_vanilla_accelerator.py | 3 +- 4 files changed, 99 insertions(+), 10 deletions(-) diff --git a/apps/uma/_template/patterns.py b/apps/uma/_template/patterns.py index ce25fe4dff8e..1c841f2dbf1d 100644 --- a/apps/uma/_template/patterns.py +++ b/apps/uma/_template/patterns.py @@ -23,3 +23,8 @@ def conv2d_pattern(): pattern = is_op("nn.conv2d")(wildcard(), wildcard()) pattern = pattern.has_attr({"strides": [1, 1], "groups": 1}) return pattern + + +def dense_pattern(): + pattern = is_op("nn.dense")(wildcard(), wildcard()) + return pattern diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 5d7fb62cd204..46b377b3aefb 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -572,7 +572,7 @@ def _create_header_file(tensor_name, npy_data, output_path, data_linkage): def convert_to_relay( - tflite_model_buf, + tflite_model_buf, bind_params_by_name=True ): """Convert a tflite model buffer in a Relay module""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 @@ -588,7 +588,8 @@ def convert_to_relay( raise ImportError("The tflite package must be installed") mod, params = relay.frontend.from_tflite(tflite_model) - mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) + if bind_params_by_name: + mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) return mod, params @@ -836,7 +837,8 @@ def run_and_check_body(base_path): assert AOT_SUCCESS_TOKEN in run_log.read() if test_dir is None: - tmpdir = utils.tempdir() + tmpdir = utils.tempdir(keep_for_debug=True) + print(tmpdir.path) run_and_check_body(os.path.join(tmpdir.path, "test")) else: run_and_check_body(test_dir) @@ -931,20 +933,23 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): return dict(zip(output_tensor_names, out)) -def create_relay_module_and_inputs_from_tflite_file(tflite_model_file): +def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_params_by_name=True): """A helper function to create a Relay IRModule with inputs and params from a tflite file""" with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - mod, params = convert_to_relay(tflite_model_buf) + mod, params = convert_to_relay(tflite_model_buf, bind_params_by_name) inputs = dict() for param in mod["main"].params: name = str(param.name_hint) data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) - in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) - data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) + if dtype == "float32": + data = np.random.uniform(size=data_shape).astype("float32") + else: + in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) inputs[name] = data return mod, inputs, params diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index 49b4a196bbd4..2ee6c20d8bba 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -22,7 +22,7 @@ AOTTestModel, AOTTestRunner, generate_ref_data, - compile_and_run, + compile_and_run, create_relay_module_and_inputs_from_tflite_file, ) import tvm @@ -132,5 +132,83 @@ def test_mobilenet(): ) +def test_tflite_model(): + import os + import tensorflow as tf + + tflite_file = "/tmp/model0.tflite" + if not os.path.exists(tflite_file): + generate_tflite_file(tflite_file) + + pytest.importorskip("tflite") + + interpreter = tf.lite.Interpreter(model_path=tflite_file) + tf_model_details = interpreter.get_input_details() + mod, _, params = create_relay_module_and_inputs_from_tflite_file( + tflite_file, bind_params_by_name=False + ) + + uma_backend = VanillaAcceleratorBackend() + uma_backend.register() + target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c")) + target_c = tvm.target.Target("c") + + # Generation of test input and output + data_shape = [int(x) for x in mod["main"].params[0].type_annotation.shape] + data = np.random.uniform(size=data_shape).astype("float32") + input_list = {str(tf_model_details[0]["name"]): data} + output_list = generate_ref_data(mod, input_list, params) + + # UMA partitioning (needs to be done after generate_ref_data) + mod = uma_backend.partition(mod) + + aot_test_model = AOTTestModel(module=mod, inputs=input_list, outputs=output_list, params=params) + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": "greedy_by_size"} + ) + + compile_and_run( + aot_test_model, + test_runner, + interface_api="c", + use_unpacked_api=True, + workspace_byte_alignment=1, + debug_calculated_workspaces=False, + target=[target_c, target], + ) + + +def generate_tflite_file(tflite_filename): + import tensorflow as tf + + mnist = tf.keras.datasets.mnist + (x_train, y_train), (x_test, y_test) = mnist.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + tf_model = tf.keras.models.Sequential( + [ + tf.keras.Input(shape=(28, 28, 1)), + # tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"), + tf.keras.layers.Flatten(input_shape=(28, 28)), + #tf.keras.layers.Dense(32, activation="relu"), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10), + ] + ) + output = tf_model(x_train[:1]) + output = output.numpy() + tf.nn.softmax(output).numpy() + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss_fn(y_train[:1], output).numpy() + tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss_fn) + tf_model.fit(x_train, y_train, epochs=3) + tf_model.evaluate(x_test, y_test, verbose=2) + + tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model) + tflite_model = tflite_converter.convert() + with open(tflite_filename, "wb") as f: + f.write(tflite_model) + + if __name__ == "__main__": - tvm.testing.main() + test_tflite_model() + #tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py index e7a6b21d4ab5..043203e22a99 100644 --- a/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py +++ b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py @@ -24,7 +24,7 @@ ) from apps.uma._template.codegen import gen_includes -from apps.uma._template.patterns import conv2d_pattern +from apps.uma._template.patterns import conv2d_pattern, dense_pattern from tvm.relay.backend.contrib.uma import uma_available pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") @@ -40,6 +40,7 @@ def __init__(self): # Relay to Relay function registration ####################################################################### self._register_pattern("conv2d", conv2d_pattern()) + self._register_pattern("dense", dense_pattern()) ####################################################################### # Relay to TIR function registration From 653ad7e84e25e40ec48f0800936011e7c15f628c Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Tue, 16 Aug 2022 08:44:06 +0200 Subject: [PATCH 2/9] [UMA] Fix for wrong order of arguments in TIR-Primfuncs: Issue #12410 --- .../relay/backend/contrib/uma/api/lower.py | 22 +------------- python/tvm/testing/aot.py | 5 +--- .../contrib/test_uma/test_uma_pipeline.py | 29 ++++++++++--------- 3 files changed, 17 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/backend/contrib/uma/api/lower.py b/python/tvm/relay/backend/contrib/uma/api/lower.py index 34630949a151..334b6d101f82 100644 --- a/python/tvm/relay/backend/contrib/uma/api/lower.py +++ b/python/tvm/relay/backend/contrib/uma/api/lower.py @@ -60,27 +60,7 @@ def _lower_relay_to_tir(self, relay_prim_func: relay.Function) -> tvm.tir.PrimFu """ def _get_tensors(te_cached_func): - outputs = list(te_cached_func.outputs) - stack = [] - visited = set() - for output_ in outputs: - if output_ not in visited: - visited.add(output_) - stack.append(output_) - - args = [] - while len(stack) != 0: - tensor = stack.pop() - if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp): - args.append(tensor) - elif isinstance(tensor.op, tvm.te.tensor.ComputeOp): - inputs = tensor.op.input_tensors - for input_ in inputs: - if input_ not in visited: - visited.add(input_) - stack.append(input_) - - return args + outputs + return list(te_cached_func.inputs) + list(te_cached_func.outputs) lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE") te_cached_func = lower_to_te(relay_prim_func) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 46b377b3aefb..4d19633f365f 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -571,9 +571,7 @@ def _create_header_file(tensor_name, npy_data, output_path, data_linkage): header_file.write("};\n\n") -def convert_to_relay( - tflite_model_buf, bind_params_by_name=True -): +def convert_to_relay(tflite_model_buf, bind_params_by_name=True): """Convert a tflite model buffer in a Relay module""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: @@ -838,7 +836,6 @@ def run_and_check_body(base_path): if test_dir is None: tmpdir = utils.tempdir(keep_for_debug=True) - print(tmpdir.path) run_and_check_body(os.path.join(tmpdir.path, "test")) else: run_and_check_body(test_dir) diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index 2ee6c20d8bba..77f0d1b268c1 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -16,13 +16,16 @@ # under the License. import pytest +import os +import tensorflow as tf from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER from tvm.relay import transform, testing from tvm.testing.aot import ( AOTTestModel, AOTTestRunner, generate_ref_data, - compile_and_run, create_relay_module_and_inputs_from_tflite_file, + compile_and_run, + create_relay_module_and_inputs_from_tflite_file, ) import tvm @@ -133,12 +136,13 @@ def test_mobilenet(): def test_tflite_model(): - import os - import tensorflow as tf - - tflite_file = "/tmp/model0.tflite" - if not os.path.exists(tflite_file): - generate_tflite_file(tflite_file) + """ + End-to-end test of TF-Lite file using UMA + """ + tflite_file = "/tmp/model.tflite" + if os.path.exists(tflite_file): + os.remove(tflite_file) + generate_tflite_file(tflite_file) pytest.importorskip("tflite") @@ -179,17 +183,15 @@ def test_tflite_model(): def generate_tflite_file(tflite_filename): - import tensorflow as tf - mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 tf_model = tf.keras.models.Sequential( [ tf.keras.Input(shape=(28, 28, 1)), - # tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"), + tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"), tf.keras.layers.Flatten(input_shape=(28, 28)), - #tf.keras.layers.Dense(32, activation="relu"), + tf.keras.layers.Dense(32, activation="relu"), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10), ] @@ -200,7 +202,7 @@ def generate_tflite_file(tflite_filename): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss_fn(y_train[:1], output).numpy() tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss_fn) - tf_model.fit(x_train, y_train, epochs=3) + tf_model.fit(x_train, y_train, epochs=1) tf_model.evaluate(x_test, y_test, verbose=2) tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model) @@ -210,5 +212,4 @@ def generate_tflite_file(tflite_filename): if __name__ == "__main__": - test_tflite_model() - #tvm.testing.main() + tvm.testing.main() From 2d73d493ad526cfd4a38add3297a885c35a99a4d Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Tue, 16 Aug 2022 15:25:13 +0200 Subject: [PATCH 3/9] [UMA] update of test_tflite_model syntax to match TF version --- python/tvm/testing/aot.py | 2 +- tests/python/contrib/test_uma/test_uma_pipeline.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 4d19633f365f..1aace4fe5e5e 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -835,7 +835,7 @@ def run_and_check_body(base_path): assert AOT_SUCCESS_TOKEN in run_log.read() if test_dir is None: - tmpdir = utils.tempdir(keep_for_debug=True) + tmpdir = utils.tempdir() run_and_check_body(os.path.join(tmpdir.path, "test")) else: run_and_check_body(test_dir) diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index 77f0d1b268c1..59f1f0731516 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -186,6 +186,7 @@ def generate_tflite_file(tflite_filename): mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 + x_train, x_test = x_train.reshape(-1, 28, 28, 1), x_test.reshape(-1, 28, 28, 1) tf_model = tf.keras.models.Sequential( [ tf.keras.Input(shape=(28, 28, 1)), @@ -198,12 +199,10 @@ def generate_tflite_file(tflite_filename): ) output = tf_model(x_train[:1]) output = output.numpy() - tf.nn.softmax(output).numpy() - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - loss_fn(y_train[:1], output).numpy() - tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss_fn) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss(y_train[:1], output).numpy() + tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss) tf_model.fit(x_train, y_train, epochs=1) - tf_model.evaluate(x_test, y_test, verbose=2) tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model) tflite_model = tflite_converter.convert() From de1eb649780a3f0f9617389d8091c4c3efb9a851 Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Tue, 16 Aug 2022 19:27:22 +0200 Subject: [PATCH 4/9] [UMA] test_tflite_model removing import of tflite and tensorflow from i386 import --- tests/python/contrib/test_uma/test_uma_pipeline.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index 59f1f0731516..c5120e0c916e 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -16,6 +16,9 @@ # under the License. import pytest +pytest.importorskip("tflite") +pytest.importorskip("tensorflow") + import os import tensorflow as tf from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER From c71a454a80b8b264495906f6db46d5474614a3cc Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Wed, 17 Aug 2022 10:38:22 +0200 Subject: [PATCH 5/9] [UMA] test_tflite_model removing import of tflite and tensorflow from i386 import --- tests/python/contrib/test_uma/test_uma_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py index c5120e0c916e..0d7ed3ab9587 100644 --- a/tests/python/contrib/test_uma/test_uma_pipeline.py +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -16,6 +16,7 @@ # under the License. import pytest + pytest.importorskip("tflite") pytest.importorskip("tensorflow") From b54889dd5bc26215d74d1f5a9ab58fb1396338b1 Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Sat, 20 Aug 2022 14:12:40 +0200 Subject: [PATCH 6/9] [AOT] added support for float32 to create_relay_module_and_inputs_from_tflite_file --- python/tvm/testing/aot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 1aace4fe5e5e..91bb2f7cdbe7 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -943,7 +943,8 @@ def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_para data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) if dtype == "float32": - data = np.random.uniform(size=data_shape).astype("float32") + in_min, in_max = (np.finfo(dtype).min, np.finfo(dtype).max) + data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype("float32") else: in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) From 7f3b4d09757bd58ca728baed84d21ae9ac885871 Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Sun, 21 Aug 2022 21:02:32 +0200 Subject: [PATCH 7/9] [AOT] Added support for float64 and int16, int64 to create_relay_module_and_inputs_from_tflite_file --- python/tvm/testing/aot.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 91bb2f7cdbe7..7967a5fd0f32 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -942,12 +942,18 @@ def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_para name = str(param.name_hint) data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) - if dtype == "float32": - in_min, in_max = (np.finfo(dtype).min, np.finfo(dtype).max) - data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype("float32") - else: + if dtype == "float32" or dtype == "float64": + # Since np.random.uniform only allows the ranges of float32, + # at first float32 is used and scaled afterwards, if necessary + in_min, in_max = (np.finfo("float32").min, np.finfo("float32").max) + data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype) + scale = np.finfo(dtype).min / np.finfo("float32").min + data *= scale + elif dtype == "int32" or dtype == "int64" or dtype == "int16": in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) + else: + raise TypeError("Unsupported type used") inputs[name] = data return mod, inputs, params From 07b02268a0a12767449ca10dca1c8a690fceb08b Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Sun, 21 Aug 2022 21:23:04 +0200 Subject: [PATCH 8/9] [AOT] Added support for float64 and int16, int64 to create_relay_module_and_inputs_from_tflite_file --- python/tvm/testing/aot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 7967a5fd0f32..1301b861170a 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -942,14 +942,14 @@ def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_para name = str(param.name_hint) data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) - if dtype == "float32" or dtype == "float64": + if dtype in ("float32", "float64"): # Since np.random.uniform only allows the ranges of float32, # at first float32 is used and scaled afterwards, if necessary in_min, in_max = (np.finfo("float32").min, np.finfo("float32").max) data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype) scale = np.finfo(dtype).min / np.finfo("float32").min data *= scale - elif dtype == "int32" or dtype == "int64" or dtype == "int16": + elif dtype in ("int16", "int32", "int64"): in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) else: From f816a96418e1cdab7958051f2ecc43160d9cc12b Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Mon, 22 Aug 2022 15:19:18 +0200 Subject: [PATCH 9/9] [AOT] Added support for np float and int types to create_relay_module_and_inputs_from_tflite_file --- python/tvm/testing/aot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 1301b861170a..563a7dff4a50 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -942,18 +942,18 @@ def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_para name = str(param.name_hint) data_shape = [int(i) for i in param.type_annotation.shape] dtype = str(param.type_annotation.dtype) - if dtype in ("float32", "float64"): + if np.issubdtype(dtype, np.floating): # Since np.random.uniform only allows the ranges of float32, - # at first float32 is used and scaled afterwards, if necessary - in_min, in_max = (np.finfo("float32").min, np.finfo("float32").max) + # at first float16 is used and scaled afterwards, if necessary. + in_min, in_max = (np.finfo("float16").min, np.finfo("float16").max) data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype) - scale = np.finfo(dtype).min / np.finfo("float32").min + scale = np.finfo(dtype).min / np.finfo("float16").min data *= scale - elif dtype in ("int16", "int32", "int64"): + elif np.issubdtype(dtype, np.integer): in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) else: - raise TypeError("Unsupported type used") + raise TypeError(f"Type {dtype} not supported") inputs[name] = data return mod, inputs, params