From 10236cd37d9a4d7c09718bd706f1717014b717cc Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Fri, 17 Apr 2020 21:08:17 +0000 Subject: [PATCH 1/3] [Tutorial - QNN] Prequantized MXNet model compilation. --- .../frontend/deploy_prequantized_mxnet.py | 220 ++++++++++++++++++ ...ized.py => deploy_prequantized_pytorch.py} | 9 +- 2 files changed, 226 insertions(+), 3 deletions(-) create mode 100644 tutorials/frontend/deploy_prequantized_mxnet.py rename tutorials/frontend/{deploy_prequantized.py => deploy_prequantized_pytorch.py} (96%) diff --git a/tutorials/frontend/deploy_prequantized_mxnet.py b/tutorials/frontend/deploy_prequantized_mxnet.py new file mode 100644 index 000000000000..8b1b2fe2dba4 --- /dev/null +++ b/tutorials/frontend/deploy_prequantized_mxnet.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy a Framework-prequantized Model with TVM - Part 2 (MXNet) +============================================== +**Author**: `Animesh Jain `_ + +Welcome to Part 2 of Deploy Framework-Prequantized Model with TVM. In this tutorial, we will start +with a FP32 MXNet graph, quantize it using MXNet, and then compile and execute it via TVM. + +For more details on quantizing the model using MXNet, readers are encouraged to go through `Model +Quantization with Calibration Examples +`_. + +Pre-requisites + pip3 install mxnet-mkl --user + pip3 install gluoncv --user +""" + + +############################################################################### +# Necessary imports +# ----------------- +import os + +import mxnet as mx +from gluoncv.model_zoo import get_model +from mxnet.contrib.quantization import * + +import tvm +from tvm import relay + + +############################################################################### +# Helper functions +# ---------------- +def download_calib_dataset(dataset_url, calib_dataset): + """ Download calibration dataset. """ + print('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) + mx.test_utils.download(dataset_url, calib_dataset) + + +def prepare_calib_dataset(data_shape, label_name): + """ Preprocess the dataset and set up the data iterator. """ + mean_args = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939} + std_args = {'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375} + combine_mean_std = {} + combine_mean_std.update(mean_args) + combine_mean_std.update(std_args) + data = mx.io.ImageRecordIter(path_imgrec='data/val_256_q90.rec', + label_width=1, + preprocess_threads=60, + batch_size=1, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=True, + **combine_mean_std) + return data + + +def get_mxnet_fp32_model(): + """ Read the MXNet symbol. """ + model_name = 'resnet50_v1' + dir_path = os.path.dirname(os.path.realpath(__file__)) + block = get_model(name=model_name, pretrained=True) + + # Convert the model to symbol format. + block.hybridize() + data = mx.sym.Variable('data') + sym = block(data) + sym = mx.sym.SoftmaxOutput(data=sym, name='softmax') + params = block.collect_params() + args = {} + auxs = {} + for param in params.values(): + v = param._reduce() + k = param.name + if 'running' in k: + auxs[k] = v + else: + args[k] = v + return sym, args, auxs + + +def quantize_model(sym, arg_params, aux_params, data, ctx, label_name): + """ Quantize the model using MXNet. """ + return quantize_model_mkldnn(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=ctx, + calib_mode='naive', calib_data=data, + num_calib_examples=5, + quantized_dtype='auto', + label_names=(label_name,)) + + +def run_mxnet(qsym, data, batch, ctx, label_name): + """ Run MXNet pre-quantized model inference. """ + mod = mx.mod.Module(symbol=qsym, context=[ctx], label_names=[label_name, ]) + mod.bind(for_training=False, data_shapes=data.provide_data, label_shapes=data.provide_label) + mod.set_params(qarg_params, qaux_params) + mod.forward(batch, is_train=False) + mxnet_res = mod.get_outputs()[0].asnumpy() + mxnet_pred = np.squeeze(mxnet_res).argsort()[-5:][::-1] + return mxnet_pred + + +def run_tvm(graph, lib, params, batch): + """ Run TVM compiler model inference. """ + from tvm.contrib import graph_runtime + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(**params) + rt_mod.set_input('data', batch.data[0].asnumpy()) + rt_mod.run() + tvm_res = rt_mod.get_output(0).asnumpy() + tvm_pred = np.squeeze(tvm_res).argsort()[-5:][::-1] + return tvm_pred, rt_mod + + +# Initialize variables. +label_name = 'softmax_label' +data_shape = (3, 224, 224) +ctx = mx.cpu(0) + + +############################################################################### +# MXNet quantization and inference. +# --------------------------------- + +# Download and prepare calibrarion dataset. +download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec') +data = prepare_calib_dataset(data_shape, label_name) + +# Get a FP32 Resnet 50 MXNet model. +sym, arg_params, aux_params = get_mxnet_fp32_model() + +# Quantize the MXNet model using MXNet quantizer. +qsym, qarg_params, qaux_params = quantize_model(sym, arg_params, aux_params, + data, ctx, label_name) + +# Get the testing image from the MXNet data iterator. +batch = data.next() + +# Run MXNet inference on the quantized model. +mxnet_pred = run_mxnet(qsym, data, batch, ctx, label_name) + + +############################################################################### +# TVM compilation of pre-quantized model and inference. +# --------------------------------- + +# Use MXNet-Relay parser. Note that the frontend parser call is exactly same as frontend parser call +# for a FP32 model. +input_shape = [1] + list(data_shape) +input_dict = {'data': input_shape} +mod, params = relay.frontend.from_mxnet(qsym, + dtype={}, + shape=input_dict, + arg_params=qarg_params, + aux_params=qaux_params) + +# Please inspect the module. You will have QNN operators like requantize, quantize, conv2d. +# print(mod) + +# Compile Relay module. Set the target platform. Replace the target with the your target type. +target = 'llvm -mcpu=cascadelake' +with relay.build_config(opt_level=3): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + +# Call inference on the compiled module. +tvm_pred, rt_mod = run_tvm(graph, lib, params, batch) + + +############################################################################### +# Accuracy comparison. +# -------------------- + +# Print the top-5 labels for MXNet and TVM inference. Note that final tensors can slightly differ +# between MXNet and TVM quantized inference, but the classification accuracy is not significantly +# affected. Output of the following code is as follows +# +# TVM Top-5 labels: [236 211 178 165 168] +# MXNet Top-5 labels: [236 211 178 165 168] +print("TVM Top-5 labels:", tvm_pred) +print("MXNet Top-5 labels:", mxnet_pred) + + +########################################################################## +# Measure performance. +# -------------------- +# Here we give an example of how to measure performance of TVM compiled models. +n_repeat = 100 # should be bigger to make the measurement more accurate +ctx = tvm.cpu(0) +ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, repeat=n_repeat) +prof_res = np.array(ftimer().results) * 1e3 +print("Elapsed average ms:", np.mean(prof_res)) + +########################################################################## +# Notes +# ----- +# 1) On Intel Cascadelake server, the performance is 2.01 ms. +# 2) Auto-tuning can potentially improve this performance. Please follow the tutorial at +# `Auto-tuning a convolution network for x86 CPU +# `_. diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized_pytorch.py similarity index 96% rename from tutorials/frontend/deploy_prequantized.py rename to tutorials/frontend/deploy_prequantized_pytorch.py index 40279778c045..51c52f8e58b8 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized_pytorch.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Deploy a Framework-prequantized Model with TVM +Deploy a Framework-prequantized Model with TVM - Part 1 (PyTorch) ============================================== **Author**: `Masahiro Masuda `_ @@ -24,8 +24,11 @@ the quantization story in TVM can be found `here `_. -Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. -Once loaded, we can run compiled, quantized models on any hardware TVM supports. +In this series of tutorials, we demonstrate how to load and run models quantized by PyTorch (Part +1), MXNet (Part 2), and TFLite (Part 3). Once loaded, we can run compiled, quantized models on any +hardware TVM supports. + +This is part 1 of the tutorial, where we will focus on PyTorch-prequantized models. """ ################################################################################# From 0520b36d3c4e81ef0bfef8120371152ee362645a Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sat, 18 Apr 2020 00:51:15 +0000 Subject: [PATCH 2/3] Sphinx fis. --- .../frontend/deploy_prequantized_mxnet.py | 32 +++++++++++++------ .../frontend/deploy_prequantized_pytorch.py | 15 ++------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized_mxnet.py b/tutorials/frontend/deploy_prequantized_mxnet.py index 8b1b2fe2dba4..6170cc266104 100644 --- a/tutorials/frontend/deploy_prequantized_mxnet.py +++ b/tutorials/frontend/deploy_prequantized_mxnet.py @@ -16,7 +16,7 @@ # under the License. """ Deploy a Framework-prequantized Model with TVM - Part 2 (MXNet) -============================================== +=============================================================== **Author**: `Animesh Jain `_ Welcome to Part 2 of Deploy Framework-Prequantized Model with TVM. In this tutorial, we will start @@ -163,7 +163,7 @@ def run_tvm(graph, lib, params, batch): ############################################################################### # TVM compilation of pre-quantized model and inference. -# --------------------------------- +# ----------------------------------------------------- # Use MXNet-Relay parser. Note that the frontend parser call is exactly same as frontend parser call # for a FP32 model. @@ -179,7 +179,7 @@ def run_tvm(graph, lib, params, batch): # print(mod) # Compile Relay module. Set the target platform. Replace the target with the your target type. -target = 'llvm -mcpu=cascadelake' +target = 'llvm' with relay.build_config(opt_level=3): graph, lib, params = relay.build_module.build(mod, target=target, params=params) @@ -211,10 +211,22 @@ def run_tvm(graph, lib, params, batch): prof_res = np.array(ftimer().results) * 1e3 print("Elapsed average ms:", np.mean(prof_res)) -########################################################################## -# Notes -# ----- -# 1) On Intel Cascadelake server, the performance is 2.01 ms. -# 2) Auto-tuning can potentially improve this performance. Please follow the tutorial at -# `Auto-tuning a convolution network for x86 CPU -# `_. +###################################################################### +# .. note:: +# +# Unless the hardware has special support for fast 8 bit instructions, quantized models are +# not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does +# quantized convolution in 16 bit, even if the model itself is 8 bit. +# +# For x86, the best performance can be achieved on CPUs with AVX512 instructions set. +# In this case, TVM utilizes the fastest available 8 bit instructions for the given target. +# This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer). +# For EC2 C5.12x large instance, TVM latency for this tutorial is ~2 ms. +# +# Moreover, the following general tips for CPU performance equally applies: +# +# * Set the environment variable TVM_NUM_THREADS to the number of physical cores +# * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or +# "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) +# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU +# `_. diff --git a/tutorials/frontend/deploy_prequantized_pytorch.py b/tutorials/frontend/deploy_prequantized_pytorch.py index 51c52f8e58b8..b9e2cb79e315 100644 --- a/tutorials/frontend/deploy_prequantized_pytorch.py +++ b/tutorials/frontend/deploy_prequantized_pytorch.py @@ -16,7 +16,7 @@ # under the License. """ Deploy a Framework-prequantized Model with TVM - Part 1 (PyTorch) -============================================== +================================================================= **Author**: `Masahiro Masuda `_ This is a tutorial on loading models quantized by deep learning frameworks into TVM. @@ -227,14 +227,5 @@ def quantize_model(model, inp): # * Set the environment variable TVM_NUM_THREADS to the number of physical cores # * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or # "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) - - -############################################################################### -# Deploy a quantized MXNet Model -# ------------------------------ -# TODO - -############################################################################### -# Deploy a quantized TFLite Model -# ------------------------------- -# TODO +# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU +# `_. From 4aa193f697c906c696a6acbd7eea4aac365a18a5 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sat, 18 Apr 2020 22:58:02 +0000 Subject: [PATCH 3/3] Reviews address. --- .../frontend/deploy_prequantized_mxnet.py | 117 +++++++++++------- .../frontend/deploy_prequantized_pytorch.py | 7 +- 2 files changed, 77 insertions(+), 47 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized_mxnet.py b/tutorials/frontend/deploy_prequantized_mxnet.py index 6170cc266104..58873b1dc28c 100644 --- a/tutorials/frontend/deploy_prequantized_mxnet.py +++ b/tutorials/frontend/deploy_prequantized_mxnet.py @@ -19,14 +19,18 @@ =============================================================== **Author**: `Animesh Jain `_ -Welcome to Part 2 of Deploy Framework-Prequantized Model with TVM. In this tutorial, we will start -with a FP32 MXNet graph, quantize it using MXNet, and then compile and execute it via TVM. +Welcome to part 2 of the Deploy Framework-Prequantized Model with TVM tutorial. +In this part, we will start with a FP32 MXNet graph, quantize it using +MXNet, and then compile and execute it via TVM. -For more details on quantizing the model using MXNet, readers are encouraged to go through `Model -Quantization with Calibration Examples +For more details on quantizing the model using MXNet, readers are encouraged to +go through `Model Quantization with Calibration Examples `_. -Pre-requisites +To get started, we need mxnet-mkl and gluoncv package. They can be installed as follows. + +.. code-block:: bash + pip3 install mxnet-mkl --user pip3 install gluoncv --user """ @@ -39,7 +43,9 @@ import mxnet as mx from gluoncv.model_zoo import get_model -from mxnet.contrib.quantization import * +from mxnet.contrib.quantization import quantize_model_mkldnn + +import numpy as np import tvm from tvm import relay @@ -48,20 +54,30 @@ ############################################################################### # Helper functions # ---------------- -def download_calib_dataset(dataset_url, calib_dataset): - """ Download calibration dataset. """ - print('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) - mx.test_utils.download(dataset_url, calib_dataset) +############################################################################### +# We need to download the calibration dataset. This dataset is used to find minimum and maximum +# values of intermediate tensors while post-training MXNet quantization. MXNet quantizer, using +# these min/max values, finds outs a suitable scale for the quantized tensors. +def download_calib_dataset(dataset_url, calib_dataset_fname): + print('Downloading calibration dataset from %s to %s' % \ + (dataset_url, calib_dataset_fname)) + mx.test_utils.download(dataset_url, calib_dataset_fname) -def prepare_calib_dataset(data_shape, label_name): + +############################################################################### +# Lets preprare the calibration dataset by pre-processing. In this tutorial, we follow the +# pre-processing used the MXNet quantization `tutorial +# `_. Please replace it +# with your pre-processing if needed. +def prepare_calib_dataset(data_shape, label_name, calib_dataset_fname): """ Preprocess the dataset and set up the data iterator. """ mean_args = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939} std_args = {'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375} combine_mean_std = {} combine_mean_std.update(mean_args) combine_mean_std.update(std_args) - data = mx.io.ImageRecordIter(path_imgrec='data/val_256_q90.rec', + data = mx.io.ImageRecordIter(path_imgrec=calib_dataset_fname, label_width=1, preprocess_threads=60, batch_size=1, @@ -74,10 +90,13 @@ def prepare_calib_dataset(data_shape, label_name): return data +############################################################################### +# The following function reads the FP32 MXNet model. In this example, we use resnet50-v1 model. The +# readers are encouraged to go through MXNet quantization tutorial to get more models. We convert +# the MXNet model to its symbol format. def get_mxnet_fp32_model(): """ Read the MXNet symbol. """ model_name = 'resnet50_v1' - dir_path = os.path.dirname(os.path.realpath(__file__)) block = get_model(name=model_name, pretrained=True) # Convert the model to symbol format. @@ -98,8 +117,11 @@ def get_mxnet_fp32_model(): return sym, args, auxs +############################################################################### +# Lets now quantize the model using MXNet. MXNet works in concert with MKLDNN to quantize the +# model. Note that MKLDNN is used only for quantizing. Once we get a quantized model, we can compile +# and execute it on any supported hardware platform in TVM. def quantize_model(sym, arg_params, aux_params, data, ctx, label_name): - """ Quantize the model using MXNet. """ return quantize_model_mkldnn(sym=sym, arg_params=arg_params, aux_params=aux_params, @@ -110,10 +132,12 @@ def quantize_model(sym, arg_params, aux_params, data, ctx, label_name): label_names=(label_name,)) +############################################################################### +# Lets run MXNet pre-quantized model inference and get the MXNet prediction. def run_mxnet(qsym, data, batch, ctx, label_name): - """ Run MXNet pre-quantized model inference. """ mod = mx.mod.Module(symbol=qsym, context=[ctx], label_names=[label_name, ]) - mod.bind(for_training=False, data_shapes=data.provide_data, label_shapes=data.provide_label) + mod.bind(for_training=False, data_shapes=data.provide_data, + label_shapes=data.provide_label) mod.set_params(qarg_params, qaux_params) mod.forward(batch, is_train=False) mxnet_res = mod.get_outputs()[0].asnumpy() @@ -121,8 +145,9 @@ def run_mxnet(qsym, data, batch, ctx, label_name): return mxnet_pred +############################################################################### +# Lets run TVM compiled pre-quantized model inference and get the TVM prediction. def run_tvm(graph, lib, params, batch): - """ Run TVM compiler model inference. """ from tvm.contrib import graph_runtime rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input(**params) @@ -133,40 +158,49 @@ def run_tvm(graph, lib, params, batch): return tvm_pred, rt_mod +############################################################################### # Initialize variables. label_name = 'softmax_label' data_shape = (3, 224, 224) ctx = mx.cpu(0) - ############################################################################### -# MXNet quantization and inference. -# --------------------------------- +# MXNet quantization and inference +# -------------------------------- +############################################################################### # Download and prepare calibrarion dataset. -download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec') -data = prepare_calib_dataset(data_shape, label_name) +calib_dataset_fname = '/tmp/val_256_q90.rec' +download_calib_dataset(dataset_url='http://data.mxnet.io/data/val_256_q90.rec', + calib_dataset_fname=calib_dataset_fname ) +data = prepare_calib_dataset(data_shape, label_name, calib_dataset_fname) +############################################################################### # Get a FP32 Resnet 50 MXNet model. sym, arg_params, aux_params = get_mxnet_fp32_model() +############################################################################### # Quantize the MXNet model using MXNet quantizer. qsym, qarg_params, qaux_params = quantize_model(sym, arg_params, aux_params, data, ctx, label_name) +############################################################################### # Get the testing image from the MXNet data iterator. batch = data.next() +############################################################################### # Run MXNet inference on the quantized model. mxnet_pred = run_mxnet(qsym, data, batch, ctx, label_name) - ############################################################################### -# TVM compilation of pre-quantized model and inference. -# ----------------------------------------------------- +# TVM compilation and inference +# ---------------------------------------------------- -# Use MXNet-Relay parser. Note that the frontend parser call is exactly same as frontend parser call -# for a FP32 model. +############################################################################### +# We use the MXNet-Relay parser to conver the MXNet pre-quantized graph into Relay IR. Note that the +# frontend parser call for a pre-quantized model is exactly same as frontend parser call for a FP32 +# model. We encourage you to remove the comment from print(mod) and inspect the Relay module. You +# will see many QNN operators, like, Requantize, Quantize and QNN Conv2D. input_shape = [1] + list(data_shape) input_dict = {'data': input_shape} mod, params = relay.frontend.from_mxnet(qsym, @@ -174,36 +208,35 @@ def run_tvm(graph, lib, params, batch): shape=input_dict, arg_params=qarg_params, aux_params=qaux_params) - -# Please inspect the module. You will have QNN operators like requantize, quantize, conv2d. # print(mod) -# Compile Relay module. Set the target platform. Replace the target with the your target type. +############################################################################### +# Lets now the compile the Relay module. We use the "llvm" target here. Please replace it with the +# target platform that you are interested in. target = 'llvm' with relay.build_config(opt_level=3): - graph, lib, params = relay.build_module.build(mod, target=target, params=params) + graph, lib, params = relay.build_module.build(mod, target=target, + params=params) -# Call inference on the compiled module. +############################################################################### +# Finally, lets call inference on the TVM compiled module. tvm_pred, rt_mod = run_tvm(graph, lib, params, batch) - ############################################################################### -# Accuracy comparison. -# -------------------- +# Accuracy comparison +# ------------------- +############################################################################### # Print the top-5 labels for MXNet and TVM inference. Note that final tensors can slightly differ # between MXNet and TVM quantized inference, but the classification accuracy is not significantly -# affected. Output of the following code is as follows -# -# TVM Top-5 labels: [236 211 178 165 168] -# MXNet Top-5 labels: [236 211 178 165 168] +# affected. print("TVM Top-5 labels:", tvm_pred) print("MXNet Top-5 labels:", mxnet_pred) ########################################################################## -# Measure performance. -# -------------------- +# Measure performance +# ------------------- # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate ctx = tvm.cpu(0) @@ -228,5 +261,5 @@ def run_tvm(graph, lib, params, batch): # * Set the environment variable TVM_NUM_THREADS to the number of physical cores # * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or # "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) -# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU +# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU # `_. diff --git a/tutorials/frontend/deploy_prequantized_pytorch.py b/tutorials/frontend/deploy_prequantized_pytorch.py index b9e2cb79e315..6363a13a4d13 100644 --- a/tutorials/frontend/deploy_prequantized_pytorch.py +++ b/tutorials/frontend/deploy_prequantized_pytorch.py @@ -24,11 +24,8 @@ the quantization story in TVM can be found `here `_. -In this series of tutorials, we demonstrate how to load and run models quantized by PyTorch (Part -1), MXNet (Part 2), and TFLite (Part 3). Once loaded, we can run compiled, quantized models on any -hardware TVM supports. - -This is part 1 of the tutorial, where we will focus on PyTorch-prequantized models. +Here, we demonstrate how to load and run models quantized by PyTorch. Once loaded, we can run +compiled, quantized models on any hardware TVM supports. """ #################################################################################