Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,13 +1944,11 @@ def infer_value(self, input_val, params, mod=None):
self._infer_simulated = False
self._mod = mod
return self.visit(input_val).data
#return _infer_value(input_val, params, mod)

def infer_value_simulated(self, input_val, params):
self._tmp_params = params
self._infer_simulated = True
return self.visit(input_val).data
#return _infer_value_simulated(input_val, params)

def infer(self, expr):
if self._infer_simulated:
Expand Down Expand Up @@ -1978,7 +1976,10 @@ def visit_let(self, let):
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return self.infer(Call(new_fn, new_args, call.attrs))
call = Call(new_fn, new_args, call.attrs)
if new_fn == _op.get("nn.batch_norm"):
return call
return self.infer(call)

def visit_var(self, var):
return self.infer(var)
Expand Down
85 changes: 85 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,89 @@ def test_or():
verify_or(indata=[x, y], dtype=bool)


def test_batch_norm():
def verify_batch_norm(in_shape):
batchnorm = onnx.helper.make_node('BatchNormalization',
inputs=["x", "scale", "B", "mean", "var"],
outputs=['Y'])

graph = helper.make_graph([batchnorm],
"batchnorm_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("scale",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("B",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("mean",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("var",
TensorProto.FLOAT, [in_shape[1]]),
],
outputs=[helper.make_tensor_value_info("Y",
TensorProto.FLOAT, list(in_shape))])

model = helper.make_model(graph, producer_name='batchnorm_test')

for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
scale = np.random.uniform(size=in_shape[1]).astype('float32')
b = np.random.uniform(size=in_shape[1]).astype('float32')
mean = np.random.uniform(size=in_shape[1]).astype('float32')
var = np.random.uniform(size=in_shape[1]).astype('float32')
onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], 'float32')[0]
tvm_out = get_tvm_output(model, [x, scale, b, mean, var], target, ctx, in_shape, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

verify_batch_norm([1, 3, 224, 224])
verify_batch_norm([1, 3, 24, 24])
verify_batch_norm([16, 3, 24, 24])
verify_batch_norm([16, 16, 24, 24])
verify_batch_norm([16, 16, 10, 10])


def test_batch_norm_dynamic_subgraph():
def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):
batchnorm = onnx.helper.make_node('BatchNormalization',
inputs=["x", "scale", "B", "mean", "var"],
outputs=['Y'])

shape_node = helper.make_node("Shape", ['Y'], ['shape'])
reshape_node = helper.make_node("Reshape", ["in", "shape"], ["out"])
graph = helper.make_graph([batchnorm, shape_node, reshape_node],
"batchnorm_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(o_shape)),
helper.make_tensor_value_info("scale",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("B",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("mean",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("var",
TensorProto.FLOAT, [in_shape[1]]),
],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(in_shape))])

model = helper.make_model(graph, producer_name='batchnorm_test')

for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
inp = np.random.uniform(size=o_shape).astype('float32')
scale = np.random.uniform(size=in_shape[1]).astype('float32')
b = np.random.uniform(size=in_shape[1]).astype('float32')
mean = np.random.uniform(size=in_shape[1]).astype('float32')
var = np.random.uniform(size=in_shape[1]).astype('float32')
onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0]
tvm_out = get_tvm_output(model, [x, inp, scale, b, mean, var], target, ctx, in_shape, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])


def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False):
if unset_pad:
node = helper.make_node('Conv',
Expand Down Expand Up @@ -2893,6 +2976,8 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_or()
test_depth_to_space()
test_space_to_depth()
test_batch_norm()
test_batch_norm_dynamic_subgraph()
test_conv()
test_convtranspose()
test_unsqueeze_constant()
Expand Down