-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Description
Hi,
I am getting lower accuracy with TVM when targeting both cuda and cpu as compared to running with a pytorch model. This is a variant of a Resnet-18 model. Find the link to download the model below.
You will have to download the imagenet validation dataset and extract/sort it into a folder. Replace imagenet/data with the name of that folder.
You can download the model from here, untar, and pass the path to the script below.
Environment:
TVM installed from source
Pytorch 1.8.1
Python 3.7
OS: Ubuntu 18.04
Cuda 11.1
GPUs: 8 NVidia GA100
Code:
import torch
import metrics
from torch.utils.data import DataLoader
#from fuzzer.datasets.ImageNetDataset import ImageNetDataset
import sys
#from training_utils import eval_model_vision, eval_model_tvm
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np
def eval_model_tvm(model, dataset, device):
import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.contrib import graph_executor
import logging
logger = logging.getLogger('compile_engine')
logger.setLevel(logging.ERROR)
validation_dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
if "cpu" in device.lower():
target = tvm.target.Target("llvm", host="llvm")
else:
target = tvm.target.cuda()
print("target", target)
dev = tvm.device(str(target))
model = model.to("cpu")
model.eval()
mod = None
lib = None
acc1s = []
acc5s = []
for i, (images, targets) in enumerate(validation_dataloader):
print(i)
input_name = "input0"
if mod is None:
scripted_model = torch.jit.trace(model, images).eval()
print("scripted")
input_data = np.array([images[i].data.numpy() for i in range(len(images))], dtype="float32")
shape_list = [(input_name, input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
m.set_input(input_name, tvm.nd.array(images))
m.run()
output = m.get_output(0).numpy()
acc1, acc5 = metrics.accuracy(torch.tensor(output), targets, topk=(1, 5))
print("Batch {0}, acc1: {1} acc5: {2}".format(i, acc1, acc5))
acc1s.append(acc1)
acc5s.append(acc5)
return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc5s)}
def eval_model_vision(model, dataset, device, criterion, compute_metrics_fn):
print("Running validation...")
if not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if not isinstance(dataset, DataLoader):
validation_dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
else:
validation_dataloader = dataset
acc1s = []
acc2s = []
model.to(device)
model.eval()
print("Val size ", len(validation_dataloader))
with torch.no_grad():
for i, (images, target) in enumerate(validation_dataloader):
# compute output
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = compute_metrics_fn(output, target, topk=(1, 5))
acc1s.append(acc1.item())
acc2s.append(acc5.item())
if i % 10 == 0:
print(i, loss)
return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc2s)}
def load_dataset():
tr = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageNet("./imagenet/data", split="val", transform=tr)
dataset = torch.utils.data.Subset(dataset, range(10000))
split_datasets = dict()
start = 0
splits = [("train", 0.70), ("val", 0.10), ("test", 0.10)]
for split in splits:
indices = range(start, int(split[1]*len(dataset)) + start)
split_datasets[split[0]] = torch.utils.data.Subset(dataset, indices)
start = indices[-1] + 1
dataset = split_datasets
return dataset
model = torch.load(sys.argv[1]+'/model.pt')
dataset = load_dataset()
DEVICE="cuda"
res1 = eval_model_vision(model, dataset["val"], device=DEVICE, criterion=torch.nn.CrossEntropyLoss(), compute_metrics_fn=metrics.accuracy)
print(res1)
res2= eval_model_tvm(model, dataset["val"] , DEVICE)
print(res2)Output:
Running validation...
Val size 10
0 tensor(12.4782, device='cuda:0')
{'acc1': 40.2, 'acc5': 63.9}
target cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32
0
scripted
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
{'acc1': 7.5, 'acc5': 21.7}
Please let me know if you need more info. Thanks!
Metadata
Metadata
Assignees
Labels
No labels