From 31ee8a32d04bca2c327920cf255820434fc1f537 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 28 Apr 2026 21:26:00 +0530 Subject: [PATCH] fix: improve weight offloading to handle plain tensor attrs and use to_empty() Signed-off-by: Rishin Raj --- QEfficient/base/modeling_qeff.py | 24 ++++++++------ .../unit_test/base/test_modeling_qeff_base.py | 31 +++++++++++++------ 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 861992b70..cf8c556b0 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -136,18 +136,22 @@ def _offload_model_weights(self, offload_pt_weights: bool) -> bool: """Clear PyTorch model weights to reduce memory usage after ONNX export.""" if offload_pt_weights and not self._is_weights_offloaded: try: - for param in self.model.parameters(): - if param.storage(): - param.storage().resize_(0) - for buffer in self.model.buffers(): - if buffer.storage(): - buffer.storage().resize_(0) - - meta_model = self.model.to("meta") - del self.model + # Clear plain tensor attributes that are not registered as parameters + # or buffers (e.g. stacked expert weights in MoE models). These are not + # handled by to_empty(). + param_data_ptrs = {p.data_ptr() for p in self.model.parameters()} + buf_data_ptrs = {b.data_ptr() for b in self.model.buffers()} + registered_ptrs = param_data_ptrs | buf_data_ptrs + for module in self.model.modules(): + for attr_name in list(vars(module).keys()): + attr = getattr(module, attr_name, None) + if isinstance(attr, torch.Tensor) and attr.data_ptr() not in registered_ptrs: + setattr(module, attr_name, torch.empty(0, device="meta")) + + # Move all parameters and buffers to meta device with empty storage. + self.model.to_empty(device="meta") gc.collect() - self.model = meta_model self._is_weights_offloaded = True return True except Exception as e: diff --git a/tests/unit_test/base/test_modeling_qeff_base.py b/tests/unit_test/base/test_modeling_qeff_base.py index 305b6dca5..1a4432354 100644 --- a/tests/unit_test/base/test_modeling_qeff_base.py +++ b/tests/unit_test/base/test_modeling_qeff_base.py @@ -12,6 +12,7 @@ """ import pytest +import torch from transformers import GPT2Config, GPT2LMHeadModel, LlamaConfig, LlamaForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM @@ -163,20 +164,32 @@ def test_model_offloaded_check_passes_when_not_offloaded(self): qeff._model_offloaded_check() def test_offload_clears_parameter_storage(self): - """_offload_model_weights clears parameter storage.""" + """_offload_model_weights moves all parameters and buffers to meta device.""" model, cfg = make_tiny_gpt2() qeff = QEFFAutoModelForCausalLM(model) - # Check that parameters have storage before offloading - has_storage_before = any(p.storage() and p.storage().size() > 0 for p in qeff.model.parameters()) - assert has_storage_before + # Check that parameters are NOT on meta before offloading + assert not any(p.is_meta for p in qeff.model.parameters()) qeff._offload_model_weights(offload_pt_weights=True) - # After offloading, parameters should have no storage or be on meta device - has_storage_after = any( - p.storage() and p.storage().size() > 0 for p in qeff.model.parameters() if not p.is_meta - ) - assert not has_storage_after + # After offloading, ALL parameters and buffers must be on meta device + assert all(p.is_meta for p in qeff.model.parameters()) + assert all(b.is_meta for b in qeff.model.buffers()) + + def test_offload_clears_plain_tensor_attributes(self): + """_offload_model_weights clears plain tensor attributes (not params/buffers).""" + model, cfg = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + + # Attach a plain tensor attribute to a submodule (simulates MoE stacked weights) + first_child = next(iter(qeff.model.modules())) + first_child.extra_weight = torch.randn(8, 8) + assert not first_child.extra_weight.is_meta + + qeff._offload_model_weights(offload_pt_weights=True) + + # The plain tensor attribute should also be on meta device + assert first_child.extra_weight.is_meta @pytest.mark.cpu_only