Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,22 @@ def _offload_model_weights(self, offload_pt_weights: bool) -> bool:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
if offload_pt_weights and not self._is_weights_offloaded:
try:
for param in self.model.parameters():
if param.storage():
param.storage().resize_(0)
for buffer in self.model.buffers():
if buffer.storage():
buffer.storage().resize_(0)

meta_model = self.model.to("meta")
del self.model
# Clear plain tensor attributes that are not registered as parameters
# or buffers (e.g. stacked expert weights in MoE models). These are not
# handled by to_empty().
param_data_ptrs = {p.data_ptr() for p in self.model.parameters()}
buf_data_ptrs = {b.data_ptr() for b in self.model.buffers()}
registered_ptrs = param_data_ptrs | buf_data_ptrs
for module in self.model.modules():
for attr_name in list(vars(module).keys()):
attr = getattr(module, attr_name, None)
if isinstance(attr, torch.Tensor) and attr.data_ptr() not in registered_ptrs:
setattr(module, attr_name, torch.empty(0, device="meta"))

# Move all parameters and buffers to meta device with empty storage.
self.model.to_empty(device="meta")
gc.collect()

self.model = meta_model
self._is_weights_offloaded = True
return True
except Exception as e:
Expand Down
31 changes: 22 additions & 9 deletions tests/unit_test/base/test_modeling_qeff_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import pytest
import torch
from transformers import GPT2Config, GPT2LMHeadModel, LlamaConfig, LlamaForCausalLM

from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
Expand Down Expand Up @@ -163,20 +164,32 @@ def test_model_offloaded_check_passes_when_not_offloaded(self):
qeff._model_offloaded_check()

def test_offload_clears_parameter_storage(self):
"""_offload_model_weights clears parameter storage."""
"""_offload_model_weights moves all parameters and buffers to meta device."""
model, cfg = make_tiny_gpt2()
qeff = QEFFAutoModelForCausalLM(model)
# Check that parameters have storage before offloading
has_storage_before = any(p.storage() and p.storage().size() > 0 for p in qeff.model.parameters())
assert has_storage_before
# Check that parameters are NOT on meta before offloading
assert not any(p.is_meta for p in qeff.model.parameters())

qeff._offload_model_weights(offload_pt_weights=True)

# After offloading, parameters should have no storage or be on meta device
has_storage_after = any(
p.storage() and p.storage().size() > 0 for p in qeff.model.parameters() if not p.is_meta
)
assert not has_storage_after
# After offloading, ALL parameters and buffers must be on meta device
assert all(p.is_meta for p in qeff.model.parameters())
assert all(b.is_meta for b in qeff.model.buffers())

def test_offload_clears_plain_tensor_attributes(self):
"""_offload_model_weights clears plain tensor attributes (not params/buffers)."""
model, cfg = make_tiny_gpt2()
qeff = QEFFAutoModelForCausalLM(model)

# Attach a plain tensor attribute to a submodule (simulates MoE stacked weights)
first_child = next(iter(qeff.model.modules()))
first_child.extra_weight = torch.randn(8, 8)
assert not first_child.extra_weight.is_meta

qeff._offload_model_weights(offload_pt_weights=True)

# The plain tensor attribute should also be on meta device
assert first_child.extra_weight.is_meta


@pytest.mark.cpu_only
Expand Down
Loading