Support .to(device) or Device Aware Handling for Segmentation Labels in EOMTImageProcessor #42205#42228
Support .to(device) or Device Aware Handling for Segmentation Labels in EOMTImageProcessor #42205#42228Chenhao-Guan wants to merge 3 commits intohuggingface:mainfrom
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: eomt, mask2former, maskformer, oneformer |
|
additional mask2former, maskformer, oneformer added,following the suggestion of @NielsRogge. However,mask2former, maskformer, oneformer were not tested on the actual model. |
|
Thanks a lot! Will assign @yonigozlan for review |
yonigozlan
left a comment
There was a problem hiding this comment.
Happy to add this for now. A better solution imo would be to allow the .to method of BatchFeature objects to work on list of tensors. I think there was some pushback against this in the past, but I'll check internally.
Thanks @Chenhao-Guan for contributing this!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Fixes a RuntimeError: Expected all tensors to be on the same device... when training EomtForUniversalSegmentation.
The EomtImageProcessor returns labels (mask_labels, class_labels) as a list[torch.Tensor] because the number of masks per image varies. The standard BatchFeature.to(device) call in a training loop does not move tensors inside these lists, leaving them on the CPU.
This PR adds a robust check at the beginning of the EomtForUniversalSegmentation.forward() method. It automatically moves any provided mask_labels or class_labels to the same device as the pixel_values, ensuring all inputs are on the correct device before loss computation.
This fixes the device mismatch bug internally without requiring any changes from the user.
Fixes #42205
Before submitting
[x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
[x] Did you read the contributor guideline, Pull Request section?
[x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. (Fixes #42205)
[x] Did you make sure to update the documentation with your changes? (N/A - Bugfix)
[x] Did you write any new necessary tests? (Verified with issue's code)
Who can review?
@NielsRogge @molbap @merveenoyan