Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ... import initialization as init
from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, can_return_tuple, logging
from .configuration_cvt import CvtConfig


Expand Down Expand Up @@ -461,23 +461,15 @@ def __init__(self, config):
for stage_idx in range(len(config.depth)):
self.stages.append(CvtStage(config, stage_idx))

def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
def forward(self, pixel_values):
hidden_state = pixel_values

cls_token = None
for _, (stage_module) in enumerate(self.stages):
for stage_module in self.stages:
hidden_state, cls_token = stage_module(hidden_state)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)

if not return_dict:
return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)

return BaseModelOutputWithCLSToken(
last_hidden_state=hidden_state,
cls_token_value=cls_token,
hidden_states=all_hidden_states,
)


Expand Down Expand Up @@ -519,36 +511,42 @@ def __init__(self, config, add_pooling_layer=True):
self.encoder = CvtEncoder(config)
self.post_init()

@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
output_attentions: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> tuple | BaseModelOutputWithCLSToken:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
return_dict = return_dict if return_dict is not None else self.config.return_dict

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

encoder_outputs = self.encoder(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
# Manually collect hidden states from encoder stages
all_hidden_states = () if output_hidden_states else None
hidden_state = pixel_values
cls_token = None

for stage_module in self.encoder.stages:
hidden_state, cls_token = stage_module(hidden_state)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)

if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)

return BaseModelOutputWithCLSToken(
last_hidden_state=sequence_output,
cls_token_value=encoder_outputs.cls_token_value,
hidden_states=encoder_outputs.hidden_states,
last_hidden_state=hidden_state,
cls_token_value=cls_token,
hidden_states=all_hidden_states,
)


Expand All @@ -573,13 +571,12 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> tuple | ImageClassifierOutputWithNoAttention:
r"""
Expand All @@ -588,12 +585,7 @@ def forward(
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
outputs = self.cvt(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
outputs = self.cvt(pixel_values, **kwargs)

sequence_output = outputs[0]
cls_token = outputs[1]
Expand Down Expand Up @@ -631,10 +623,6 @@ def forward(
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)


Expand Down
Loading