Skip to content

Fix label truncation for per-sample nested structures in Trainer#43395

Open
raimbekovm wants to merge 3 commits intohuggingface:mainfrom
raimbekovm:fix-mask2former-label-truncation
Open

Fix label truncation for per-sample nested structures in Trainer#43395
raimbekovm wants to merge 3 commits intohuggingface:mainfrom
raimbekovm:fix-mask2former-label-truncation

Conversation

@raimbekovm
Copy link
Copy Markdown
Contributor

@raimbekovm raimbekovm commented Jan 21, 2026

What does this PR do?

This PR fixes incorrect label handling in Trainer.evaluation_loop for models that use per-sample nested label structures like tuple[list[Tensor], list[Tensor]] (e.g., Mask2Former for instance segmentation).

The Problem

When labels are structured as (mask_labels, class_labels) where each is a list of tensors (one per image with varying sizes), gather_for_metrics truncates incorrectly:

  • With use_gather_object=False: truncates tensor dimensions instead of list length → instances are lost
  • With use_gather_object=True: truncates the tuple itself → when remainder=1, class_labels is completely dropped

This caused metrics degradation of 3-12+ mAP points depending on batch configuration, and potential crashes when class_labels was lost entirely.

Detailed Illustration

Mask2Former label structure:

labels = (
    [mask_img0, mask_img1, mask_img2, mask_img3],  # mask_labels: list of tensors
    [class_img0, class_img1, class_img2, class_img3]  # class_labels: list of tensors
)
# Each tensor has shape [num_instances, ...] where num_instances varies per image
# Example: mask_img0.shape = [5, 256, 256]  (5 instances)
#          mask_img1.shape = [3, 256, 256]  (3 instances)

Scenario: batch_size=8, last batch has 4 images, so remainder=4

Bug with use_gather_object=False:

# gather_for_metrics does: recursively_apply(lambda t: t[:remainder], labels)
# This truncates EACH TENSOR to 4 elements in first dimension:
mask_img0[:4]  # Shape [5, 256, 256] → [4, 256, 256] — LOST 1 INSTANCE!
mask_img1[:4]  # Shape [3, 256, 256] → [3, 256, 256] — ok (less than 4)

Bug with use_gather_object=True:

# gather_for_metrics does: labels[:remainder]
labels[:4]  # Returns full tuple (ok when remainder >= 2)
labels[:1]  # Returns (mask_labels,) — CLASS_LABELS COMPLETELY LOST!

What we need:

# Truncate at IMAGE level (list length), not at instance level (tensor dim)
result = (
    mask_labels[:num_samples],   # Keep first N images
    class_labels[:num_samples]   # Keep first N images
)

The Solution

Added two helper functions in trainer_pt_utils.py:

  • is_per_sample_nested() — detects tuple[list[Tensor], ...] structures
  • flatten_per_sample_nested_batches() — correctly flattens and truncates at list level

Modified Trainer.evaluation_loop to:

  1. Detect per-sample nested labels with is_per_sample_nested()
  2. Use gather_object directly (bypassing gather_for_metrics truncation)
  3. Accumulate labels from all processes with extend()
  4. Flatten and truncate correctly at the end with flatten_per_sample_nested_batches()

This also properly handles distributed training where gather_object returns labels from all GPU processes.

Fixes #43388

Before submitting

Who can review?

@SunMarc @muellerzr

Copy link
Copy Markdown

@J-Bracke J-Bracke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks like it fixes the issue.

@karthik-0306
Copy link
Copy Markdown

Thanks for the fix — the overall approach makes sense and aligns well with the issue.

I’m working on a small follow-up PR that builds on this change by:

  • adding regression tests for per-sample nested label structures
  • tightening guards to avoid applying truncation logic to non-label data (e.g. processors)
  • covering edge cases like gradient_state.remainder == 1

I’ll reference this PR once the follow-up is submitted.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this ! Is there a way to fix this simply by modifying gather_for_metrics instead of doing a separate implementation ? Also, can you put a more descriptive illustration of the issue ? In this issue, the user should indeed set use_gather_object=True

Maybe the issue that you are facing is that the data returned is a tuple in a specific case where there is only one item ? So there is indeed an issue when we truncate ?

    if use_gather_object:
        data = gather_object(input_data)
    else:
        data = self.gather(input_data)

    try:
        if self.gradient_state.end_of_dataloader:
            # at the end of a dataloader, `gather_for_metrics` regresses to
            # `gather` unless the dataset has a remainder so log.
            if self.gradient_state.remainder == -1:
                logger.info(
                    "The used dataset had no length, returning gathered tensors. You should drop the remainder yourself."
                )
                return data
            elif self.gradient_state.remainder > 0:
                # Last batch needs to be truncated on distributed systems as it contains additional samples
                def _adjust_samples(tensor):
                    return tensor[: self.gradient_state.remainder]

                if use_gather_object:
                    # gather_object put the objects in a list
                    return _adjust_samples(data)
                else:
                    return recursively_apply(_adjust_samples, data)

Can you check if this fixes your issue if you modify data to be a list just after data = gather_object(input_data) ? cc @J-Bracke

@raimbekovm
Copy link
Copy Markdown
Contributor Author

Thanks for the review @SunMarc!

I checked the list conversion idea — unfortunately it doesn't help. The issue is that data[:remainder] truncates the outer tuple/list, so with remainder=1 we get (mask_labels,) and lose class_labels entirely. Converting to list first gives the same result. We need to truncate the inner lists (images), not the outer structure (label types).

As for fixing gather_for_metrics directly — yeah, that's definitely an option. It would need the same is_per_sample_nested detection though. I went with the transformers-only approach since it's a pretty model-specific pattern and doesn't require waiting for an accelerate release, but I'm open to opening a PR there instead if you think that's cleaner.

Happy to add a more detailed illustration in the description if that would help. Let me know what you'd prefer!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Jan 28, 2026

Happy to add a more detailed illustration in the description if that would help. Let me know what you'd prefer!

If you can do that, that would be for the best !

I checked the list conversion idea — unfortunately it doesn't help. The issue is that data[:remainder] truncates the outer tuple/list, so with remainder=1 we get (mask_labels,) and lose class_labels entirely. Converting to list first gives the same result. We need to truncate the inner lists (images), not the outer structure (label types).

Oh indeed, if the images for in the inner list, then it won't work. The issue if that in distributed setup, we won't be able to correctly gather the metrics with your solution no ?

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43395&sha=87af30

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Jan 30, 2026

Can you try again @J-Bracke and check if it also works in a distributed case ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

gather_for_metrics incorrectly drops label elements in the last batch when labels is a tuple with several label types e.g. used by mask2former

4 participants