diff --git a/rust/tvm/README.md b/rust/tvm/README.md index b1bb4687679e..3455975ad81d 100644 --- a/rust/tvm/README.md +++ b/rust/tvm/README.md @@ -26,7 +26,7 @@ You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/t The goal of this crate is to provide bindings to both the TVM compiler and runtime APIs. First train your **Deep Learning** model using any major framework such as -[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/). +[PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/). Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators. The Rust bindings are composed of a few crates: diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md index d6e32f7fa768..ad76ac0048a0 100644 --- a/rust/tvm/examples/resnet/README.md +++ b/rust/tvm/examples/resnet/README.md @@ -21,7 +21,7 @@ This end-to-end example shows how to: * build `Resnet 18` with `tvm` from Python * use the provided Rust frontend API to test for an input image -To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +To run the example with pretrained resnet weights, first `tvm` and `torchvision` must be installed for the python build. To install torchvision for cpu, run `pip install torch torchvision` and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). * **Build the example**: `cargo build diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index 45e4d6d658d5..9e3a76433ffc 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -21,10 +21,6 @@ use anyhow::{Context, Result}; use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { - // Currently disabled, as it depends on the no-longer-supported - // mxnet repo to download resnet. - - /* let out_dir = std::env::var("CARGO_MANIFEST_DIR")?; let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"); let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt"); @@ -57,7 +53,5 @@ fn main() -> Result<()> { ); println!("cargo:rustc-link-search=native={}", out_dir); - */ - Ok(()) } diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index df02dd78f57c..4e8ae01c413b 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -17,22 +17,18 @@ # under the License. import argparse -import csv import logging -from os import path as osp -import sys import shutil +from os import path as osp import numpy as np - +import torch +import torchvision import tvm -from tvm import te -from tvm import relay, runtime -from tvm.relay import testing -from tvm.contrib import graph_executor, cc from PIL import Image +from tvm import relay, runtime +from tvm.contrib import cc, graph_executor from tvm.contrib.download import download_testdata -from mxnet.gluon.model_zoo.vision import get_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -64,11 +60,16 @@ def build(target_dir): """Compiles resnet18 with TVM""" - # Download the pretrained model in MxNet's format. - block = get_model("resnet18_v1", pretrained=True) + # Download the pretrained model from Torchvision. + weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + torch_model = torchvision.models.resnet18(weights=weights).eval() + + input_shape = [1, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(torch_model, input_data) + input_infos = [("data", input_data.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) - shape_dict = {"data": (1, 3, 224, 224)} - mod, params = relay.frontend.from_mxnet(block, shape_dict) # Add softmax to do classification in last layer. func = mod["main"] func = relay.Function( @@ -93,7 +94,6 @@ def build(target_dir): def download_img_labels(): """Download an image and imagenet1k class labels for test""" - from mxnet.gluon.utils import download synset_url = "".join( [ diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 0ea8c4cf8bb5..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -31,10 +31,6 @@ use tvm_rt::graph_rt::GraphRt; use tvm_rt::*; fn main() -> anyhow::Result<()> { - // Currently disabled, as it depends on the no-longer-supported - // mxnet repo to download resnet. - - /* let dev = Device::cpu(0); println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); @@ -138,7 +134,6 @@ fn main() -> anyhow::Result<()> { "input image belongs to the class `{}` with probability {}", label, max_prob ); - */ Ok(()) }