Skip to content

[Bug] rocm platform result are not correct #13666

@wangzy0327

Description

@wangzy0327

I tried to execute mnist-model by tvm in rocm platform(rocm 5.2). The result of execution is error.

Expected behavior

The result of rocm platform equals result of cuda platform or opencl platform

Actual behavior

The result of rocm platform not equal result of cuda platform or opencl platform

Environment

Operating System:Ubuntu 20.04
TVM version : 7f1856d
platform: rocm 5.2

Any environment details, such as: Operating System, TVM version, etc

Steps to reproduce

There is the test code.

onnx_rocm.py
from csv import writer
from pyexpat import model
import onnx
#from tvm.driver import tvmc
import numpy as np
import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor
import tvm.testing
import numpy as np
import os


class NetworkData():
    def __init__(self,name:str,net_set:list,prefix_str:str,suffix_str:str,input_name:str,input:tuple,output:tuple):
        self.name = name
        self.net_set = net_set
        self.prefix_str = prefix_str
        self.suffix_str = suffix_str
        self.input_name = input_name
        self.input = input
        self.output = output

mnist_networkData = NetworkData(name = "mnist",
                                net_set = ["mnist-7","mnist-8"],
                                prefix_str = "mnist/model/",
                                suffix_str = ".onnx",
                                input_name = 'Input3',
                                input = (1,1,28,28),
                                output = (1,10))


MODEL_NAME = {
             "mnist":mnist_networkData,
             }


dtype="float32"
common_prefix_str = "onnx-model/vision/classification/"
tol_paras = [1e-7,1e-6,1e-5,1e-4,1e-3,1e-2]

import logging
logging.basicConfig(level=logging.ERROR)

import warnings
warnings.filterwarnings('ignore')


def build(target:str,mod:tvm.IRModule, params:dict, input_name:str, input_data:np.ndarray, input:tuple, output: tuple) -> np.ndarray:
    tgt = tvm.target.Target(target=target, host="llvm")
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
    # print(lib.get_lib().imported_modules[0].get_source())
    # print("------------------------source code start----------------------------")
    # print(lib.get_lib().imported_modules[0].get_source())
    # print("------------------------source code end----------------------------")
    dev = tvm.device(str(target), 0)
    module = graph_executor.GraphModule(lib["default"](dev))
    module.set_input(input_name, input_data)
    module.run()
    output_shape = output
    tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
    return tvm_output

def main(model_network : NetworkData):
    # 设置随机种子
    np.random.seed(0)
    I_np = np.random.uniform(size = model_network.input).astype(dtype)
    print(I_np[0][0][0][:10])
    header = ['network_name','network_sub_name','input','output','tolerance','rocm_cost_time','opencl_cost_time']
    rows = []
    for child_model_network in model_network.net_set:
        print("--------"+child_model_network+"----start-------------")
        onnx_model = onnx.load(common_prefix_str + 
                               model_network.prefix_str +
                               child_model_network +
                               model_network.suffix_str)
        shape_dict = {model_network.input_name: I_np.shape}
        mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
        import datetime
        # opencl_starttime = datetime.datetime.now()
        # opencl_output = build("opencl",mod = mod,params = params,input_name = model_network.input_name,input_data = I_np, input = I_np.shape, output = model_network.output)
        # opencl_endtime = datetime.datetime.now()
        # opencl_duringtime = opencl_endtime - opencl_starttime
        # print("%15s network opencl cost time is %s s"%(child_model_network,opencl_duringtime))
        rocm_starttime = datetime.datetime.now()
        rocm_output = build("rocm",mod = mod,params = params,input_name = model_network.input_name,input_data = I_np, input = I_np.shape, output = model_network.output)
        rocm_endtime = datetime.datetime.now()
        rocm_duringtime = rocm_endtime - rocm_starttime
        print("%15s network rocm cost time is %s s"%(child_model_network,rocm_duringtime))
        opencl_starttime = datetime.datetime.now()
        opencl_output = build("opencl",mod = mod,params = params,input_name = model_network.input_name,input_data = I_np, input = I_np.shape, output = model_network.output)
        opencl_endtime = datetime.datetime.now()
        opencl_duringtime = opencl_endtime - opencl_starttime
        print("%15s network opencl cost time is %s s"%(child_model_network,opencl_duringtime))
        if rocm_output.ndim > 2:
            rocm_output = rocm_output.reshape(rocm_output.shape[0],rocm_output.shape[1])
            opencl_output = opencl_output.reshape(opencl_output.shape[0],opencl_output.shape[1])
        print(rocm_output[0][:10])
        print(opencl_output[0][:10])
        row = {'network_name': model_network.name,'network_sub_name':child_model_network, 'input':model_network.input, 'output':model_network.output, 'rocm_cost_time':rocm_duringtime,'opencl_cost_time':opencl_duringtime}
        for para in tol_paras: 
            if np.allclose(rocm_output,opencl_output,rtol=para, atol=para):
                row["tolerance"] = para
                rows.append(row)
                print("%15s opencl network tolerance is %g"%(child_model_network,para))
                break
    import csv
    file_exist = False
    access_mode = 'w+'
    model_network_file = model_network.name+'_network_test_data.csv'
    if os.path.exists(model_network_file):
        file_exist = True
        access_mode = 'a+'
    with open(model_network_file,access_mode,encoding='utf-8',newline='') as f:
        writer = csv.DictWriter(f,header)
        if not file_exist :
            writer.writeheader()
        writer.writerows(rows) 
    pass

for name,each_network in MODEL_NAME.items():
    print("-----------"+name+"----start----------------")
    main(each_network)

The result of program as follow.

image

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions