Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2401210
Bump transformers to 5.0
kevalmorabia97 Mar 4, 2026
b269824
Fix Bert Gradnas tracing for transformers 5.0
kevalmorabia97 Mar 6, 2026
66ec553
Add more fixes
kevalmorabia97 Mar 24, 2026
bdaa515
Fix Bert and DBRX unit tests
kevalmorabia97 Mar 24, 2026
c72454c
Fix transformers load and test_llm_qat
kevalmorabia97 Mar 24, 2026
46348a0
Remove tokenizer.batch_encode_plus
kevalmorabia97 Mar 24, 2026
1d9155b
Remove deprecated transformers arguments
kevalmorabia97 Mar 24, 2026
ee51fd7
Rename torch_dtype to dtype
kevalmorabia97 Mar 24, 2026
aa6c3ce
Remove hard-coded trust_remote_code=True
kevalmorabia97 Mar 24, 2026
7343c4f
Fix unit tests
kevalmorabia97 Mar 25, 2026
31efc36
Enable some quantizer manual tests
kevalmorabia97 Mar 25, 2026
f69d9fa
fix test
kevalmorabia97 Mar 25, 2026
2dc3140
Set min transformers 5.0
kevalmorabia97 Mar 25, 2026
b37545b
Fix more tests
kevalmorabia97 Mar 25, 2026
1024528
Fix for TRT-LLM
kevalmorabia97 Mar 25, 2026
1e45639
Let PTQ example tests run with transformers<5.0
kevalmorabia97 Mar 25, 2026
38e26e3
fix tests
kevalmorabia97 Mar 26, 2026
26cf04a
minor fixes
kevalmorabia97 Mar 30, 2026
6d3af7c
Remove transformers 5.0 compatibility patch for trtllm; disable MOE c…
kevalmorabia97 Mar 31, 2026
f707ce5
fix for cppimport container test
kevalmorabia97 Mar 31, 2026
d5b61cb
Fix spec dec example tests
kevalmorabia97 Mar 31, 2026
22e9d4d
minor
kevalmorabia97 Apr 1, 2026
fdeb1ab
Add back windows accuracy_benchmark dependencies + trust_remote_code fix
kevalmorabia97 Apr 1, 2026
6061218
revert onnx extension file back to logger
kevalmorabia97 Apr 1, 2026
c74d5ec
Pin transformers<5.4 in spec dec example
kevalmorabia97 Apr 1, 2026
48bb138
Merge branch 'main' into kmorabi/bump-transformers-5.0
kevalmorabia97 Apr 7, 2026
c3f1e87
Fix pyproject.toml version
kevalmorabia97 Apr 7, 2026
1ede794
Fix HFEagleModel for transformers 5.5
kevalmorabia97 Apr 7, 2026
54dd176
Sparse Sequential MoE fixes
kevalmorabia97 Apr 7, 2026
7dad663
Merge remote-tracking branch 'origin/main' into kmorabi/bump-transfor…
kevalmorabia97 Apr 7, 2026
ddde158
Merge branch 'main' into kmorabi/bump-transformers-5.0
kevalmorabia97 Apr 7, 2026
644e0e0
Merge branch 'main' into kmorabi/bump-transformers-5.0
kevalmorabia97 Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/example_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3"
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.03' }}-py3"
example: ${{ matrix.example }}
timeout_minutes: 30
pip_install_extras: "[hf,dev-test]"
Expand All @@ -82,7 +82,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3"
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.03' }}-py3"
example: ${{ matrix.example }}
timeout_minutes: 30
pip_install_extras: "[hf,dev-test]"
Expand All @@ -99,7 +99,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5"
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10"
example: ${{ matrix.example }}
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-rtxpro6000-latest-1
Expand All @@ -113,7 +113,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5"
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10"
example: ${{ matrix.example }}
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-rtxpro6000-latest-2
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/gpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,19 @@ jobs:
- example: gpu
timeout: 45
container_image: pytorch:26.01-py3
# tests/gpu/_extensions/test_onnx_extensions.py fails for newer containers until https://github.com/tbenthompson/cppimport/pull/98
- example: gpu-megatron
timeout: 45
container_image: pytorch:26.01-py3
- example: gpu-trtllm
timeout: 30
container_image: tensorrt-llm/release:1.3.0rc5
container_image: tensorrt-llm/release:1.3.0rc10
runs-on: linux-amd64-gpu-rtxpro6000-latest-1
timeout-minutes: ${{ matrix.timeout }}
container: &gpu_container
image: nvcr.io/nvidia/${{ matrix.container_image }}
env:
GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py
GIT_DEPTH: 1000 # For correct version
PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps: &gpu_steps
Expand Down
12 changes: 8 additions & 4 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v6
- uses: ./.github/actions/ubuntu-setup
- name: Run unit tests
run: pip install tox && COV_ARGS="--cov" tox -e py312-torch210-tf_latest-unit
run: pip install tox && COV_ARGS="--cov" tox -e py312-torch211-tf_latest-unit
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
Expand All @@ -65,6 +65,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
py: [10, 11, 13]
steps:
Expand All @@ -73,15 +74,16 @@ jobs:
with:
python-version: "3.${{ matrix.py }}"
- name: Run unit tests
run: pip install tox && tox -e py3${{ matrix.py }}-torch210-tf_latest-unit
run: pip install tox && tox -e py3${{ matrix.py }}-torch211-tf_latest-unit
multi-torch:
if: github.event_name == 'pull_request'
needs: [linux]
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
torch: [26, 27, 28, 29]
torch: [28, 29, 210]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/ubuntu-setup
Expand All @@ -93,13 +95,14 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
tf: [min]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/ubuntu-setup
- name: Run unit tests
run: pip install tox && tox -e py312-torch210-tf_${{ matrix.tf }}-unit
run: pip install tox && tox -e py312-torch211-tf_${{ matrix.tf }}-unit
launcher:
if: github.event_name == 'pull_request'
needs: [linux]
Expand All @@ -123,6 +126,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
test-env: [onnx, torch]
steps:
Expand Down
7 changes: 5 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
NVIDIA Model Optimizer Changelog
================================
Changelog
=========

0.44 (2026-05-xx)
^^^^^^^^^^^^^^^^^

Expand All @@ -25,6 +26,8 @@ NVIDIA Model Optimizer Changelog
**Misc**

- [Security] Changed the default of ``weights_only`` to ``True`` in ``torch.load`` for secure checkpoint loading. If you need to load a checkpoint that requires unpickling arbitrary objects, first register the class in ``torch.serialization.add_safe_globals([cls])`` before loading. Added :meth:`safe_save <modelopt.torch.utils.serialization.safe_save>` and :meth:`safe_load <modelopt.torch.utils.serialization.safe_load>` API to save and load checkpoints securely.
- Bump minimum required PyTorch version to 2.8.
- [Experimental] Add support for transformers>=5.0. Unified Hugging Face checkpoint export for quantized checkpoints may not work for MoE models with transformers>=5.0 yet.

0.43 (2026-04-09)
^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/_installation_for_Linux.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Latest Model Optimizer (``nvidia-modelopt``) currently has the following system
+-------------------------+-----------------------------+
| CUDA | 12.x, 13.x |
+-------------------------+-----------------------------+
| PyTorch | >=2.6 |
| PyTorch | >=2.8 |
+-------------------------+-----------------------------+
| TensorRT-LLM (Optional) | >=1.0 |
+-------------------------+-----------------------------+
Expand Down
6 changes: 3 additions & 3 deletions examples/gpt-oss/configs/sft_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ per_device_train_batch_size: 2
per_device_eval_batch_size: 2
gradient_accumulation_steps: 2
max_length: 4096
warmup_ratio: 0.03
warmup_steps: 0.03 # use warmup_ratio if using transformers<5.0
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
Expand All @@ -30,6 +30,6 @@ eval_steps: 8
dataset_test_split: test

# ModelOpt Quantization Parameters
quant_cfg: # Examples: MXFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_ONLY_CFG
# For the full list of supported configs, do: mtq.config.choices
quant_cfg: # Examples: MXFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_ONLY_CFG
# For the full list of supported configs, do: mtq.config.choices
calib_size: 128
6 changes: 3 additions & 3 deletions examples/gpt-oss/configs/sft_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ lora_alpha: 16
lora_dropout: 0.0
lora_target_modules: all-linear
max_length: 4096
warmup_ratio: 0.03
warmup_steps: 0.03 # use warmup_ratio if using transformers<5.0
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
Expand All @@ -35,6 +35,6 @@ eval_steps: 8
dataset_test_split: test

# ModelOpt Quantization Parameters
quant_cfg: # Examples: MXFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_ONLY_CFG
# For the full list of supported configs, do: mtq.config.choices
quant_cfg: # Examples: MXFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_WEIGHT_ONLY_CFG, NVFP4_MLP_ONLY_CFG
# For the full list of supported configs, do: mtq.config.choices
calib_size: 128
6 changes: 1 addition & 5 deletions examples/gpt-oss/convert_oai_mxfp4_weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ def create_parser():
parser = create_parser()
args = parser.parse_args()

kwargs = {
"device_map": "auto",
"torch_dtype": "auto",
"trust_remote_code": args.trust_remote_code,
}
kwargs = {"device_map": "auto", "dtype": "auto", "trust_remote_code": args.trust_remote_code}
if args.lora_path:
assert args.model_path is None, "You can only specify lora_path or model_path, not both."
model_path = args.base_path
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt-oss/qat-finetune-transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
" per_device_eval_batch_size=1,\n",
" gradient_accumulation_steps=2,\n",
" max_length=4096,\n",
" warmup_ratio=0.03,\n",
" warmup_steps=0.03, # use warmup_ratio if using transformers<5.0\n",
" eval_strategy=\"steps\",\n",
" eval_on_start=True,\n",
" logging_steps=10,\n",
Expand Down
2 changes: 0 additions & 2 deletions examples/gpt-oss/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
kernels>=0.9.0
torch>2.7.1
trackio
transformers>=4.55.0
trl>=0.21.0
2 changes: 1 addition & 1 deletion examples/gpt-oss/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(script_args, training_args, model_args, quant_args):
"revision": model_args.model_revision,
"trust_remote_code": model_args.trust_remote_code,
"attn_implementation": model_args.attn_implementation,
"torch_dtype": getattr(model_args, "dtype", "bfloat16"),
"dtype": getattr(model_args, "dtype", "bfloat16"),
"use_cache": not training_args.gradient_checkpointing,
}

Expand Down
2 changes: 1 addition & 1 deletion examples/llm_autodeploy/run_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def modelopt_ptq(
) -> torch.nn.Module:
"""Quantize the model with modelopt."""
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=trust_remote_code, torch_dtype="auto", device_map="auto"
model_path, trust_remote_code=trust_remote_code, dtype="auto", device_map="auto"
)
model.eval()

Expand Down
1 change: 0 additions & 1 deletion examples/llm_distill/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pyarrow
torchao>=0.14.1
transformers<5.0
trl>=0.23.0
6 changes: 3 additions & 3 deletions examples/llm_eval/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def load(self):
args.update(device_map="auto")
if self.load_8bit:
args.update(device_map="auto", load_in_8bit=True)
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
args.update(dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
if self.attn_implementation:
args["attn_implementation"] = self.attn_implementation
self.model = AutoModelForSeq2SeqLM.from_pretrained(
Expand Down Expand Up @@ -246,7 +246,7 @@ def load(self):
args.update(device_map="auto")
if self.load_8bit:
args.update(device_map="auto", load_in_8bit=True)
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
args.update(dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
if self.attn_implementation:
args["attn_implementation"] = self.attn_implementation
self.model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -327,7 +327,7 @@ def load(self):
args.update(device_map="auto")
if self.load_8bit:
args.update(device_map="auto", load_in_8bit=True)
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
args.update(dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
self.model = LlamaForCausalLM.from_pretrained(self.model_path, **args)
print_gpu_utilization()
if self.lora_path:
Expand Down
5 changes: 3 additions & 2 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| Kimi K2 | - | - | - | - | ✅ |
| MiniMax M2.1 | - | - | - | - | ✅ |
| T5 | ✅ | ✅ | ✅ | ✅ | - |
| Whisper | ✅ | ❌ | ❌ | ❌ | - |
| Whisper<sup>9</sup> | ✅ | ❌ | ❌ | ❌ | - |
| Nemotron-3 | ✅ | ❌ | ❌ | ❌ | ✅ |

> *This is a subset of the models supported. For the full list please check the [TensorRT-LLM support matrix](https://nvidia.github.io/TensorRT-LLM/reference/precision.html#support-matrix)*
Expand All @@ -127,7 +127,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
> *<sup>5.</sup>A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \
> *<sup>6.</sup>Some models currently support export to HF format only.* \
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* \
> *<sup>9.</sup>Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).*

> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.*

Expand Down
6 changes: 3 additions & 3 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def get_model(
model_kwargs = config_kwargs.copy()
# Don't set torch_dtype for VILA models as they handle it explicitly in their builder
if "vila" not in ckpt_path.lower():
model_kwargs.setdefault("torch_dtype", "auto")
model_kwargs.setdefault("dtype", "auto")

if "vila" in ckpt_path.lower():
hf_vila = AutoModel.from_pretrained(
Expand Down Expand Up @@ -634,7 +634,7 @@ def has_pack_quantized_config(config):
ckpt_path,
device_map="auto",
trust_remote_code=trust_remote_code,
torch_dtype="auto",
dtype="auto",
)
else:
architecture = hf_config.architectures[0]
Expand Down Expand Up @@ -666,7 +666,7 @@ def has_pack_quantized_config(config):
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2["dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)

Expand Down
4 changes: 1 addition & 3 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def load_and_prepare_model(
Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader)
"""
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=trust_remote_code,
model_path, dtype="auto", trust_remote_code=trust_remote_code
)
model.eval()
model_type = get_model_type(model)
Expand Down
1 change: 0 additions & 1 deletion examples/llm_ptq/requirements-t5.txt

This file was deleted.

2 changes: 0 additions & 2 deletions examples/llm_ptq/requirements-whisper.txt

This file was deleted.

1 change: 1 addition & 0 deletions examples/llm_ptq/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ compressed-tensors==0.12.0
fire
flash-attn>=2.6.0
rouge_score>=0.1.2
transformers<5.0
transformers_stream_generator
zstandard
2 changes: 1 addition & 1 deletion examples/llm_qat/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
--save_total_limit 2 \
--learning_rate $LR \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--warmup_steps 0.1 \
--lr_scheduler_type linear \
--logging_steps 1 \
--report_to tensorboard \
Expand Down
6 changes: 2 additions & 4 deletions examples/llm_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def train():
print_rank_0(f"Last checkpoint detected: {last_checkpoint}")

model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
model_args.model_name_or_path, cache_dir=training_args.cache_dir, dtype=torch.bfloat16
)
model.generation_config.do_sample = True
tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -231,7 +229,7 @@ def train():
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_model,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
)
distill_config = {
"teacher_model": teacher_model,
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_qat/notebooks/QAT_QAD_Walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "0bf60614-99a0-48b0-85a8-1d88cd7c72ba",
"metadata": {},
"outputs": [],
Expand All @@ -290,7 +290,7 @@
" per_device_eval_batch_size=1,\n",
" gradient_accumulation_steps=2,\n",
" max_length=4096,\n",
" warmup_ratio=0.03,\n",
" warmup_steps=0.03, # use warmup_ratio if using transformers<5.0\n",
" eval_strategy=\"steps\",\n",
" eval_on_start=True,\n",
" logging_steps=50,\n",
Expand Down
5 changes: 1 addition & 4 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ def main(args):
# No need to specify attn_implementation here — mtsa.sparsify() sets it
# automatically ("eager" for pytorch backend, "modelopt_triton" for triton).
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype="auto",
device_map="auto",
args.pyt_ckpt_path, attn_implementation="eager", dtype="auto", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)

Expand Down
Loading
Loading