Fix loss computation in TFWav2Vec2ForCTC#18014
Fix loss computation in TFWav2Vec2ForCTC#18014Rocketknight1 merged 1 commit intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
This looks good to me! We just made the same change to several other losses (reshaping the output from a scalar to a tensor with shape My only concern is that I have no idea how the lack of |
|
@Sreyan88 I'm happy with this and I think it's ready to merge now - if you want to make any other changes, now's the time. If not, ping me and I'll merge it! |
@Rocketknight1 I'm happy you can merge! |
Co-authored-by: Sreyan-G@NVIDIA <sreyang@nvidia.com>
What does this PR do?
TFWav2Vec2ForCTCimplementation was incorrect. The CTC loss calculation wasn't proper. The root of the problem was that the CTC target labels weren't reaching the loss calculation and it was None. So adding@unpack_inputsnow unpacks the input properly and loss calculation is properly done.Additionally, the loss needed to be reshaped for backpropagation.
Fixes #18009