Skip to content

fix torchao memory check#12539

Closed
jiqing-feng wants to merge 4 commits intohuggingface:mainfrom
jiqing-feng:mem
Closed

fix torchao memory check#12539
jiqing-feng wants to merge 4 commits intohuggingface:mainfrom
jiqing-feng:mem

Conversation

@jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Oct 24, 2025

The test pytest -rA tests/quantization/torchao/test_torchao.py::TorchAoTest::test_model_memory_usage failed with

>       assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
E       assert (1416704 / 1382912) >= 2.0
tests/quantization/torchao/test_torchao.py:512: AssertionError                                                                       

on A100. I guess it is because the model is too small, most memories are consumed on cuda kernel launch instead of model weight. If we change it to a large model like black-forest-labs/FLUX.1-dev, the ratio will be 24244073472 / 12473665536 = 1.9436206143278139

@sayakpaul . Please review this PR. Thanks!

@jiqing-feng
Copy link
Contributor Author

For the black-forest-labs/FLUX.1-dev script, you can run

import torch
import os
from diffusers import FluxTransformer2DModel
from transformers import TorchAoConfig

torch.use_deterministic_algorithms(True)

def get_dummy_tensor_inputs(device=None, seed: int = 0):
    batch_size = 1
    num_latent_channels = 4096
    num_image_channels = 3
    height = width = 8
    sequence_length = 48
    embedding_dim = 768
    torch.manual_seed(seed)
    hidden_states = torch.randn((batch_size, num_latent_channels, height*width)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    encoder_hidden_states = torch.randn((batch_size, sequence_length, num_latent_channels)).to(
        device, dtype=torch.bfloat16
    )
    torch.manual_seed(seed)
    pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    text_ids = torch.randn((num_latent_channels, num_image_channels)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    image_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
    timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
    return {
        "hidden_states": hidden_states,
        "encoder_hidden_states": encoder_hidden_states,
        "pooled_projections": pooled_prompt_embeds,
        "txt_ids": text_ids,
        "img_ids": image_ids,
        "timestep": timestep,
        "guidance": timestep * 3054,
    }

@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
    device_module.reset_peak_memory_stats()
    device_module.empty_cache()
    model(**inputs)
    max_mem_allocated = device_module.max_memory_allocated()
    return max_mem_allocated

torch_device = "xpu" if torch.xpu.is_available() else "cuda"
if torch_device == "cuda":
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
device_module = torch.xpu if torch.xpu.is_available() else torch.cuda
model_id = "black-forest-labs/FLUX.1-dev"
print(f"max allocated memory before loading: {device_module.max_memory_allocated()}")
inputs = get_dummy_tensor_inputs(device=torch_device)
print(f"max allocated memory after get inputs: {device_module.max_memory_allocated()}")
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=None, torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get bf16 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after bf16 model inference: {device_module.max_memory_allocated()}")

del transformer
device_module.reset_peak_memory_stats()
device_module.empty_cache()

transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=TorchAoConfig("int8_weight_only"), torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get int8 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after int8 model inference: {device_module.max_memory_allocated()}")

@sayakpaul
Copy link
Member

I just ran the test on an H100 and it worked fine.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Oct 24, 2025

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

@jiqing-feng
Copy link
Contributor Author

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

Hi @sayakpaul , could you share the ratio on H100?

@sywangyi
Copy link
Contributor

sywangyi commented Dec 1, 2025

@sayakpaul @DN6 I'd like to jump into this case again, since we need to extend this case in xpu as well. I test this case in A100 as well. and I find it fail as well. I track the detail memory in A100

memory(B) bf16 int8wo
pure model 157184 123393
cublas workspace size 1179648 131072
pure forward 79872 79872

and since in the case, bf16 is running ahead of int8wo and cublas workspace is not delete explicitly by torch._C._cuda_clearCublasWorkspaces, which will lead the case fail since the memory like bf16 1416704(157184+1179648+79872) vs int8wo 1382913(123393+1179648+79872). 1416704/1382913<2.0, and I do not have H100 env, I guess it could pass in H100 mainly because cublas implementation is different?

and in xpu. there's no cublas workspace.
the memory is like

memory(B) bf16 int8wo
pure model 158208 124416
pure forward 88064 88064

which is also reasonable. it's weight only quantization, memory increase in the pure forward should be same between bf16 and int8wo.

so in order to extend the case support in different hardware, could we adjust ratio, only measure the model only memory instead of model memory+runtime memory?

@sayakpaul
Copy link
Member

We could adjust the ratio, depending on the hardware type.

@sywangyi
Copy link
Contributor

#12768, please help review, thanks

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 10, 2026
@yiyixuxu
Copy link
Collaborator

can we close this one now #12768 is merged?

@github-actions github-actions bot removed the stale Issues that haven't received updates label Jan 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants