Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Load model before assign submodule to device to save cpu memory#770

Merged
kwen2501 merged 6 commits intopytorch:mainfrom
jiqing-feng:low-mem
Apr 13, 2023
Merged

Load model before assign submodule to device to save cpu memory#770
kwen2501 merged 6 commits intopytorch:mainfrom
jiqing-feng:low-mem

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Mar 30, 2023

@kwen2501 @HamidShojanazeri Hi, related to this issue: #723.

I found a way to reduce CPU memory costs. If I load an empty model in HF_inference.py and load the submodule's weights before assigning the submodule to the device, it will save CPU memory because each rank will only load the submodule's weights instead of the whole model. I have tested my code on opt-13b and flan-t5-xxl, and it works well.

Would you please help me to review it? Thanks!

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

@kwen2501 @HamidShojanazeri Hi, it would be nice if you can have a look at this PR. This helps save CPU memory, and it is really useful for large language models (like bloom-176b).

@kwen2501
Copy link
Copy Markdown
Contributor

kwen2501 commented Apr 4, 2023

Hi @jiqing-feng thanks so much for the PR and sorry about the delay. (We were busy implementing the HF generate support in previous days.) I am reviewing your PR now.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 9, 2023

Hi, @kwen2501 thanks for your support.

It is nice to support generation in PiPPy and I see that you set use_cache=False to make it.

I have successfully run the generation task with use_cache in PiPPy. If you want to enable generation tasks with use_cache, this PR by me may help. We can mock data of past_key_values to avoid inconsistent inputs since traced model only accepts fixed inputs.

Of course, I can submit another PR to fix it if you want. BTW, enabling use_cache in the generation task will reduce the complexity.

Thanks, hope for your response.

Copy link
Copy Markdown
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thank you SO MUCH for contributing to PiPPy!
I just added some minor comments.

Comment thread pippy/LoadModule.py Outdated
Comment on lines +52 to +53
param_name (`str`):
The full name of the parameter/buffer.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: replace param_name with tensor_name so that it is consistent with the API signature.

Comment thread pippy/LoadModule.py Outdated
Comment on lines +26 to +31
if hasattr(model, "lm_head"):
model.lm_head.weight = torch.nn.Parameter((param.clone()).to(device))
if hasattr(model, "encoder_embed_tokens"):
model.encoder_embed_tokens.weight = torch.nn.Parameter((param.clone()).to(device))
if hasattr(model, "decoder_embed_tokens"):
model.decoder_embed_tokens.weight = torch.nn.Parameter((param.clone()).to(device))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you mind putting a comment here on why it would be desirable to make a clone of the parameter for these set of parameters?

@kwen2501
Copy link
Copy Markdown
Contributor

kwen2501 commented Apr 11, 2023

For other readers, an index file may look like this:
https://huggingface.co/bigscience/bloom/blob/main/pytorch_model.bin.index.json
The PR loads weights into the pipeline modules per the weight_map in the index file, to locate the bin file that stores the weight.

@kwen2501
Copy link
Copy Markdown
Contributor

@jiqing-feng
nit: do you mind putting an example run command in the example file or in the README under the same directory?
That would help users get familiar with the important functionality.
Thank you!

@kwen2501
Copy link
Copy Markdown
Contributor

Cc @wz337 PyTorch maintainer for distributed checkpointing.

@HamidShojanazeri HamidShojanazeri self-requested a review April 11, 2023 06:14
Copy link
Copy Markdown
Contributor

@HamidShojanazeri HamidShojanazeri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jiqing-feng for the contribution, added some inline comments.

Comment thread examples/inference/HF_inference.py Outdated
model = RegNetModel.from_pretrained("facebook/regnet-y-10b-seer")
args.feature_extractor = feature_extractor
if args.index_filename is not None:
with init_empty_weights():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiqing-feng thanks for the PR!

I wonder if we would be able to implement init_empty_weight as we implemented here, this way we would not need to add accelerate as a dependency.

We are in fact looking to add this to PT, may take a bit time though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. It would be great if you can add the init_empty_weights to PT. In this case, I only import accelerate in the Huggingface example, it would not be a problem since accelerate is belong to Huggingface and many examples in transformers also used accelerate.

Comment thread pippy/LoadModule.py
import torch
from torch import nn

def load_checkpoint(
Copy link
Copy Markdown
Contributor

@HamidShojanazeri HamidShojanazeri Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiqing-feng I wonder if this is specific to HF models? wondering if this would be generalized to cover other dist checkpoints like from FSDP.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I only tested it on HF models, but I think it should work if the model has a weight map like pytorch_model.bin.index.json

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 11, 2023

@jiqing-feng nit: do you mind putting an example run command in the example file or the README under the same directory? That would help users get familiar with the important functionality. Thank you!

Hi, @kwen2501 thanks for your comments. I have added the README. The model cannot be loaded by the URL since all weights are saved in the model.bin. We can download the model and this method only supports models with pytorch_model.bin.index.json.

Comment thread examples/inference/hf_generate.py Outdated
parser.add_argument('--chunks', type=int, default=1)
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--index_filename', type=str, default=None, help="The director or url of model's index.json file")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiqing-feng can. you pls add an example of " url of model's index.json file" or we might need to have a script to download the model checkpoint that let us run this e2e.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it is a mistake and I have changed it. The model cannot be loaded by the URL since all weights are saved in the model.bin, so we need to download the model anyway. Thanks for your reminder.

@HamidShojanazeri
Copy link
Copy Markdown
Contributor

@jiqing-feng following the readme, I am running into the dtype mismatch here, wondering if I am missing something?

Traceback (most recent call last):
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 485, in worker_loop
    out_val, flat_tensor_args = forward(
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 448, in forward
    out_val = forward_maybe_with_ddp(args, kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 432, in forward_maybe_with_ddp
    out_val = stage_executor.mod(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 281, in __call__
    raise e
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.12", line 344, in forward
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

@jiqing-feng following the readme, I am running into the dtype mismatch here, wondering if I am missing something?

Traceback (most recent call last):
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 485, in worker_loop
    out_val, flat_tensor_args = forward(
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 448, in forward
    out_val = forward_maybe_with_ddp(args, kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 432, in forward_maybe_with_ddp
    out_val = stage_executor.mod(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 281, in __call__
    raise e
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.12", line 344, in forward
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half

It is because it does not support float16 and the bloom model's tensors are saved as float16. I am trying to fix it by supporting user-customized data types.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 12, 2023

@jiqing-feng following the readme, I am running into the dtype mismatch here, wondering if I am missing something?

Traceback (most recent call last):
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 485, in worker_loop
    out_val, flat_tensor_args = forward(
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 448, in forward
    out_val = forward_maybe_with_ddp(args, kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/PipelineDriver.py", line 432, in forward_maybe_with_ddp
    out_val = stage_executor.mod(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 281, in __call__
    raise e
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+ae87843-py3.8.egg/pippy/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.12", line 344, in forward
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half

Hi, @HamidShojanazeri , thanks for your reminder. It should work now, and you can also try bfloat16 with --dtype bf16. I wonder if you can help me to approve and merge it if there are no problems. Thanks!

@HamidShojanazeri
Copy link
Copy Markdown
Contributor

@jiqing-feng , thanks for the updates, it LGTM, just a minor point I found git clone model very slow and ran into some issue occasionally with complaining about not finding files in /tmp and seem some download script like this one work much faster.

@HamidShojanazeri
Copy link
Copy Markdown
Contributor

@kwen2501 can you pls have final review as per offline discussions to move forward and merge the PR.

@HamidShojanazeri
Copy link
Copy Markdown
Contributor

@jiqing-feng just one more thing came up for me, trying this model cerebras/Cerebras-GPT-13B it runs into the issue with copy meta data. I wonder if you had run into such an issue before.

NotImplementedError('Cannot copy out of meta tensor; no data!')
Traceback (most recent call last):
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/distributed/rpc/rref_proxy.py", line 11, in _local_invoke
    return getattr(rref.local_value(), func_name)(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+9075a99-py3.8.egg/pippy/PipelineDriver.py", line 282, in create_stage_executor
    mod=mod or Pipe.materialize_stage(mod_name),  # type: ignore[attr-defined]
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+9075a99-py3.8.egg/pippy/IR.py", line 1060, in materialize_stage
    return submodule.to(device)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1145, in to
    return self._apply(convert)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 820, in _apply
    param_applied = fn(param)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1143, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

@kwen2501
Copy link
Copy Markdown
Contributor

I just pushed a commit to fix the lint issues complained by the CI.
I will let the CI run again. If things look good, I will merge into main.

@kwen2501 kwen2501 merged commit 500da19 into pytorch:main Apr 13, 2023
@kwen2501
Copy link
Copy Markdown
Contributor

Hi @jiqing-feng thanks much for this important feature that enables large model loading. I merged this PR to main.

As a follow-up, I would like to check with you whether the following API semantic would read more composable:

    # all_compile today returns a submodule corresponding to the rank's pipeline stage
    pipe_driver, stage_mod = pippy.all_compile(
        model,
        num_ranks,
        args.chunks,
        split_policy=split_policy,
        tracer=PiPPyHFTracer(),
        concrete_args=concrete_args,
    )

    pippy.load_checkpoint(stage_mod, index_filename, device)

It may be just a matter of style, but can also help all_compile to focus on the compilation part.
Would appreciate your thought.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 14, 2023

@jiqing-feng just one more thing came up for me, trying this model cerebras/Cerebras-GPT-13B it runs into the issue with copy meta data. I wonder if you had run into such an issue before.

NotImplementedError('Cannot copy out of meta tensor; no data!')
Traceback (most recent call last):
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/distributed/rpc/rref_proxy.py", line 11, in _local_invoke
    return getattr(rref.local_value(), func_name)(*args, **kwargs)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+9075a99-py3.8.egg/pippy/PipelineDriver.py", line 282, in create_stage_executor
    mod=mod or Pipe.materialize_stage(mod_name),  # type: ignore[attr-defined]
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/pippy-0.1.0a0+9075a99-py3.8.egg/pippy/IR.py", line 1060, in materialize_stage
    return submodule.to(device)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1145, in to
    return self._apply(convert)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 820, in _apply
    param_applied = fn(param)
  File "/opt/conda/envs/TS-Py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1143, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

Thanks for your comment. I will check it

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @jiqing-feng thanks much for this important feature that enables large model loading. I merged this PR to main.

As a follow-up, I would like to check with you whether the following API semantic would read more composable:

    # all_compile today returns a submodule corresponding to the rank's pipeline stage
    pipe_driver, stage_mod = pippy.all_compile(
        model,
        num_ranks,
        args.chunks,
        split_policy=split_policy,
        tracer=PiPPyHFTracer(),
        concrete_args=concrete_args,
    )

    pippy.load_checkpoint(stage_mod, index_filename, device)

It may be just a matter of style, but can also help all_compile to focus on the compilation part. Would appreciate your thought.

Hi, @kwen2501 , thanks for your comment. Load checkpoint after all compile may not work.
image
All weights in the submodule should be loaded before submodule.to(device), otherwise, the error will occur: NotImplementedError: Cannot copy out of meta tensor; no data!

@jiqing-feng jiqing-feng mentioned this pull request Apr 14, 2023
@kwen2501
Copy link
Copy Markdown
Contributor

Thanks for the reply @jiqing-feng .
I was thinking of separate APIs because that's how checkpoint loading is used in other distribution techniques. See for example: https://github.com/pytorch/pytorch/blob/master/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py

In this specific case, how about that we delay submodule.to(device) when we found that submodule is on meta device?

@kwen2501
Copy link
Copy Markdown
Contributor

I have to admit though the separate API approach would only work for pippy.all_compile and not for pippy.compile because the latter does not return a stage module.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 17, 2023

I have to admit though the separate API approach would only work for pippy.all_compile and not for pippy.compile because the latter does not return a stage module.

Yes, and if we want to keep the all_compile API clean, we can use an environment variable to pass the model directory.

And I was wondering if you could have a look on 777 which fixed loading some models. Thanks! @kwen2501 @HamidShojanazeri

kwen2501 pushed a commit that referenced this pull request Apr 17, 2023
Relate to [770](#770). This PR
solved the problem of loading parameters saved in module._parameters by
matching the parameters' names.

Hi, @HamidShojanazeri , gpt models like cerebras/Cerebras-GPT-13B should
work with this PR. BTW, I think we can keep git clone models for now
since it is recommended by Huggingface officially. I could have a try on
your recommended way and will integrate it if possible.

Hi, @kwen2501 @HamidShojanazeri , could you help me review it? Thanks!
@jiqing-feng jiqing-feng deleted the low-mem branch April 18, 2023 03:02
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants