Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
e8994e4
initial dyn unsqueeze example
Sep 16, 2021
b044395
simplify, properly unpack scalar
Sep 16, 2021
3c0b669
basic tests
Sep 16, 2021
2c3facb
squish bugs -- assign proper types
AndrewZhaoLuo Sep 16, 2021
045ab70
working topi
AndrewZhaoLuo Sep 16, 2021
20f55dd
fix things
AndrewZhaoLuo Sep 16, 2021
e47d78f
temp work
AndrewZhaoLuo Sep 16, 2021
cd15f1a
fix casting to int64
AndrewZhaoLuo Sep 17, 2021
e70f24d
shape encoding method for axis
AndrewZhaoLuo Sep 17, 2021
ba609a8
working shape encoding metric
AndrewZhaoLuo Sep 17, 2021
89fd696
add comment
AndrewZhaoLuo Sep 17, 2021
e21dcc1
move to non-rank encoded axis
AndrewZhaoLuo Sep 17, 2021
7f2a602
failing regime
AndrewZhaoLuo Sep 17, 2021
5a3c47b
fix
AndrewZhaoLuo Sep 17, 2021
c83a1a6
it works!
AndrewZhaoLuo Sep 17, 2021
d5749f3
add test
AndrewZhaoLuo Sep 17, 2021
b46f70b
add comment on shape func
AndrewZhaoLuo Sep 17, 2021
f681dd3
remove unused topi
AndrewZhaoLuo Sep 17, 2021
c73020d
undo some file changes
AndrewZhaoLuo Sep 17, 2021
cf3ace8
more cleanup
AndrewZhaoLuo Sep 17, 2021
2d0d48d
newline
AndrewZhaoLuo Sep 17, 2021
7669645
clean up
AndrewZhaoLuo Sep 17, 2021
14629ec
clean up
AndrewZhaoLuo Sep 17, 2021
2aa5607
enable multiple axis tests
AndrewZhaoLuo Sep 17, 2021
b2cf8ce
move tests to dynamic op
AndrewZhaoLuo Sep 17, 2021
de7320d
Update docs
AndrewZhaoLuo Sep 17, 2021
75d29c5
add converter
AndrewZhaoLuo Sep 17, 2021
385246a
initial dyn unsqueeze example
Sep 16, 2021
533d198
simplify, properly unpack scalar
Sep 16, 2021
5b75ee8
basic tests
Sep 16, 2021
63e9ec9
squish bugs -- assign proper types
AndrewZhaoLuo Sep 16, 2021
72322d0
working topi
AndrewZhaoLuo Sep 16, 2021
9b710ca
fix things
AndrewZhaoLuo Sep 16, 2021
cbf49a5
temp work
AndrewZhaoLuo Sep 16, 2021
85c3d82
fix casting to int64
AndrewZhaoLuo Sep 17, 2021
48e9129
shape encoding method for axis
AndrewZhaoLuo Sep 17, 2021
0013dfd
working shape encoding metric
AndrewZhaoLuo Sep 17, 2021
a187e57
add comment
AndrewZhaoLuo Sep 17, 2021
bac9e50
move to non-rank encoded axis
AndrewZhaoLuo Sep 17, 2021
7decc4d
failing regime
AndrewZhaoLuo Sep 17, 2021
d26ff7a
fix
AndrewZhaoLuo Sep 17, 2021
fe3ce5c
it works!
AndrewZhaoLuo Sep 17, 2021
17c9f65
add test
AndrewZhaoLuo Sep 17, 2021
cc4f30c
add comment on shape func
AndrewZhaoLuo Sep 17, 2021
55b7875
remove unused topi
AndrewZhaoLuo Sep 17, 2021
8d72e8b
undo some file changes
AndrewZhaoLuo Sep 17, 2021
118934f
more cleanup
AndrewZhaoLuo Sep 17, 2021
0d11d7c
newline
AndrewZhaoLuo Sep 17, 2021
4451059
clean up
AndrewZhaoLuo Sep 17, 2021
4310732
clean up
AndrewZhaoLuo Sep 17, 2021
227d912
enable multiple axis tests
AndrewZhaoLuo Sep 17, 2021
720fb8e
move tests to dynamic op
AndrewZhaoLuo Sep 17, 2021
25f8c23
Update docs
AndrewZhaoLuo Sep 17, 2021
0b9ec33
add converter
AndrewZhaoLuo Sep 17, 2021
feb635c
Merge branch 'aluo/onnx/unsqueeze-alt' of github.com:AndrewZhaoLuo/tv…
Sep 17, 2021
12709e7
working tests
Sep 19, 2021
ff39840
add test, remove unneeded file
AndrewZhaoLuo Sep 19, 2021
f7b045f
fix things
Sep 19, 2021
39dba34
more lint
Sep 19, 2021
6caadf3
more lint
Sep 19, 2021
a9f5117
pick things
AndrewZhaoLuo Sep 19, 2021
c0ca64b
Merge branch 'aluo/onnx/unsqueeze-alt' of github.com:AndrewZhaoLuo/tv…
AndrewZhaoLuo Sep 19, 2021
39f729f
disable opencl tests
AndrewZhaoLuo Sep 19, 2021
4f6725e
unsqueeze tests
Sep 19, 2021
2bf1fb0
clean up
AndrewZhaoLuo Sep 20, 2021
3f5e26d
Merge branch 'main' into aluo/onnx/unsqueeze-alt
AndrewZhaoLuo Sep 20, 2021
0b58530
dyn stuff
AndrewZhaoLuo Sep 21, 2021
7e2ebac
add num_newaxis
AndrewZhaoLuo Sep 21, 2021
dc17ada
Merge branch 'main' into aluo/onnx/unsqueeze-alt
AndrewZhaoLuo Sep 21, 2021
e354563
add support
Sep 19, 2021
13165ee
black
AndrewZhaoLuo Sep 20, 2021
1d0ca0f
doc string
AndrewZhaoLuo Sep 21, 2021
8ce9b3b
Merge branch 'main' into aluo/onnx/support-nllloss-tests
Sep 22, 2021
8ef89ea
remove bad merge
Sep 22, 2021
a718b2a
fix default axis behavior
Sep 23, 2021
ec71db0
Merge branch 'main' into aluo/onnx/support-nllloss-tests
AndrewZhaoLuo Sep 28, 2021
4447044
rebase
AndrewZhaoLuo Sep 28, 2021
81a40a8
fix squeeze
Sep 28, 2021
6f5aeb1
Merge branch 'main' into aluo/onnx/support-nllloss-tests
AndrewZhaoLuo Sep 29, 2021
629244f
jostle ci
AndrewZhaoLuo Sep 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def _impl_v13(cls, inputs, attr, params):
axis = relay.TupleGetItem(axes, i)
# Unpack scalar
axis = relay.reshape(axis, [])
axis = relay.If(
axis = relay.where(
axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64")
)
result = _op.expand_dims(result, axis)
Expand All @@ -1509,12 +1509,18 @@ class Squeeze(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axes", None)
return _op.squeeze(*inputs, axis)
return _op.squeeze(inputs[0], axis)

@classmethod
def _impl_v13(cls, inputs, attr, params):
axis = inputs[1]
dtype = infer_type(axis).checked_type.dtype

if isinstance(axis, _expr.Constant):
constant_axes = list(inputs[1].data.numpy())
constant_axes = list(map(int, constant_axes))
return _op.squeeze(inputs[0], constant_axes)

rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype)
axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis)
return _op.squeeze(inputs[0], fold_constant(axis))
Expand Down Expand Up @@ -1640,7 +1646,7 @@ def normalize_gather_indices(data, indices, axis):
"""Make sure gather indicies aren't negative"""
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis, dtype="int64"))
cond = fold_constant(indices < _op.const(0, ind_dtype))
if isinstance(cond, _expr.Constant):
val = cond.data.numpy()
Expand Down
19 changes: 1 addition & 18 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4939,25 +4939,8 @@ def verify_eyelike(indata):
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
"test_mvn",
# When unsqueeze is fully supported, remaining nllloss tests should work:
"test_nllloss_NC_expanded",
"test_nllloss_NCd1_expanded",
"test_nllloss_NCd1_ii_expanded",
"test_nllloss_NCd1_mean_weight_negative_ii_expanded",
"test_nllloss_NCd1_weight_expanded",
"test_nllloss_NCd1_weight_ii_expanded",
"test_nllloss_NCd1d2_expanded",
"test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded",
"test_nllloss_NCd1d2_reduction_mean_expanded",
"test_nllloss_NCd1d2_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_mean_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded",
# This test fails llvm with a lowering error:
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded",
"test_nllloss_NCd1d2d3d4d5_mean_weight_expanded",
"test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
Expand Down