Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/transformers/loss/loss_d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,37 +337,43 @@ def DFineForObjectDetectionLoss(
auxiliary_outputs = None
if config.auxiliary_loss:
if denoising_meta_values is not None:
dn_out_coord, outputs_coord = torch.split(
dn_out_coord, normal_out_coord = torch.split(
outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2
)
dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
dn_out_class, normal_out_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2)
dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2)
else:
normal_out_coord = outputs_coord.clamp(min=0, max=1)
normal_out_class = outputs_class
out_corners = predicted_corners
out_refs = initial_reference_points

if config.auxiliary_loss:
auxiliary_outputs = _set_aux_loss2(
outputs_class[:, :-1].transpose(0, 1),
outputs_coord[:, :-1].transpose(0, 1),
normal_out_class[:, :-1].transpose(0, 1),
normal_out_coord[:, :-1].transpose(0, 1),
out_corners[:, :-1].transpose(0, 1),
out_refs[:, :-1].transpose(0, 1),
out_corners[:, -1],
outputs_class[:, -1],
normal_out_class[:, -1],
)

outputs_loss["auxiliary_outputs"] = auxiliary_outputs
outputs_loss["auxiliary_outputs"].extend(
_set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)])
)

dn_auxiliary_outputs = _set_aux_loss2(
dn_out_class.transpose(0, 1),
dn_out_coord.transpose(0, 1),
dn_out_corners.transpose(0, 1),
dn_out_refs.transpose(0, 1),
dn_out_corners[:, -1],
dn_out_class[:, -1],
)
outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs
outputs_loss["denoising_meta_values"] = denoising_meta_values
if denoising_meta_values is not None:
dn_auxiliary_outputs = _set_aux_loss2(
dn_out_class.transpose(0, 1),
dn_out_coord.transpose(0, 1),
dn_out_corners.transpose(0, 1),
dn_out_refs.transpose(0, 1),
dn_out_corners[:, -1],
dn_out_class[:, -1],
)
outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs
outputs_loss["denoising_meta_values"] = denoising_meta_values

loss_dict = criterion(outputs_loss, labels)

Expand Down
43 changes: 43 additions & 0 deletions tests/models/d_fine/test_modeling_d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,49 @@ def _validate_backbone_init(config):
config = config.__class__(**config_dict)
_validate_backbone_init(config)

def test_auxiliary_losses_without_denoising(self):
"""Auxiliary losses should still be computed when num_denoising=0. Regression test for #45593."""
config = copy.deepcopy(self.model_tester.get_config())
config.num_denoising = 0
config.auxiliary_loss = True
config.num_labels = self.model_tester.num_labels

model = DFineForObjectDetection(config)
model.to(torch_device)
model.train()

pixel_values = torch.rand(
self.model_tester.batch_size,
self.model_tester.num_channels,
self.model_tester.image_size,
self.model_tester.image_size,
).to(torch_device)
labels = []
for _ in range(self.model_tester.batch_size):
labels.append(
{
"class_labels": torch.randint(0, self.model_tester.num_labels, (self.model_tester.n_targets,)).to(
torch_device
),
"boxes": torch.rand(self.model_tester.n_targets, 4).to(torch_device),
}
)

outputs = model(pixel_values=pixel_values, labels=labels)

# Main loss must exist
self.assertIsNotNone(outputs.loss)

# Aux losses MUST exist when denoising is off
self.assertTrue(
any("aux" in k for k in outputs.loss_dict), "Auxiliary losses should be computed even when num_denoising=0"
)

# Denoising losses must NOT exist when denoising is off
self.assertFalse(
any("dn_" in k for k in outputs.loss_dict), "Denoising losses should not be present when num_denoising=0"
)

@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_accelerator
@slow
Expand Down
Loading