Refactor GPT2-based models to standardized output collection interface#44015
Refactor GPT2-based models to standardized output collection interface#44015akashadsare wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: decision_transformer, gpt2, gpt_bigcode |
There was a problem hiding this comment.
Pull request overview
This PR migrates GPT-2-family PyTorch model implementations (GPT-2, GPTBigCode, DecisionTransformer’s GPT-2 backbone) to the standardized output-collection interface based on @capture_outputs / @can_return_tuple and per-model _can_record_outputs metadata.
Changes:
- Replaced manual hidden-state/attention collection loops with hook-based output capturing.
- Updated
forwardsignatures to use**kwargs: Unpack[TransformersKwargs]and rely on decorators forreturn_dicthandling. - Added
_can_record_outputsmappings for GPT-2, GPTBigCode, and DecisionTransformer components.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py |
Adds _can_record_outputs and moves GPTBigCode to decorator-based hidden-state/attention capture. |
src/transformers/models/gpt2/modeling_gpt2.py |
Refactors GPT-2 and heads to decorator-based output capture and standardized kwargs. |
src/transformers/models/decision_transformer/modeling_decision_transformer.py |
Updates DecisionTransformer’s GPT-2 backbone + top-level model to the new capture interface and standardized kwargs. |
| "attentions": OutputRecorder(GPTBigCodeAttention, index=1), | ||
| "cross_attentions": OutputRecorder(GPTBigCodeAttention, index=1), |
There was a problem hiding this comment.
Both attentions and cross_attentions are configured to record from all GPTBigCodeAttention modules with no layer_name filter. This will mix self-attention and cross-attention weights into both outputs. Use layer_name filters (e.g. attn vs crossattention) similar to the GPT-2 setup so each key captures only the intended modules.
| "attentions": OutputRecorder(GPTBigCodeAttention, index=1), | |
| "cross_attentions": OutputRecorder(GPTBigCodeAttention, index=1), | |
| "attentions": OutputRecorder(GPTBigCodeAttention, index=1, layer_name="attn"), | |
| "cross_attentions": OutputRecorder(GPTBigCodeAttention, index=1, layer_name="crossattention"), |
| from ...modeling_utils import ( | ||
| ALL_ATTENTION_FUNCTIONS, | ||
| OutputRecorder, | ||
| PreTrainedModel, | ||
| can_return_tuple, | ||
| capture_outputs, | ||
| ) | ||
| from ...processing_utils import Unpack |
There was a problem hiding this comment.
can_return_tuple and capture_outputs are imported from modeling_utils, but those symbols are not defined there (only OutputRecorder is re-exported). This will raise an ImportError at import time. Import can_return_tuple from ...utils and capture_outputs from ...utils.output_capturing (and keep OutputRecorder consistent with the other model files).
| output_attentions=output_attentions, | ||
| ) | ||
|
|
||
| hidden_states = outputs[0] |
There was a problem hiding this comment.
DecisionTransformerGPT2Block.forward() now returns a tensor, but this loop still treats the return value as a tuple (hidden_states = outputs[0]). That will index into the batch dimension and corrupt the hidden states. Assign the block output directly (and consider renaming outputs to hidden_states for clarity).
| "hidden_states": "encoder", | ||
| "attentions": "encoder", |
There was a problem hiding this comment.
This _can_record_outputs mapping targets the submodule named encoder, which causes hooks to record the entire BaseModelOutput... object returned by the encoder (for attentions, this gets appended into the collected attention list). This will produce malformed attentions/hidden_states outputs when output_attentions/output_hidden_states are enabled. If the intent is just to expose these keys at the top level while relying on the encoder’s own _can_record_outputs, set the values to an empty list (so no hook is installed) but keep the keys present.
| "hidden_states": "encoder", | |
| "attentions": "encoder", | |
| "hidden_states": [], | |
| "attentions": [], |
This PR migrates GPT-2 and its derivatives (GPTBigCode, Decision Transformer) to the new standardized output collection interface using the @capture_outputs and @can_return_tuple decorators.
By leveraging the _can_record_outputs metadata in the model classes, we remove the need for manual loops and boilerplate code to collect hidden states and attention weights in the forward pass. This improves maintainability and ensures consistency with the library's evolving standards.
Key Changes:
Standardized Collection: Replaced manual output_hidden_states and output_attentions logic with decorators.
Hook-based Capture: Defined _can_record_outputs for GPT2, GPT BigCode, and Decision Transformer components.
Attention Capture: Utilized OutputRecorder to target the second element (index 1) of attention outputs.
Improved Signatures: Updated forward signatures to use Unpack[TransformersKwargs].
Validation:
Verified GPT-2 modeling tests: pytest tests/models/gpt2/test_modeling_gpt2.py (all tests passed).
Ensured repo consistency via make fix-repo.