Skip to content

checkpoint utility: optimize to_maxtext, add deepseek#3184

Merged
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-ckpt-opt3
Mar 27, 2026
Merged

checkpoint utility: optimize to_maxtext, add deepseek#3184
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-ckpt-opt3

Conversation

@shuningjin
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin commented Feb 18, 2026

Description

Optimize to_maxtext loading and saving

  • The main goal is to optimize the to_maxtext eager loading and saving pipelines. By controlling data type footprint, we reduces memory and delivers speedup.
  • The optimization unblocks large-scale model with checkpoint conversion utility. Consequently, we are able to onboard DeepSeek model family to utility at 671B scale (deepseek3.2-671b, deepseek3-671b, deepseek2-16b).
  • See this document for detailed logic and tests (http://shortn/_KlHIRwUxvI).
  • Fix: memory optimization (b/452391831), speed optimization (b/477316979), ds3.2 (b/469550012, b/469550011), ds2 (b/459536844), ds3 (b/457820372, b/457820735), ds (b/452392346)

Problem

Previously, eager load defaulted to transformers_class.from_pretrained(...), which loaded, converted, and saved checkpoints in float32.

  • Latency (Loading Bottleneck): The previous implementation performed an implicit conversion to float32 during the load/save cycle. Since most source tensors (e.g., in MaxText) are natively torch.bfloat16, this conversion added significant overhead to startup times.
  • Memory Inefficiency: Forcing float32 doubles the memory footprint compared to bfloat16, leading to higher peak memory usage and potential OOM (Out of Memory) issues during the loading phase on constrained hardware.
  • Version Incompatibility (Transformers 5.0+): While the current dependency is pinned to 4.57.3, the latest 5.0.0 release introduces a breaking change: NumPy's lack of native bfloat16 support causes end-to-end (e2e) failures when the library attempts to bridge tensor data.

What Changed

This PR introduces two optimized eager loading methods and adds the ability to save in bfloat16:

  • Eager load Method 1: transformers_class.from_pretrained(..., dtype="auto") to load the original tensor type.
  • Eager load Method 2: safetensors.safe_open(..., framework="pt") to load natively from safetensors. Similar to Method 1, this can either process remote repo or local path.
  • Save Dtype: Added bfloat16 as the recommended save option (with float32 retained as a backup). This works for both eager load and lazy load.
# to_maxtext flag
--eager_load_method=<transformers (default) | safetensors>
--save_dtype=<bfloat16 (default) | float32>

Why It Matters (Impact & Benefits)

  1. 2x Memory Reduction: Peak memory usage is cut in half across the board.
  • gpt-oss-120b: 1009.72 GB -> 511.47 GB
  • approx 8y -> 4y GB, where y is billion parameters
  1. Speedups for Loading Alone: Loading time is drastically reduced by avoiding native float32 casting and NumPy bottlenecks.
  • gpt-oss-120b: 78 min -> 1s
  • deepseek3-671b: 7.5 hr -> 4 min
  1. Speedup for Conversion Total: Reduce total conversion lifecyle for large model
  • gpt-oss-120b: 134.86 min -> 96.22 min. 30% speedup for 120B.
  • Expect more speedup with larger model.
  1. Unblocked Scalability: Entire conversions for massive models are now practical. deepseek3-671b previously OOM'd on 3.7TB RAM; it is now feasible with a peak of 2854.90 GB and a total conversion time of ~9.5 hours.
  2. Reduced Storage: Checkpoint sizes are smaller (e.g., gpt-oss-120b dropped from 100.17 GiB to 74.23 GiB).
  3. Increased Flexibility: Method 2 (safetensors) allows us to
  • convert models even if the HuggingFace code isn't fully available yet (e.g., deepseek-ai/DeepSeek-V3.2 is still in PR as of 2026-03).
  • convert weights omitted by Transformers class (e.g., Multi-Token-Prediction weights layers.61 is not loaded by deepseek-ai/DeepSeek-V3)

Other changes

  • to_maxtext: Reuse HF_MODEL_CONFIGS rather than transformers.AutoConfig. This accommodates model without full HuggingFace code support (e.g., deepseek3.2). This also aligns with how to_huggingface uses config.
  • to_huggingface: Initially, maxtext weights are loaded via set_decode_state, which uses config.weight_dtype. It was subsequently changed to orbax restore, which loads the weight as is. To control save dtype, we now explicitly cast it to config.weight_type in utils._process.
# to_huggingface flag
weight_type=<bfloat16 | float32 (default)| float16>

Tests

Test details in doc.

1. Performance (gpt-oss-120b)

  • Converted hf-bf16 to maxtext scanned.
  • Compared the previous method against Method 2 + bfloat16 save.
  • Result: Verified 2x memory reduction, significant total conversion time reduction (134.86 min -> 96.22 min), smaller checkpoint size, and confirmed logit precision remains very close.

2. Functionality (qwen3-0.6b)

  • Converted hf-bf16 to maxtext scanned.
  • Tested the matrix of {lazy, method1, method2} x {bfloat16, float32}.
  • Result: Verified all different load/save modes are functional and yield correct logic via logit checks.

3. Scalability (deepseek3-671b)

  • Tested {to_maxtext} x {scanned}.
  • Result: Successfully scaled to_maxtext for this model class. (Previously, only to_huggingface was feasible due to OOM constraints).

4. New DeepSeek Mappings

  • deepseek2-16b: {to_maxtext, to_huggingface} x {scanned, unscanned}
  • deepseek3-671b: {to_maxtext} x {unscanned}
  • deepseek3.2: {to_maxtext} x {scanned, unscanned}. Note to_huggingface is not enabled as DeepSeek32ForCausalLM is not supported yet, follow up in b/496411531.
  • Result: Verified new mappings are fully functional and correct via logit checks.

Examples:

# deepseek2-16b
# to_maxtext, eager load with transformers (default), save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek2-16b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=transformers --save_dtype=bfloat16

# deepseek2-16b
# to_maxtext, eager load with safetensors, save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek2-16b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=safetensors --save_dtype=bfloat16

# deepseek3.2-671b
# to_maxtext, eager load with safetensors, save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek3.2-671b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=safetensors --save_dtype=bfloat16 \
--hf_model_path=$CUSTOM_PATH_HF_BF16 # original fp8 checkpoint dequantized to bf16

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 12.19512% with 36 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...xtext/checkpoint_conversion/utils/param_mapping.py 0.00% 33 Missing ⚠️
...xt/checkpoint_conversion/utils/hf_model_configs.py 71.42% 1 Missing and 1 partial ⚠️
src/maxtext/utils/muon_utils.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 0798438 to 5702326 Compare March 24, 2026 21:36
@shuningjin shuningjin changed the title Checkpoint conversion tool: Optimize to_maxtext & Onboard deepseek2/3/3.2 checkpoint utility: optimize to_maxtext, add deepseek Mar 24, 2026
@shuningjin shuningjin marked this pull request as ready for review March 24, 2026 21:53
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Mar 25, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant improvements and optimizations to the checkpoint conversion process in MaxText, specifically focusing on DeepSeek model support (V2-16B, V3-671B, and V3.2-671B). The implementation of LazyHFLoader and the adoption of dtype="auto" in Hugging Face loading are excellent additions that substantially reduce memory overhead, making the conversion of extremely large models more feasible.

🔍 General Feedback

  • Efficiency: The shift towards memory-efficient loading strategies is a major highlight. Using safetensors on-demand avoids redundant memory consumption during the to_maxtext conversion.
  • Support: Comprehensive support for DeepSeek's MLA architecture and MoE experts is well-integrated into both hf_shape.py and param_mapping.py.
  • Maintainability: The refactoring of forward_pass_logit_checker.py and the grouping of reshape hooks in param_mapping.py significantly improve code clarity and ease of future extension.

Comment thread src/maxtext/checkpoint_conversion/utils/utils.py
Comment thread src/maxtext/checkpoint_conversion/utils/param_mapping.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM at high level! A few minor comments.

Comment thread docs/guides/checkpointing_solutions/convert_checkpoint.md
Comment thread src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py Outdated
Comment thread src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Comment thread src/maxtext/checkpoint_conversion/utils/hf_model_configs.py Outdated
Comment thread src/maxtext/checkpoint_conversion/utils/hf_shape.py
Comment thread src/maxtext/checkpoint_conversion/to_maxtext.py Outdated
Comment thread src/maxtext/checkpoint_conversion/to_maxtext.py Outdated
Comment thread src/maxtext/checkpoint_conversion/to_maxtext.py
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 91acedc to efdfc8e Compare March 26, 2026 00:21
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Mar 26, 2026
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Mar 26, 2026
@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch 3 times, most recently from 190dee7 to 22da4de Compare March 26, 2026 15:48
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Just a minor comment to make logging more useful.

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 150f104 to c93fde2 Compare March 27, 2026 18:05
@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from c93fde2 to b1a5feb Compare March 27, 2026 18:48
@copybara-service copybara-service Bot merged commit cd7a1eb into main Mar 27, 2026
45 of 48 checks passed
@copybara-service copybara-service Bot deleted the shuningjin-ckpt-opt3 branch March 27, 2026 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants