From a2f641c7f364f84ffc3bd2986724f12aeb564080 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Mon, 17 May 2021 22:40:42 +0800 Subject: [PATCH 1/7] add nll_loss --- include/tvm/relay/attrs/nn.h | 13 +++++++ include/tvm/topi/nn.h | 49 +++++++++++++++++++++++++ python/tvm/relay/frontend/pytorch.py | 15 ++++++++ python/tvm/relay/op/nn/_nn.py | 11 ++++++ python/tvm/relay/op/nn/nn.py | 28 ++++++++++++++ python/tvm/relay/op/op_attrs.py | 5 +++ python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/loss.py | 53 +++++++++++++++++++++++++++ src/relay/op/nn/nn.cc | 55 ++++++++++++++++++++++++++++ src/topi/nn.cc | 4 ++ 10 files changed, 234 insertions(+) create mode 100644 python/tvm/topi/nn/loss.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 15f6b03f0c06..1f9fef35409d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1426,6 +1426,19 @@ struct BatchToSpaceNDAttrs : public tvm::AttrsNode { } }; // struct BatchToSpaceNDAttrs +/*! \brief Attributes used in NLLLoss operator */ +struct NLLLossAttrs : public tvm::AttrsNode { + std::string reduction; + int ignore_index; + + TVM_DECLARE_ATTRS(NLLLossAttrs, "relay.attrs.NLLLossAttrs") { + TVM_ATTR_FIELD(reduction).set_default("mean").describe( + "The reduction method to apply to the output. Can be" + "'none', 'mean' or 'sum'."); + TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore."); + } +}; // struct NLLLossAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 29c3156ab5d6..f649e1ee49a6 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -642,6 +643,54 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = strided_slice(out, begin_idx, end_idx, strides); return out; } + +/*! + * \brief Negative log likelihood loss. + * + * \param input The input tensor. + * \param target The target tensor. + * \param weight A manual rescaling weight given to each class. + * \param reduction The reduction method to apply to the output. + * \param ignore_index The target value to ignore. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the batch_to_space_nd operation + */ +inline Tensor nll_loss(const Tensor& input, const Tensor& target, const Tensor& weight, + std::string reduction = "mean", int ignore_index = -100, + const std::string name = "nll_loss", const std::string tag = kBroadcast) { + auto T = tvm::te::compute( + target->shape, + [&](const tvm::Array& target_indices) { + auto c = target(target_indices); + tvm::Array input_indices; + for (size_t i = 0; i < target_indices.size(); i++) { + input_indices.push_back(target_indices[i]); + if (i == 0) { + input_indices.push_back(c); + } + } + return tvm::tir::Select(c != ignore_index, -input(input_indices) * weight(c), + tvm::tir::make_const(input->dtype, 0)); + }, + name, tag); + if (reduction == "mean") { + auto W = tvm::te::compute( + target->shape, + [&](const tvm::Array& target_indices) { + auto c = target(target_indices); + return tvm::tir::Select(c != ignore_index, weight(c), + tvm::tir::make_const(input->dtype, 0)); + }, + name, tag); + return topi::divide(topi::sum(T, {}), topi::sum(W, {})); + } else if (reduction == "sum") { + return topi::sum(T, {}); + } else { // reduction == "none" + return T; + } +} } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b5cfcf5e3bac..85ae33ed979c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2305,6 +2305,20 @@ def unique(self, inputs, input_types): unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") return (unique_sliced, indices) + def nll_loss(self, inputs, input_types): + assert len(inputs) == 5 + [input, target, weight, reduction, ignore_index] = inputs + num_class = self.infer_shape(input)[1] + if reduction == 0: + reduction = "none" + elif reduction == 1: + reduction = "mean" + else: + reduction = "sum" + if weight is None: + weight = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) + return _op.nn.nll_loss(input, target, weight, reduction, ignore_index) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2517,6 +2531,7 @@ def create_convert_map(self): "aten::argsort": self.argsort, "aten::sort": self.sort, "aten::_unique2": self.unique, + "aten::nll_loss": self.nll_loss, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c6c4f4bfb959..334eddc9a322 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -886,6 +886,17 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) +# nll_loss +@reg.register_compute("nn.nll_loss") +def compute_nll_loss(attrs, inputs, out_dtype): + input, target, weights = inputs + return [topi.nn.nll_loss(input, target, weights, attrs.reduction, attrs.ignore_index)] + + +reg.register_reduce_schedule("nn.nll_loss") +reg.register_pattern("nn.nll_loss", OpPattern.OPAQUE) + + # depth_to_space @reg.register_compute("nn.depth_to_space") def compute_depth_to_space(attrs, inputs, out_dtype): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 91c148b5df2e..ee864d290bdf 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2973,6 +2973,34 @@ def cross_entropy_with_logits(predictions, targets): return _make.cross_entropy_with_logits(predictions, targets) +def nll_loss(input, target, weight, reduction="mean", ignore_index=-100): + """Negative log likelihood loss. + + Parameters + ---------- + input : tvm.relay.Expr + The input. + + target : tvm.relay.Expr + The target value of the input. + + weight : tvm.relay.Expr + The weight of each target value. + + reduction : string + The reduction method to apply to the output. + + ignore_index : int + The target value to ignore. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.nll_loss(input, target, weight, reduction, ignore_index) + + def depth_to_space(data, block_size, layout="NCHW", mode="DCR"): """Convert channels into spatial blocks. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 6844d133a77e..2e13d1f042a2 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -572,3 +572,8 @@ class ThreefryGenerateAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.UniformAttrs") class UniformAttrs(Attrs): """Attributes used in UniformAttrs operators""" + + +@tvm._ffi.register_object("relay.attrs.NLLLossAttrs") +class NLLLossAttrs(Attrs): + """Attributes for nn.nll_loss""" diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 94a5b30c9b76..b5e766adbc12 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -49,3 +49,4 @@ from .space_to_depth import * from .space_to_batch_nd import * from .batch_to_space_nd import * +from .loss import * diff --git a/python/tvm/topi/nn/loss.py b/python/tvm/topi/nn/loss.py new file mode 100644 index 000000000000..f1c3cf47f4f6 --- /dev/null +++ b/python/tvm/topi/nn/loss.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""TVM operator negative log likelihood loss compute.""" +from __future__ import absolute_import +from . import cpp + + +def nll_loss(input, target, weight, reduction, ignore_index): + """Negative log likelihood loss on the input data. + + Parameters + ---------- + input : tvm.te.Tensor + (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), + where C is the number of target classes + + target : tvm.te.Tensor + (k+1)-D with shape (N, d_1, d_2, ..., d_k) + The target value of the input. + + weight : tvm.te.Tensor + 1-D with shape (C,) + The weight of each target value. + + reduction : string + The reduction method to apply to output. + Can be "mean", "sum" or "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + output : tvm.te.Tensor + a scalar if the reduction type is "mean" or "sum", + otherwise the same shape as `target`. + """ + return cpp.nn.nll_loss(input, target, weight, reduction, ignore_index) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 32c0a21d46c7..37981850b0b7 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1068,6 +1068,7 @@ Dilate data with given dilation value (0 by default). .set_support_level(10) .add_type_rel("Dilate", DilateRel); +// relay.nn.cross_entropy_with_logits // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy_with_logits"); @@ -1091,6 +1092,60 @@ Accept logits. // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); +// relay.nn.nll_loss +TVM_REGISTER_NODE_TYPE(NLLLossAttrs); + +bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 4); + const auto* input = types[0].as(); + const auto* target = types[1].as(); + const auto* weight = types[2].as(); + const NLLLossAttrs* param = attrs.as(); + if (input == nullptr || target == nullptr || weight == nullptr) return false; + ICHECK(input->shape.size() - target->shape.size() == 1) + << "NLLLossRel: input should be one dimension larger than target, " + << "input shape = " << input->shape << ", " + << "target shape = " << target->shape; + ICHECK(weight->shape.size() == 1); + ICHECK(reporter->AssertEQ(input->shape[1], weight->shape[0])) + << "NLLLossRel: the second dimension of input should be the number of classes, " + << "which is the length of weight, " + << "input shape = " << input->shape << ", " + << "weight shape = " << weight->shape; + ICHECK(input->dtype == weight->dtype && input->dtype.is_float()); + ICHECK(target->dtype.is_int()); + // assign output type + if (param->reduction == "none") { + reporter->Assign(types[3], TensorType(target->shape, input->dtype)); + } else { + reporter->Assign(types[3], TensorType({}, input->dtype)); + } + return true; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakeNLLLoss(Expr input, Expr target, Expr weight, String reduction, int ignore_index) { + auto attrs = make_object(); + attrs->reduction = reduction; + attrs->ignore_index = ignore_index; + static const Op& op = Op::Get("nn.nll_loss"); + return Call(op, {input, target, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss); + +RELAY_REGISTER_OP("nn.nll_loss") + .describe(R"code( +Negative log likelihood loss for given input and target. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("input", "Tensor", "The input tensor.") + .add_argument("target", "Tensor", "The target tensor.") + .add_argument("weight", "Tensor", "The weight of each target values.") + .add_type_rel("NLLLoss", NLLLossRel); + bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 356f3d2ea18f..2950aee4e90d 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -65,6 +65,10 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd").set_body([](TVMArgs args, TVMRe *rv = batch_to_space_nd(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nll_loss(args[0], args[1], args[2], args[3], args[4]); +}); + /* Ops from nn/dense.h */ TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); From 1053a2c323db2c99e866ec9c94852a0c241afa81 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Tue, 18 May 2021 13:48:34 +0800 Subject: [PATCH 2/7] enrich the doc and rename parameters --- include/tvm/topi/nn.h | 30 +++++++-------- python/tvm/relay/frontend/pytorch.py | 10 ++--- python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/nn/nn.py | 22 +++++++---- python/tvm/topi/nn/loss.py | 19 +++++++--- src/relay/op/nn/nn.cc | 55 +++++++++++++++------------- 6 files changed, 80 insertions(+), 60 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index f649e1ee49a6..0328de8a9975 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -647,9 +647,9 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, /*! * \brief Negative log likelihood loss. * - * \param input The input tensor. - * \param target The target tensor. - * \param weight A manual rescaling weight given to each class. + * \param predictions The prediction tensor. + * \param targets The target tensor. + * \param weights A manual rescaling weight given to each class. * \param reduction The reduction method to apply to the output. * \param ignore_index The target value to ignore. * \param name The name of the operation. @@ -657,31 +657,31 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, * * \return A Tensor whose op member is the batch_to_space_nd operation */ -inline Tensor nll_loss(const Tensor& input, const Tensor& target, const Tensor& weight, +inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, std::string reduction = "mean", int ignore_index = -100, const std::string name = "nll_loss", const std::string tag = kBroadcast) { auto T = tvm::te::compute( - target->shape, + targets->shape, [&](const tvm::Array& target_indices) { - auto c = target(target_indices); - tvm::Array input_indices; + auto c = targets(target_indices); + tvm::Array pred_indices; for (size_t i = 0; i < target_indices.size(); i++) { - input_indices.push_back(target_indices[i]); + pred_indices.push_back(target_indices[i]); if (i == 0) { - input_indices.push_back(c); + pred_indices.push_back(c); } } - return tvm::tir::Select(c != ignore_index, -input(input_indices) * weight(c), - tvm::tir::make_const(input->dtype, 0)); + return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), + tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); if (reduction == "mean") { auto W = tvm::te::compute( - target->shape, + targets->shape, [&](const tvm::Array& target_indices) { - auto c = target(target_indices); - return tvm::tir::Select(c != ignore_index, weight(c), - tvm::tir::make_const(input->dtype, 0)); + auto c = targets(target_indices); + return tvm::tir::Select(c != ignore_index, weights(c), + tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); return topi::divide(topi::sum(T, {}), topi::sum(W, {})); diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 85ae33ed979c..5bdd0701d433 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2307,17 +2307,17 @@ def unique(self, inputs, input_types): def nll_loss(self, inputs, input_types): assert len(inputs) == 5 - [input, target, weight, reduction, ignore_index] = inputs - num_class = self.infer_shape(input)[1] + [predictions, targets, weights, reduction, ignore_index] = inputs + num_class = self.infer_shape(predictions)[1] if reduction == 0: reduction = "none" elif reduction == 1: reduction = "mean" else: reduction = "sum" - if weight is None: - weight = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) - return _op.nn.nll_loss(input, target, weight, reduction, ignore_index) + if weights is None: + weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) + return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 334eddc9a322..ee83f3151635 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -889,8 +889,8 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): # nll_loss @reg.register_compute("nn.nll_loss") def compute_nll_loss(attrs, inputs, out_dtype): - input, target, weights = inputs - return [topi.nn.nll_loss(input, target, weights, attrs.reduction, attrs.ignore_index)] + predictions, targets, weights = inputs + return [topi.nn.nll_loss(predictions, targets, weights, attrs.reduction, attrs.ignore_index)] reg.register_reduce_schedule("nn.nll_loss") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index ee864d290bdf..801eba0f3f8c 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2973,22 +2973,30 @@ def cross_entropy_with_logits(predictions, targets): return _make.cross_entropy_with_logits(predictions, targets) -def nll_loss(input, target, weight, reduction="mean", ignore_index=-100): +def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100): """Negative log likelihood loss. + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + Parameters ---------- - input : tvm.relay.Expr - The input. + predictions : tvm.relay.Expr + The predictions. - target : tvm.relay.Expr - The target value of the input. + targets : tvm.relay.Expr + The target value of each prediction. - weight : tvm.relay.Expr + weights : tvm.relay.Expr The weight of each target value. reduction : string The reduction method to apply to the output. + Possible values are "mean", "sum" and "none". ignore_index : int The target value to ignore. @@ -2998,7 +3006,7 @@ def nll_loss(input, target, weight, reduction="mean", ignore_index=-100): result : tvm.relay.Expr The computed result. """ - return _make.nll_loss(input, target, weight, reduction, ignore_index) + return _make.nll_loss(predictions, targets, weights, reduction, ignore_index) def depth_to_space(data, block_size, layout="NCHW", mode="DCR"): diff --git a/python/tvm/topi/nn/loss.py b/python/tvm/topi/nn/loss.py index f1c3cf47f4f6..1d6f588c7d53 100644 --- a/python/tvm/topi/nn/loss.py +++ b/python/tvm/topi/nn/loss.py @@ -15,25 +15,32 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,unused-argument -"""TVM operator negative log likelihood loss compute.""" +"""Loss functions definitions.""" from __future__ import absolute_import from . import cpp -def nll_loss(input, target, weight, reduction, ignore_index): +def nll_loss(predictions, targets, weights, reduction, ignore_index): """Negative log likelihood loss on the input data. + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + Parameters ---------- - input : tvm.te.Tensor + predictions : tvm.te.Tensor (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), where C is the number of target classes - target : tvm.te.Tensor + targets : tvm.te.Tensor (k+1)-D with shape (N, d_1, d_2, ..., d_k) The target value of the input. - weight : tvm.te.Tensor + weights : tvm.te.Tensor 1-D with shape (C,) The weight of each target value. @@ -50,4 +57,4 @@ def nll_loss(input, target, weight, reduction, ignore_index): a scalar if the reduction type is "mean" or "sum", otherwise the same shape as `target`. """ - return cpp.nn.nll_loss(input, target, weight, reduction, ignore_index) + return cpp.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 37981850b0b7..67e996f555a4 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1097,53 +1097,58 @@ TVM_REGISTER_NODE_TYPE(NLLLossAttrs); bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 4); - const auto* input = types[0].as(); - const auto* target = types[1].as(); - const auto* weight = types[2].as(); + ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << types.size() + << " were provided."; + const auto* predictions = types[0].as(); + const auto* targets = types[1].as(); + const auto* weights = types[2].as(); const NLLLossAttrs* param = attrs.as(); - if (input == nullptr || target == nullptr || weight == nullptr) return false; - ICHECK(input->shape.size() - target->shape.size() == 1) - << "NLLLossRel: input should be one dimension larger than target, " - << "input shape = " << input->shape << ", " - << "target shape = " << target->shape; - ICHECK(weight->shape.size() == 1); - ICHECK(reporter->AssertEQ(input->shape[1], weight->shape[0])) - << "NLLLossRel: the second dimension of input should be the number of classes, " - << "which is the length of weight, " - << "input shape = " << input->shape << ", " - << "weight shape = " << weight->shape; - ICHECK(input->dtype == weight->dtype && input->dtype.is_float()); - ICHECK(target->dtype.is_int()); + if (predictions == nullptr || targets == nullptr || weights == nullptr) return false; + ICHECK(predictions->shape.size() - targets->shape.size() == 1) + << "NLLLossRel: predictions should be one dimension larger than targets, " + << "predictions shape = " << predictions->shape << ", " + << "targets shape = " << targets->shape; + ICHECK(weights->shape.size() == 1) + << "NLLLossRel: weights should be a one dimension Tensor with its length " + << "the number of classes, but Tensor of dimension " << weights->shape.size() + << " were provided."; + ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0])) + << "NLLLossRel: the second dimension of predictions should be the number of classes, " + << "which is the length of weights, " + << "predictions shape = " << predictions->shape << ", " + << "weights shape = " << weights->shape; + ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float()) + << "NLLLossRel: predictions and weights should be of the same floating type."; + ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type."; // assign output type if (param->reduction == "none") { - reporter->Assign(types[3], TensorType(target->shape, input->dtype)); + reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype)); } else { - reporter->Assign(types[3], TensorType({}, input->dtype)); + reporter->Assign(types[3], TensorType({}, predictions->dtype)); } return true; } // Handler to create a call to the padding op used by front-end FFI -Expr MakeNLLLoss(Expr input, Expr target, Expr weight, String reduction, int ignore_index) { +Expr MakeNLLLoss(Expr predictions, Expr targets, Expr weights, String reduction, int ignore_index) { auto attrs = make_object(); attrs->reduction = reduction; attrs->ignore_index = ignore_index; static const Op& op = Op::Get("nn.nll_loss"); - return Call(op, {input, target, weight}, Attrs(attrs), {}); + return Call(op, {predictions, targets, weights}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss); RELAY_REGISTER_OP("nn.nll_loss") .describe(R"code( -Negative log likelihood loss for given input and target. +Negative log likelihood loss for given prediction and target. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(3) - .add_argument("input", "Tensor", "The input tensor.") - .add_argument("target", "Tensor", "The target tensor.") - .add_argument("weight", "Tensor", "The weight of each target values.") + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "Tensor", "The weight of each target values.") .add_type_rel("NLLLoss", NLLLossRel); bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attrs, From d8f111b578321d299d50b588c41f14569ae4e73d Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Wed, 19 May 2021 12:09:12 +0800 Subject: [PATCH 3/7] update upon review --- include/tvm/topi/nn.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 0328de8a9975..b54e93b739ce 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -665,11 +665,10 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::Array& target_indices) { auto c = targets(target_indices); tvm::Array pred_indices; - for (size_t i = 0; i < target_indices.size(); i++) { - pred_indices.push_back(target_indices[i]); - if (i == 0) { - pred_indices.push_back(c); - } + pred_indices.push_back(target_indices[0]); // batch index + pred_indices.push_back(c); // class index + for (size_t i = 1; i < target_indices.size(); i++) { + pred_indices.push_back(target_indices[i]); // indices for multidimensional loss } return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), tvm::tir::make_const(predictions->dtype, 0)); From 4b717d6b4fa467489f418ff22a095b670a6f2335 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Sun, 30 May 2021 23:51:36 +0800 Subject: [PATCH 4/7] add tests --- include/tvm/topi/nn.h | 2 +- python/tvm/relay/frontend/pytorch.py | 1 + python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/nll_loss.py | 73 +++++++++++++++++++ src/relay/op/nn/nn.cc | 52 +++++++++---- tests/python/frontend/pytorch/test_forward.py | 24 ++++++ tests/python/relay/test_op_level10.py | 44 +++++++++++ tests/python/topi/python/test_topi_loss.py | 70 ++++++++++++++++++ 8 files changed, 250 insertions(+), 17 deletions(-) create mode 100644 python/tvm/topi/testing/nll_loss.py create mode 100644 tests/python/topi/python/test_topi_loss.py diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index b54e93b739ce..33eb67a1eb9d 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -655,7 +655,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, * \param name The name of the operation. * \param tag The tag to mark the operation. * - * \return A Tensor whose op member is the batch_to_space_nd operation + * \return The negative log likelihood loss of the predictions and targets. */ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, std::string reduction = "mean", int ignore_index = -100, diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5bdd0701d433..34236b537eaa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2532,6 +2532,7 @@ def create_convert_map(self): "aten::sort": self.sort, "aten::_unique2": self.unique, "aten::nll_loss": self.nll_loss, + "aten::nll_loss2d": self.nll_loss, } def update_convert_map(self, custom_map): diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index ef7d86322be7..afb251417315 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -69,3 +69,4 @@ from .matrix_set_diag import matrix_set_diag from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python +from .nll_loss import nll_loss diff --git a/python/tvm/topi/testing/nll_loss.py b/python/tvm/topi/testing/nll_loss.py new file mode 100644 index 000000000000..b6eeb187d3b7 --- /dev/null +++ b/python/tvm/topi/testing/nll_loss.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""NLLLoss in python""" +import numpy as np + + +def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100): + """nll_loss operator implemented in numpy. + + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + + Parameters + ---------- + predictions : numpy.ndarray + (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), + where C is the number of target classes + + targets : numpy.ndarray + (k+1)-D with shape (N, d_1, d_2, ..., d_k) + The target value of the input. + + weights : numpy.ndarray + 1-D with shape (C,) + The weight of each target value. + + reduction : string + The reduction method to apply to output. + Can be "mean", "sum" or "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + output : numpy.ndarray + a scalar if the reduction type is "mean" or "sum", + otherwise the same shape as `target`. + """ + res = np.zeros(targets.shape) + weight_sum = 0.0 + for index in np.ndindex(targets.shape): + class_id = targets[index] + if class_id != ignore_index: + index_list = list(index) + pred_index = tuple(index_list[:1] + [class_id] + index_list[1:]) + res[index] = -predictions[pred_index] * weights[class_id] + weight_sum += weights[class_id] + if reduction == "mean": + return np.sum(res) / weight_sum + if reduction == "sum": + return np.sum(res) + else: + return res diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 67e996f555a4..281fc5093325 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1104,22 +1104,42 @@ bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* weights = types[2].as(); const NLLLossAttrs* param = attrs.as(); if (predictions == nullptr || targets == nullptr || weights == nullptr) return false; - ICHECK(predictions->shape.size() - targets->shape.size() == 1) - << "NLLLossRel: predictions should be one dimension larger than targets, " - << "predictions shape = " << predictions->shape << ", " - << "targets shape = " << targets->shape; - ICHECK(weights->shape.size() == 1) - << "NLLLossRel: weights should be a one dimension Tensor with its length " - << "the number of classes, but Tensor of dimension " << weights->shape.size() - << " were provided."; - ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0])) - << "NLLLossRel: the second dimension of predictions should be the number of classes, " - << "which is the length of weights, " - << "predictions shape = " << predictions->shape << ", " - << "weights shape = " << weights->shape; - ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float()) - << "NLLLossRel: predictions and weights should be of the same floating type."; - ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type."; + if (!(predictions->shape.size() - targets->shape.size() == 1)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: predictions should be one" + << " dimension larger than targets," + << "predictions shape = " << predictions->shape + << ", targets shape = " << targets->shape); + return false; + } + if (!(weights->shape.size() == 1)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: weights should be a one dimension" + << " Tensor with its length the number of classes," + << " but Tensor of dimension " << weights->shape.size() + << " were provided."); + return false; + } + if (!reporter->AssertEQ(predictions->shape[1], weights->shape[0])) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: the second dimension of predictions" + << " should be the number of classes, " + << "which is the length of weights, " + << "predictions shape = " << predictions->shape + << ", weights shape = " << weights->shape); + return false; + } + if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: predictions and weights should" + << " be of the same floating type."); + return false; + } + if (!targets->dtype.is_int()) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: targets should be of int type."); + return false; + } // assign output type if (param->reduction == "none") { reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype)); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07f0d8e75c4d..0b75961ecdc0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3866,6 +3866,29 @@ def test_fn(is_sorted, return_inverse, return_counts): verify_trace_model(test_fn(True, False, True), [in_data], targets) +def test_forward_nll_loss(): + torch.set_grad_enabled(False) + N, C = 10, 3 + predictions = torch.rand((N, C)).float() + targets = torch.randint(0, 3, (N,)) + weights = torch.tensor([1, 2, 3]).float() + verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) + + # multidimension nll loss (aten::nll_loss2d) + d1, d2 = 2, 3 + predictions = torch.rand((N, C, d1, d2)).float() + targets = torch.randint(0, 3, (N, d1, d2)) + verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4007,6 +4030,7 @@ def test_fn(is_sorted, return_inverse, return_counts): test_unique() test_hard_swish() test_hard_sigmoid() + test_forward_nll_loss() # Model tests test_resnet18() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 96d90b2a4f76..6a5c8c4a1f52 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -577,6 +577,49 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT") +@tvm.testing.uses_gpu +def test_nll_loss(): + def _get_oshape(target_shape, reduction): + if reduction == "none": + return target_shape + else: + return [] + + def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"): + C = prediction_shape[1] + target_shape = prediction_shape[:1] + prediction_shape[2:] + + predictions = relay.var("predictions", relay.TensorType(prediction_shape, dtype)) + targets = relay.var("targets", relay.TensorType(target_shape, "int32")) + weights = relay.var("weights", relay.TensorType((C,), dtype)) + ignore_index_const = relay.const(ignore_index) + out = relay.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + checked = run_infer_type(out) + assert checked.checked_type == relay.ty.TensorType( + _get_oshape(target_shape, reduction), dtype + ) + func = relay.Function([predictions, targets, weights], out) + predictions_np = np.random.uniform(size=prediction_shape).astype(dtype) + targets_np = np.random.randint(0, C, target_shape).astype("int32") + weights_np = np.random.uniform(size=(C,)).astype(dtype) + out_np = tvm.topi.testing.nll_loss( + predictions_np, targets_np, weights_np, reduction, ignore_index + ) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((10, 5)) + _verify((10, 5, 2, 2)) + _verify((10, 5), reduction="sum") + _verify((10, 5), reduction="none") + _verify((10, 5), ignore_index=3) + _verify((10, 5), dtype="float64") + + if __name__ == "__main__": test_adaptive_pool() test_collapse_sum_like() @@ -590,3 +633,4 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): test_one_hot() test_ndarray_size() test_matrix_set_diag() + test_nll_loss() diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py new file mode 100644 index 000000000000..0fb3f392da35 --- /dev/null +++ b/tests/python/topi/python/test_topi_loss.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for loss operators.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +import tvm.topi.testing + +import tvm.testing + + +def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"): + C = prediction_shape[1] + target_shape = prediction_shape[:1] + prediction_shape[2:] + predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) + targets = te.placeholder(shape=target_shape, name="targets", dtype="int32") + weights = te.placeholder(shape=(C,), name="weights", dtype=dtype) + nll_loss_result = topi.nn.nll_loss( + predictions, targets, weights, reduction, ignore_index + ) + + def check_device(target, dev): + print("Running on target: %s" % target) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result) + fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") + predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) + targets_npy = np.random.randint(0, C, target_shape).astype("int32") + weights_npy = np.random.uniform(size=(C,)).astype(dtype) + out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index) + predictions_nd = tvm.nd.array(predictions_npy, dev) + targets_nd = tvm.nd.array(targets_npy, dev) + weights_nd = tvm.nd.array(weights_npy, dev) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) + fn(predictions_nd, targets_nd, weights_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for target, dev in tvm.testing.enabled_targets(): + check_device(target, dev) + + +@tvm.testing.uses_gpu +def test_nll_loss(): + verify_nll_loss((10, 5,)) + verify_nll_loss((10, 5, 2, 2)) + verify_nll_loss((10, 5,), reduction="sum") + verify_nll_loss((10, 5,), reduction="none") + verify_nll_loss((10, 5,), ignore_index=3) + verify_nll_loss((10, 5,), dtype="float64") + + +if __name__ == "__main__": + test_nll_loss() From 47d9d04620e657bd564836d1c9632dc2c86a7db5 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Fri, 11 Jun 2021 11:51:48 +0800 Subject: [PATCH 5/7] update based on reviews --- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/topi/testing/nll_loss.py | 3 +- tests/python/topi/python/test_topi_loss.py | 61 +++++++++++----------- 3 files changed, 32 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index ee83f3151635..04d38ce39422 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -894,7 +894,7 @@ def compute_nll_loss(attrs, inputs, out_dtype): reg.register_reduce_schedule("nn.nll_loss") -reg.register_pattern("nn.nll_loss", OpPattern.OPAQUE) +reg.register_pattern("nn.nll_loss", OpPattern.OUT_ELEMWISE_FUSABLE) # depth_to_space diff --git a/python/tvm/topi/testing/nll_loss.py b/python/tvm/topi/testing/nll_loss.py index b6eeb187d3b7..fd78f6f56d00 100644 --- a/python/tvm/topi/testing/nll_loss.py +++ b/python/tvm/topi/testing/nll_loss.py @@ -69,5 +69,4 @@ def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100) return np.sum(res) / weight_sum if reduction == "sum": return np.sum(res) - else: - return res + return res diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index 0fb3f392da35..3cb7172adae4 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -25,46 +25,45 @@ import tvm.testing -def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"): +def verify_nll_loss( + dev, target, prediction_shape, reduction="mean", ignore_index=-100, dtype="float32" +): C = prediction_shape[1] target_shape = prediction_shape[:1] + prediction_shape[2:] predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) targets = te.placeholder(shape=target_shape, name="targets", dtype="int32") weights = te.placeholder(shape=(C,), name="weights", dtype=dtype) - nll_loss_result = topi.nn.nll_loss( - predictions, targets, weights, reduction, ignore_index - ) + nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + + with tvm.target.Target(target): + s = tvm.te.create_schedule(nll_loss_result.op) + fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") - def check_device(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result) - fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") - predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) - targets_npy = np.random.randint(0, C, target_shape).astype("int32") - weights_npy = np.random.uniform(size=(C,)).astype(dtype) - out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index) - predictions_nd = tvm.nd.array(predictions_npy, dev) - targets_nd = tvm.nd.array(targets_npy, dev) - weights_nd = tvm.nd.array(weights_npy, dev) - out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) - fn(predictions_nd, targets_nd, weights_nd, out_nd) - out_topi = out_nd.asnumpy() - tvm.testing.assert_allclose(out_topi, out_npy) + predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) + targets_npy = np.random.randint(0, C, target_shape).astype("int32") + weights_npy = np.random.uniform(size=(C,)).astype(dtype) + out_npy = tvm.topi.testing.nll_loss( + predictions_npy, targets_npy, weights_npy, reduction, ignore_index + ) - for target, dev in tvm.testing.enabled_targets(): - check_device(target, dev) + predictions_nd = tvm.nd.array(predictions_npy, dev) + targets_nd = tvm.nd.array(targets_npy, dev) + weights_nd = tvm.nd.array(weights_npy, dev) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) + fn(predictions_nd, targets_nd, weights_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) -@tvm.testing.uses_gpu -def test_nll_loss(): - verify_nll_loss((10, 5,)) - verify_nll_loss((10, 5, 2, 2)) - verify_nll_loss((10, 5,), reduction="sum") - verify_nll_loss((10, 5,), reduction="none") - verify_nll_loss((10, 5,), ignore_index=3) - verify_nll_loss((10, 5,), dtype="float64") +@tvm.testing.parametrize_targets +def test_nll_loss(dev, target): + verify_nll_loss(dev, target, (10, 5)) + verify_nll_loss(dev, target, (10, 5, 2, 2)) + verify_nll_loss(dev, target, (10, 5), reduction="sum") + verify_nll_loss(dev, target, (10, 5), reduction="none") + verify_nll_loss(dev, target, (10, 5), ignore_index=3) + verify_nll_loss(dev, target, (10, 5), dtype="float64") if __name__ == "__main__": - test_nll_loss() + test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm")) From ca255b28980e18715d60dbfcc26f07e3cc30998a Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Thu, 17 Jun 2021 15:55:53 +0800 Subject: [PATCH 6/7] update upon reviews --- tests/python/relay/test_op_level10.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 6a5c8c4a1f52..718bc909ad4c 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -577,8 +577,8 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT") -@tvm.testing.uses_gpu -def test_nll_loss(): +@tvm.testing.parametrize_targets +def test_nll_loss(dev, target): def _get_oshape(target_shape, reduction): if reduction == "none": return target_shape @@ -592,7 +592,6 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 predictions = relay.var("predictions", relay.TensorType(prediction_shape, dtype)) targets = relay.var("targets", relay.TensorType(target_shape, "int32")) weights = relay.var("weights", relay.TensorType((C,), dtype)) - ignore_index_const = relay.const(ignore_index) out = relay.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) checked = run_infer_type(out) assert checked.checked_type == relay.ty.TensorType( @@ -606,11 +605,10 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 predictions_np, targets_np, weights_np, reduction, ignore_index ) - for target, dev in tvm.testing.enabled_targets(): - for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) - tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) _verify((10, 5)) _verify((10, 5, 2, 2)) @@ -633,4 +631,4 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 test_one_hot() test_ndarray_size() test_matrix_set_diag() - test_nll_loss() + test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm")) From b7b9865c20a76ded8a7fe67c2bdd0ada3060d5cb Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Thu, 24 Jun 2021 10:45:21 +0800 Subject: [PATCH 7/7] update upon reviews --- tests/python/relay/test_op_level10.py | 1 - tests/python/topi/python/test_topi_loss.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 718bc909ad4c..040fa3fb4315 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -631,4 +631,3 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 test_one_hot() test_ndarray_size() test_matrix_set_diag() - test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm")) diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index 3cb7172adae4..7fd8238bf0ae 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -36,7 +36,8 @@ def verify_nll_loss( nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) with tvm.target.Target(target): - s = tvm.te.create_schedule(nll_loss_result.op) + fschedule = tvm.topi.testing.get_reduce_schedule(target) + s = fschedule([nll_loss_result]) fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)