From 3fe5894181e6c845b590465053912d19171c8916 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Mon, 24 Feb 2020 21:57:06 -0800 Subject: [PATCH 1/8] Add a tutorial for PyTorch --- tutorials/frontend/from_pytorch.py | 166 +++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 tutorials/frontend/from_pytorch.py diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py new file mode 100644 index 000000000000..675cf06d8a77 --- /dev/null +++ b/tutorials/frontend/from_pytorch.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile PyTorch Models +===================== +**Author**: `Alex Wong `_ + +This article is an introductory tutorial to deploy PyTorch models with Relay. + +For us to begin with, PyTorch should be installed. +TorchVision is also required since we will be using it as our model zoo. + +A quick solution is to install via pip + +.. code-block:: bash + + pip install torch==1.4.0 + pip install torchvision==0.5.0 + +or please refer to official site +https://pytorch.org/get-started/locally/ +""" + +# tvm, relay +import tvm +from tvm import relay + +# numpy, packaging +import numpy as np +from packaging import version +from tvm.contrib.download import download_testdata + +# PyTorch imports +import torch +import torchvision +if (version.parse(torch.__version__) > version.parse("1.4.0")) \ + or (version.parse(torchvision.__version__) > version.parse("0.5.0")): + assert "Please ensure version of PyTorch is supported by TVM" + +###################################################################### +# Load a pretrained PyTorch model +# ------------ +model_name = 'resnet18' +model = getattr(torchvision.models, model_name)(pretrained=True) +model = model.float().eval() + +# We grab the TorchScripted model via tracing +input_shape = [1, 3, 224, 224] +input_data = torch.randn(input_shape).float() +scripted_model = torch.jit.trace(model, input_data).float().eval() + +###################################################################### +# Load a test image +# ------------------ +# Classic cat example! +from PIL import Image +img_url = 'https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/master/data/cat_224.jpg' +img_path = download_testdata(img_url, 'cat_224.png', module='data') +img = Image.open(img_path) + +# Preprocess the image and convert to tensor +from torchvision import transforms +my_preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +]) +img = my_preprocess(img) +img = np.expand_dims(img, 0) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# Convert PyTorch graph to Relay graph. +shape_dict = {'input0': img.shape} +mod, params = relay.frontend.from_pytorch(scripted_model, + shape_dict) + +###################################################################### +# Relay Build +# ----------- +# Compile the graph to llvm target with given input specification. +target = 'llvm' +target_host = 'llvm' +layout = None +ctx = tvm.cpu(0) +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, + target=target, + target_host=target_host, + params=params) + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now we can try deploying the compiled model on target. + +from tvm.contrib import graph_runtime +dtype = 'float32' +m = graph_runtime.create(graph, lib, ctx) +# Set inputs +m.set_input('input0', tvm.nd.array(img.astype(dtype))) +m.set_input(**params) +# Execute +m.run() +# Get outputs +tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) + +##################################################################### +# Look up synset name +# ------------------- +# Look up prediction top 1 index in 1000 class synset. +synset_url = ''.join(['https://raw.githubusercontent.com/Cadene/', + 'pretrained-models.pytorch/master/data/', + 'imagenet_synsets.txt']) +synset_name = 'imagenet_synsets.txt' +synset_path = download_testdata(synset_url, synset_name, module='data') +with open(synset_path) as f: + synsets = f.readlines() + +synsets = [x.strip() for x in synsets] +splits = [line.split(' ') for line in synsets] +key_to_classname = {spl[0]:' '.join(spl[1:]) for spl in splits} + +class_url = ''.join(['https://raw.githubusercontent.com/Cadene/', + 'pretrained-models.pytorch/master/data/', + 'imagenet_classes.txt']) +class_name = 'imagenet_classes.txt' +class_path = download_testdata(class_url, class_name, module='data') +with open(class_path) as f: + class_id_to_key = f.readlines() + +class_id_to_key = [x.strip() for x in class_id_to_key] + +# Get top-1 result for TVM +top1_tvm = np.argmax(tvm_output.asnumpy()[0]) +tvm_class_key = class_id_to_key[top1_tvm] + +# Convert input to PyTorch variable and get PyTorch result for comparison +torch_img = torch.from_numpy(img) +from torch.autograd import Variable +torch_img = Variable(torch_img).float() +output = model(torch_img) + +# Get top-1 result for PyTorch +top1_torch = np.argmax(output.detach().numpy()) +torch_class_key = class_id_to_key[top1_torch] + +print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) +print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file From dac502154e5494db7392ed7a23085d255193bfb4 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 11:20:57 -0800 Subject: [PATCH 2/8] Fix sphinx formatting, add version support --- tutorials/frontend/from_pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 675cf06d8a77..0d11e3e04459 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -16,7 +16,7 @@ # under the License. """ Compile PyTorch Models -===================== +====================== **Author**: `Alex Wong `_ This article is an introductory tutorial to deploy PyTorch models with Relay. @@ -33,6 +33,12 @@ or please refer to official site https://pytorch.org/get-started/locally/ + +PyTorch versions should be backwards compatible but should be used +with the proper TorchVision version. + +Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may +be unstable. """ # tvm, relay @@ -53,7 +59,7 @@ ###################################################################### # Load a pretrained PyTorch model -# ------------ +# ------------------------------- model_name = 'resnet18' model = getattr(torchvision.models, model_name)(pretrained=True) model = model.float().eval() @@ -65,7 +71,7 @@ ###################################################################### # Load a test image -# ------------------ +# ----------------- # Classic cat example! from PIL import Image img_url = 'https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/master/data/cat_224.jpg' From 72769430f2a56b605c2ab43cb925bc0ffb90c8d2 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 11:25:20 -0800 Subject: [PATCH 3/8] Remove space --- tutorials/frontend/from_pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 0d11e3e04459..01cba505846c 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -116,7 +116,6 @@ # Execute the portable graph on TVM # --------------------------------- # Now we can try deploying the compiled model on target. - from tvm.contrib import graph_runtime dtype = 'float32' m = graph_runtime.create(graph, lib, ctx) From 85368d36f2eda0767d4efb04649be758abf8dada Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 15:31:49 -0800 Subject: [PATCH 4/8] Remove version check --- tutorials/frontend/from_pytorch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 01cba505846c..8f3e708eb79f 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -53,9 +53,6 @@ # PyTorch imports import torch import torchvision -if (version.parse(torch.__version__) > version.parse("1.4.0")) \ - or (version.parse(torchvision.__version__) > version.parse("0.5.0")): - assert "Please ensure version of PyTorch is supported by TVM" ###################################################################### # Load a pretrained PyTorch model From 2ba82c4ccf0892d950777f5b77bad3f963c2a64b Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 16:16:35 -0800 Subject: [PATCH 5/8] Some refactoring --- tutorials/frontend/from_pytorch.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 8f3e708eb79f..d58c3b7a76fe 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -59,12 +59,12 @@ # ------------------------------- model_name = 'resnet18' model = getattr(torchvision.models, model_name)(pretrained=True) -model = model.float().eval() +model = model.eval() # We grab the TorchScripted model via tracing input_shape = [1, 3, 224, 224] -input_data = torch.randn(input_shape).float() -scripted_model = torch.jit.trace(model, input_data).float().eval() +input_data = torch.randn(input_shape) +scripted_model = torch.jit.trace(model, input_data).eval() ###################################################################### # Load a test image @@ -101,7 +101,6 @@ # Compile the graph to llvm target with given input specification. target = 'llvm' target_host = 'llvm' -layout = None ctx = tvm.cpu(0) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, @@ -156,8 +155,6 @@ # Convert input to PyTorch variable and get PyTorch result for comparison torch_img = torch.from_numpy(img) -from torch.autograd import Variable -torch_img = Variable(torch_img).float() output = model(torch_img) # Get top-1 result for PyTorch From ef416260ac355a07bd2c80aee1e11ab153ae0dd2 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 16:23:55 -0800 Subject: [PATCH 6/8] Use no grad --- tutorials/frontend/from_pytorch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index d58c3b7a76fe..dc92e835999b 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -154,12 +154,13 @@ tvm_class_key = class_id_to_key[top1_tvm] # Convert input to PyTorch variable and get PyTorch result for comparison -torch_img = torch.from_numpy(img) -output = model(torch_img) +with torch.no_grad(): + torch_img = torch.from_numpy(img) + output = model(torch_img) -# Get top-1 result for PyTorch -top1_torch = np.argmax(output.detach().numpy()) -torch_class_key = class_id_to_key[top1_torch] + # Get top-1 result for PyTorch + top1_torch = np.argmax(output.numpy()) + torch_class_key = class_id_to_key[top1_torch] print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file From 3097ddbd146f9182bb143cc45233f9590469ff73 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 16:27:47 -0800 Subject: [PATCH 7/8] Rename input --- tutorials/frontend/from_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index dc92e835999b..22c649712af0 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -91,7 +91,7 @@ # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. -shape_dict = {'input0': img.shape} +shape_dict = {'img': img.shape} mod, params = relay.frontend.from_pytorch(scripted_model, shape_dict) @@ -116,7 +116,7 @@ dtype = 'float32' m = graph_runtime.create(graph, lib, ctx) # Set inputs -m.set_input('input0', tvm.nd.array(img.astype(dtype))) +m.set_input('img', tvm.nd.array(img.astype(dtype))) m.set_input(**params) # Execute m.run() From 38464d7f9fd478266a7b35d2931ca002b6032cdd Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 16:44:02 -0800 Subject: [PATCH 8/8] Update cat img source --- tutorials/frontend/from_pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 22c649712af0..c280c259c1fe 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -71,9 +71,9 @@ # ----------------- # Classic cat example! from PIL import Image -img_url = 'https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/master/data/cat_224.jpg' -img_path = download_testdata(img_url, 'cat_224.png', module='data') -img = Image.open(img_path) +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +img_path = download_testdata(img_url, 'cat.png', module='data') +img = Image.open(img_path).resize((224, 224)) # Preprocess the image and convert to tensor from torchvision import transforms