diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ba2c6b4b54e7..02b899a251e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1496,7 +1496,7 @@ def _impl_v13(cls, inputs, attr, params): axis = relay.TupleGetItem(axes, i) # Unpack scalar axis = relay.reshape(axis, []) - axis = relay.If( + axis = relay.where( axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64") ) result = _op.expand_dims(result, axis) @@ -1509,12 +1509,18 @@ class Squeeze(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axes", None) - return _op.squeeze(*inputs, axis) + return _op.squeeze(inputs[0], axis) @classmethod def _impl_v13(cls, inputs, attr, params): axis = inputs[1] dtype = infer_type(axis).checked_type.dtype + + if isinstance(axis, _expr.Constant): + constant_axes = list(inputs[1].data.numpy()) + constant_axes = list(map(int, constant_axes)) + return _op.squeeze(inputs[0], constant_axes) + rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype) axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis) return _op.squeeze(inputs[0], fold_constant(axis)) @@ -1640,7 +1646,7 @@ def normalize_gather_indices(data, indices, axis): """Make sure gather indicies aren't negative""" ind_dtype = infer_type(indices).checked_type.dtype # Normalize the indices to a positive range - s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) + s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis, dtype="int64")) cond = fold_constant(indices < _op.const(0, ind_dtype)) if isinstance(cond, _expr.Constant): val = cond.data.numpy() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1cf6ffff762c..709c7c06d5b1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4939,25 +4939,8 @@ def verify_eyelike(indata): "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", "test_mvn", - # When unsqueeze is fully supported, remaining nllloss tests should work: - "test_nllloss_NC_expanded", - "test_nllloss_NCd1_expanded", - "test_nllloss_NCd1_ii_expanded", - "test_nllloss_NCd1_mean_weight_negative_ii_expanded", - "test_nllloss_NCd1_weight_expanded", - "test_nllloss_NCd1_weight_ii_expanded", - "test_nllloss_NCd1d2_expanded", - "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", - "test_nllloss_NCd1d2_reduction_mean_expanded", - "test_nllloss_NCd1d2_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + # This test fails llvm with a lowering error: "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", - "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", - "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded",