Skip to content

fix: add identity reverse_op to dequantize ops for save_pretrained#44983

Merged
SunMarc merged 1 commit intohuggingface:mainfrom
Hyungkeun-Park-Nota:fix/mxfp4-dequantize-reverse-op
Mar 27, 2026
Merged

fix: add identity reverse_op to dequantize ops for save_pretrained#44983
SunMarc merged 1 commit intohuggingface:mainfrom
Hyungkeun-Park-Nota:fix/mxfp4-dequantize-reverse-op

Conversation

@Hyungkeun-Park-Nota
Copy link
Copy Markdown
Contributor

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota commented Mar 25, 2026

What does this PR do?

Fixes save_pretrained() for models loaded with dequantize=True.

save_pretrained calls reverse_op on all weight conversion operations from loading. Dequantize ops (Mxfp4Dequantize, Fp8Dequantize, MetalDequantize) don't have reverse_op implemented, so it raises NotImplementedError.

Since dequantized weights are already in their target dtype, the reverse op should just pass them through. Added _IdentityOp in core_model_loading.py and set it as reverse_op on all three dequantize operations.

Tested with

  • MXFP4: openai/gpt-oss-20b + Mxfp4Config(dequantize=True)
  • FP8: Qwen/Qwen3-0.6B-FP8 + FineGrainedFP8Config(dequantize=True)
  • Metal: medmekk/Llama-3.2-1B-Instruct-metal + MetalConfig(dequantize=True)

All three: save ✓, no quantization_config in saved config.json ✓, reload ✓ (0 meta params)

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from 1742755 to 9c59dda Compare March 25, 2026 01:24
@Rocketknight1
Copy link
Copy Markdown
Member

cc @SunMarc for quantization maybe?

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Mar 25, 2026

model = AutoModelForCausalLM.from_pretrained(
"openai/gpt-oss-20b",
quantization_config=Mxfp4Config(dequantize=True),
)
model.save_pretrained("/tmp/test") # NotImplementedError before this fix

hmmm, this shouldn't trigger a reverse ops when we dequantized the model. I think the right behavior here would be to just save the model in its dequantized form.

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from 9c59dda to b676da0 Compare March 26, 2026 01:34
@Hyungkeun-Park-Nota
Copy link
Copy Markdown
Contributor Author

Hyungkeun-Park-Nota commented Mar 26, 2026

@SunMarc Thanks for the review! Updated the PR based on your feedback:

  1. Removed re-quantization logic — replaced Mxfp4ReverseDequantize with Mxfp4IdentityOp that simply passes through bf16 weights as-is during save
  2. Remove quantization_config after dequantize — in _process_model_after_weight_loading, when dequantize=True, we delete model.config.quantization_config so the saved model loads as a regular bf16 model without triggering the MXFP4 loading path

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mxfp4

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from cd2c8fe to 13f9355 Compare March 26, 2026 02:11
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few question, thanks for iterating

Comment on lines +177 to +178
if self.quantization_config.dequantize and hasattr(model.config, "quantization_config"):
del model.config.quantization_config
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when calling dequantize, i think there is a function that is triggered to remove all traces of quantization no ? maybe we don't need to do this

Copy link
Copy Markdown
Contributor Author

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. remove_quantization_config in base.py already handles this during postprocess_model. Removed the mxfp4-specific deletion.

Comment thread src/transformers/integrations/mxfp4.py Outdated
Comment on lines +148 to +150
@property
def reverse_op(self) -> ConversionOps:
return Mxfp4IdentityOp()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even after dequantizing, do we still need the reverse ops ? Can you check the behavior with fp8 when dequantize is called also ?

Copy link
Copy Markdown
Contributor Author

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue exists. Tested with Qwen/Qwen3-0.6B-FP8 using FineGrainedFP8Config(dequantize=True): save_pretrained also raises NotImplementedError because _weight_conversions remains after remove_quantization_config.

So the fix is now in base.py, remove_quantization_config clears _weight_conversions alongside hf_quantizer and quantization_config. This makes the mxfp4-specific reverse_op and config removal unnecessary, so I've removed them. The fix covers both mxfp4 and fp8.

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from 13f9355 to ae3e643 Compare March 26, 2026 15:52
@Hyungkeun-Park-Nota
Copy link
Copy Markdown
Contributor Author

Hyungkeun-Park-Nota commented Mar 26, 2026

Updated based on review feedback. The fix is now a 2-line change: clearing _weight_conversions in remove_quantization_config.

Verified with both affected quantizers:

MXFP4 (openai/gpt-oss-20b, Mxfp4Config(dequantize=True)):

  • save_pretrained
  • quantization_config absent from saved config.json
  • Reload (0 meta-device params)

FP8 (Qwen/Qwen3-0.6B-FP8, FineGrainedFP8Config(dequantize=True)):

  • save_pretrained (was also failing with NotImplementedError before this fix)
  • quantization_config absent from saved config.json
  • Reload (0 meta-device params)

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota changed the title fix: implement Mxfp4Dequantize.reverse_op for save_pretrained support fix: clear _weight_conversions in remove_quantization_config for save_pretrained support Mar 26, 2026
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from ae3e643 to 9553783 Compare March 26, 2026 16:02
Comment thread src/transformers/quantizers/base.py Outdated
Comment on lines +206 to +207
if hasattr(model, "_weight_conversions"):
del model._weight_conversions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm but still we shouldn't just remove all weight_conversions. not all weight_conversions are related to quantization. Can we find a way to remove only the weight conversion that makes sense ? otherwise, we can just update the weight conversion ops if it is simpler

Copy link
Copy Markdown
Contributor Author

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't realize not all weight_conversions are quantization-related thanks for catching that.

Between the two options, adding identity reverse_op on each dequantize op seemed simpler, so went with that. Added _IdentityOp to core_model_loading.py and set it as reverse_op on Mxfp4Dequantize, Fp8Dequantize, and MetalDequantize (all three had the same issue).

Tested save + reload on:

  • MXFP4: openai/gpt-oss-20b
  • FP8: Qwen/Qwen3-0.6B-FP8
  • Metal: medmekk/Llama-3.2-1B-Instruct-metal

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch 2 times, most recently from de90e5b to ece5d90 Compare March 26, 2026 16:45
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota changed the title fix: clear _weight_conversions in remove_quantization_config for save_pretrained support fix: add identity reverse_op to dequantize ops for save_pretrained Mar 26, 2026
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from ece5d90 to 61347c0 Compare March 26, 2026 17:02
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

…ined

Dequantize operations (Mxfp4Dequantize, Fp8Dequantize, MetalDequantize)
raise NotImplementedError on reverse_op, causing save_pretrained to fail
for models loaded with dequantize=True.

Add _IdentityOp as the reverse_op so dequantized weights are saved as-is.
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/mxfp4-dequantize-reverse-op branch from 0d756c1 to 4212234 Compare March 27, 2026 14:36
@SunMarc SunMarc enabled auto-merge March 27, 2026 17:00
@SunMarc SunMarc added this pull request to the merge queue Mar 27, 2026
Merged via the queue into huggingface:main with commit ecdf95c Mar 27, 2026
29 checks passed
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Mar 30, 2026
…uggingface#44983)

fix: add identity reverse_op to dequantize operations for save_pretrained

Dequantize operations (Mxfp4Dequantize, Fp8Dequantize, MetalDequantize)
raise NotImplementedError on reverse_op, causing save_pretrained to fail
for models loaded with dequantize=True.

Add _IdentityOp as the reverse_op so dequantized weights are saved as-is.
3outeille added a commit that referenced this pull request Apr 14, 2026
The _IdentityOp class (added by PR #44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3outeille added a commit that referenced this pull request Apr 14, 2026
* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR #44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants