Fix label truncation for per-sample nested structures in Trainer#43395
Fix label truncation for per-sample nested structures in Trainer#43395raimbekovm wants to merge 3 commits intohuggingface:mainfrom
Conversation
J-Bracke
left a comment
There was a problem hiding this comment.
Code looks like it fixes the issue.
|
Thanks for the fix — the overall approach makes sense and aligns well with the issue. I’m working on a small follow-up PR that builds on this change by:
I’ll reference this PR once the follow-up is submitted. |
There was a problem hiding this comment.
Thanks for this ! Is there a way to fix this simply by modifying gather_for_metrics instead of doing a separate implementation ? Also, can you put a more descriptive illustration of the issue ? In this issue, the user should indeed set use_gather_object=True
Maybe the issue that you are facing is that the data returned is a tuple in a specific case where there is only one item ? So there is indeed an issue when we truncate ?
if use_gather_object: data = gather_object(input_data) else: data = self.gather(input_data) try: if self.gradient_state.end_of_dataloader: # at the end of a dataloader, `gather_for_metrics` regresses to # `gather` unless the dataset has a remainder so log. if self.gradient_state.remainder == -1: logger.info( "The used dataset had no length, returning gathered tensors. You should drop the remainder yourself." ) return data elif self.gradient_state.remainder > 0: # Last batch needs to be truncated on distributed systems as it contains additional samples def _adjust_samples(tensor): return tensor[: self.gradient_state.remainder] if use_gather_object: # gather_object put the objects in a list return _adjust_samples(data) else: return recursively_apply(_adjust_samples, data)
Can you check if this fixes your issue if you modify data to be a list just after data = gather_object(input_data) ? cc @J-Bracke
|
Thanks for the review @SunMarc! I checked the list conversion idea — unfortunately it doesn't help. The issue is that As for fixing Happy to add a more detailed illustration in the description if that would help. Let me know what you'd prefer! |
If you can do that, that would be for the best !
Oh indeed, if the images for in the inner list, then it won't work. The issue if that in distributed setup, we won't be able to correctly gather the metrics with your solution no ? |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43395&sha=87af30 |
|
Can you try again @J-Bracke and check if it also works in a distributed case ? |
What does this PR do?
This PR fixes incorrect label handling in
Trainer.evaluation_loopfor models that use per-sample nested label structures liketuple[list[Tensor], list[Tensor]](e.g., Mask2Former for instance segmentation).The Problem
When labels are structured as
(mask_labels, class_labels)where each is a list of tensors (one per image with varying sizes),gather_for_metricstruncates incorrectly:use_gather_object=False: truncates tensor dimensions instead of list length → instances are lostuse_gather_object=True: truncates the tuple itself → whenremainder=1,class_labelsis completely droppedThis caused metrics degradation of 3-12+ mAP points depending on batch configuration, and potential crashes when
class_labelswas lost entirely.Detailed Illustration
Mask2Former label structure:
Scenario:
batch_size=8, last batch has 4 images, soremainder=4Bug with
use_gather_object=False:Bug with
use_gather_object=True:What we need:
The Solution
Added two helper functions in
trainer_pt_utils.py:is_per_sample_nested()— detectstuple[list[Tensor], ...]structuresflatten_per_sample_nested_batches()— correctly flattens and truncates at list levelModified
Trainer.evaluation_loopto:is_per_sample_nested()gather_objectdirectly (bypassinggather_for_metricstruncation)extend()flatten_per_sample_nested_batches()This also properly handles distributed training where
gather_objectreturns labels from all GPU processes.Fixes #43388
Before submitting
Who can review?
@SunMarc @muellerzr