diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5471f67ea106..a5ae959901ba 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm import relay from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple @@ -32,23 +33,23 @@ from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn +from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .. import random as _random from .common import ( AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, infer_value, + lstm_cell, new_var, unbind, - gru_cell, - lstm_cell, ) __all__ = ["from_onnx"] @@ -3481,6 +3482,62 @@ def _impl_v1(cls, inputs, attr, params): return vals +class NegativeLogLikelihoodLoss(OnnxOpConverter): + """Operator converter for random_uniform""" + + VALID_REDUCTIONS = {"mean", "sum", "none"} + + @classmethod + def _impl_v13(cls, inputs, attr, params): + ignore_index = attr.get("ignore_index", None) + reduction = attr.get("reduction", b"mean").decode("utf-8") + + if reduction not in cls.VALID_REDUCTIONS: + raise ValueError( + f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}" + ) + + input_tensor, target_tensor = inputs[0], inputs[1] + if len(inputs) == 3: + weight_tensor = inputs[2] + else: + channels = infer_shape(input_tensor)[1] + weight_tensor = relay.ones( + [channels], + dtype=input_tensor.type_annotation.dtype, + ) + + loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1)) + loss = relay.squeeze(loss, axis=[1]) + + expanded_target_tensor = relay.expand_dims(target_tensor, 0) + expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) + flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) + select_weights = relay.reshape_like(flattened_weights, loss) + loss *= select_weights + + if ignore_index is not None: + # "Ignore" values whose target is the ignore_index + mask_tensor = relay.equal( + target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype) + ) + mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8") + loss *= relay.cast_like(mask_tensor, loss) + + # This is not explained super clearly in the onnx spec, but masked values don't + # contribute toward the final value in reduction + select_weights *= relay.cast_like(mask_tensor, select_weights) + + weight_total = relay.sum(select_weights) + + if reduction == "mean": + return relay.sum(loss) / weight_total + if reduction == "sum": + return relay.sum(loss) + # Case reduction == 'none' + return loss + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3663,6 +3720,8 @@ def _get_convert_map(opset): "ConvInteger": ConvInteger.get_converter(opset), # Random number generation. "RandomUniform": RandomUniform.get_converter(opset), + # Loss functions + "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e0eb1f75217..7693b636e373 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4752,41 +4752,24 @@ def verify_eyelike(indata): "test_momentum_multiple", "test_mvn", "test_nesterov_momentum", - "test_nllloss_NC", + # When unsqueeze is fully supported, remaining nllloss tests should work: "test_nllloss_NC_expanded", - "test_nllloss_NCd1", "test_nllloss_NCd1_expanded", - "test_nllloss_NCd1_ii", "test_nllloss_NCd1_ii_expanded", - "test_nllloss_NCd1_mean_weight_negative_ii", "test_nllloss_NCd1_mean_weight_negative_ii_expanded", - "test_nllloss_NCd1_weight", "test_nllloss_NCd1_weight_expanded", - "test_nllloss_NCd1_weight_ii", "test_nllloss_NCd1_weight_ii_expanded", - "test_nllloss_NCd1d2", "test_nllloss_NCd1d2_expanded", - "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", - "test_nllloss_NCd1d2_reduction_mean", "test_nllloss_NCd1d2_reduction_mean_expanded", - "test_nllloss_NCd1d2_reduction_sum", "test_nllloss_NCd1d2_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight", "test_nllloss_NCd1d2_with_weight_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_mean", "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum", "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", - "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_nllloss_NCd1d2d3_sum_weight_high_ii", "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", - "test_nllloss_NCd1d2d3d4d5_mean_weight", "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", - "test_nllloss_NCd1d2d3d4d5_none_no_weight", "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", "test_pow_types_float", "test_pow_types_float32_int32",