Skip to content

Add SoftmaxCrossEntropyLossInternal to Support Dynamic ignore_index Input#7899

Merged
Lafi7e merged 9 commits intomasterfrom
weicwang/scel
Jun 9, 2021
Merged

Add SoftmaxCrossEntropyLossInternal to Support Dynamic ignore_index Input#7899
Lafi7e merged 9 commits intomasterfrom
weicwang/scel

Conversation

@Lafi7e
Copy link
Contributor

@Lafi7e Lafi7e commented Jun 1, 2021

Add SoftmaxCrossEntropyLossInternal and its gradient to support dynamic ignore_index input.

@Lafi7e Lafi7e added training issues related to ONNX Runtime training; typically submitted using template component:ortmodule labels Jun 1, 2021
mrry
mrry previously approved these changes Jun 1, 2021
Copy link
Contributor

@mrry mrry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks Vincent!

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5)

# This model with dynamic ignore_index requires torch version from commit 645119eaefd0dcf1afc54e9ad58678b5245dea78.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iK1D Does this mean that the fix also depends on a new version of PyTorch? If so, is there any way to backport it to work with PyTorch 1.8.x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite farmiliar with exporter. From this commit: pytorch/pytorch@645119e#diff-a915656e4a7a07553917af566ed0f559610872d9025063d82dacb9f5c37974c9, it said to lowering NLLLoss/CrossEntropyLoss to ATen code. I think without this, we cannot override the exporter to export the CrossEntropyLoss to ATenOp, then maybe we need exporter team to change their C++ code to support that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm reading the diff correctly, it looks like previous behavior was for cross_entropy_loss() to be a Python wrapper around some code that dispatches to torch._C._nn.nll_loss() or torch._C._nn.nll_loss2d(). After that commit, it dispatches directly to torch._C._nn.cross_entropy_loss(), which gave us the new ability to register a symbolic for it.

I think the way to make this work with the old version of PyTorch would be to register a similar NegativeLogLikelihoodLossInternal op that supports dynamic ignore_index, and this would give us the ability to work with PyTorch 1.8.x (and support models using NLLLoss in the future). Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mrry It makes sense. The tricky part is, there are C++ code in exporter to fuse LogSoftmax->NegativeLogLikelihoodLoss to single SoftmaxCrossEntropyLoss. When we use NegativeLogLikelihoodLossInternal, this fuse will not work, so we also need a new transformer in our ORT side to do this fusion to reuse the kernel implementation of SoftmaxCrossEntropyLossInternal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's a pity! Thanks for the explanation though.

Since it looks like a 1.9 release candidate branch is being prepared for PyTorch, we can probably just wait for the new version to be released, rather than adding a temporary new transformer in ORT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just didn't see your comment before I pushed this new commit to support torch 1.8.1... I think it's good that we can support this torch version, what if there are users working on this old version...

Copy link
Contributor

@mrry mrry Jun 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, thanks for doing that - I didn't expect you to go to the extra trouble, but I appreciate it! Agreed that supporting 1.8.1 is worthwhile :).

mrry
mrry previously approved these changes Jun 4, 2021
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain target to integer types")
.TypeConstraint("I", {"tensor(int64)"}, "Constrain ignore_index tensor to int64")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); })
.SetDoc(R"DOC(NegativeLogLikelihoodLossInternal)DOC");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, if a user tries to use torch.nn.NLLLoss with the current PR, it will fail because there's no kernel for NegativeLogLikelihoodLossInternal - is that right?

If so, should we add a function body implementation here, based on this code for the base ONNX op?

https://github.com/onnx/onnx/blob/d08b3e951be607b9638ab84c340fabec1fdbec83/onnx/defs/math/defs.cc#L3160-L3161

...but e.g. with the ignore_index input passed in directly, instead of converted to a constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mrry for this. I've added a function body for the new Op in a new commit. But OpSet version seems a problem here. When I added a function body for an Op from com.microsoft domain and version 1, I don't know which OpSet version to use for ONNX domain, especially the Squeeze/Unsqueeze used in the function body has different input definition between OpSet12 and OpSet13. Since currently the register_custom_op is enabled for ORTModule case only, and currently the ORTModule default OpSet is version 12, I am using OpSet12 for both the function body and the UTs.

@thiagocrepaldi
Copy link
Contributor

Is this PR still needed if #7937 is merged?

@mrry
Copy link
Contributor

mrry commented Jun 4, 2021

Is this PR still needed if #7937 is merged?

We should merge both. Even if #7937 fixes the specific issue that led to the creation of this PR, this PR strictly increases our coverage of PyTorch semantics, and may save work in the future.

mrry
mrry previously approved these changes Jun 7, 2021
pengwa
pengwa previously approved these changes Jun 8, 2021
thiagocrepaldi
thiagocrepaldi previously approved these changes Jun 8, 2021
@Lafi7e Lafi7e dismissed stale reviews from thiagocrepaldi and pengwa via 8cb7b7d June 8, 2021 23:39
@Lafi7e Lafi7e merged commit f0f3012 into master Jun 9, 2021
@Lafi7e Lafi7e deleted the weicwang/scel branch June 9, 2021 02:29
harshithapv pushed a commit that referenced this pull request Jun 16, 2021
…nput (#7899)

* add SoftmaxCrossEntropyLossInternal

* bugfix and ut

* fix ut

* fix ut

* support torch1.8.1

* function body for nll_loss_internal
harshithapv added a commit that referenced this pull request Jun 18, 2021
* Cache initializers and avoid device check ot end of forward (#7905)

* ATenOp Enhancement (#7725)

* config parser, default argument values

* ut

* win build

* maxpool2d

* fix win build

* fix build

* unfold atenop

* Update CMakeLists.txt for openvino EP (#7980)

* Add SoftmaxCrossEntropyLossInternal to Support Dynamic ignore_index Input (#7899)

* add SoftmaxCrossEntropyLossInternal

* bugfix and ut

* fix ut

* fix ut

* support torch1.8.1

* function body for nll_loss_internal

* Override ORTModule named_modules to support extra arg (#7954)

* add missing provider_options.h in packages (#7995)

* consolidate copy binary script for gpu/trt tarball package

* add provider_options.h

* add provider_options.h

* Add cuda provides files (#8002)

* Save module output for backward if needed (#8010)

* Save module output for backward if needed

* Make logic in InsertCastTransformer around forcing a node to fp32 more precise. (#8018)

* Address #7981

Reworked the logic around forcing a node to run on fp32 even if it was supported on fp16.

The github issue had multiple factors. In ORT 1.8 we remove Identity nodes that produce graph outputs as they're not needed. That resulted in a Loop node no longer having output nodes (it produces graph outputs instead), which meant the check in IsSingleInputNodeFloat16Node returned true as there was no longer a downstream Identity node processing fp16 data.

We shouldn't only force a node to fp32 in very specific circumstances, and the changes hopefully check for those more precisely.

* Fix Memory Leak from DlpackToOrtValue (#8029)

* Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987)

* Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs

Motivation:
As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing.

Root Cause:
- The test model has `Identity` operator as one of the node. The input of this node is of non-float data type.
- In DML, `Identity` operator is registered as operator which requires floating input.
- As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`.

Changes:
- Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator.
- Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test.

Related work items: #33076298

* Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER)

Motivation:
Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)]

Changes:
Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code.

Related work items: #33076298

Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>

* [js/react_native] Use a mobile ORT instead of a full ORT (#8042)

* Change full ort to mobile ort

* Update Android example to load mobile ort

* Change the format of test models to ort

* update ios to use mobile ort

* revise README

* use onnxruntime-mobile-c CocoaPods in a npm package

* fix PATH addition in windows

should set PATH, not add to the tail the copy of PATH

* Reduce Kernel Optimization (#8067)

* reduce optimization

* bug fix

* add a check

* add ut

* refactor

* add ut cases for keepdims=true

Co-authored-by: baijumeswani <bmeswani@microsoft.com>
Co-authored-by: Vincent Wang <wangwchpku@outlook.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: George Wu <jywu@microsoft.com>
Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com>
Co-authored-by: Sherlock <baihan.huang@gmail.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: sumitsays <sumitagarwal330@gmail.com>
Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
Co-authored-by: Sunghoon <35605090+hanbitmyths@users.noreply.github.com>
Co-authored-by: iperov <lepersorium@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants