allow TP to work in ND-parallel with fsdp cpu ram efficient loading#39999
allow TP to work in ND-parallel with fsdp cpu ram efficient loading#39999winglian wants to merge 3 commits intohuggingface:mainfrom
Conversation
| # TODO: we can relax this check when we support taking tp_plan from a json file, for example. | ||
| raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") | ||
| if tp_plan is not None and device_map is not None: | ||
| if tp_plan is not None and device_map is not None and device_mesh is not None: |
There was a problem hiding this comment.
| if tp_plan is not None and device_map is not None and device_mesh is not None: | |
| if tp_plan is not None and device_map is not None and device_mesh is None: |
we should check for device_mesh is None instead no ?
Also maybe we can add is_fsdp_enabled() somewhere here to make it easier to understand and add some comments
There was a problem hiding this comment.
There was a problem hiding this comment.
The workflow here is
- TP plan is provided (we are applying tensor parallel)
device_mapis set - it's meta device
We don't want to error out here, and we also don't want to infer the device map from the device mesh. I think a clearer check would be something like
| if tp_plan is not None and device_map is not None and device_mesh is not None: | |
| # device_map should be permitted if the user wishes to instantiate the model on meta device | |
| if tp_plan is not None and device_map is not None and device_map != "meta": |
What do you think? Is there a better check for meta device instantiation? @SunMarc @winglian
There was a problem hiding this comment.
We still need a device_mesh check in that. maybe
if tp_plan is not None and device_map is not None and device_map != "meta" and device_mesh is None:
ArthurZucker
left a comment
There was a problem hiding this comment.
Can we add what this enables somewhere?! 🤗 like a small snippet would be very nice, thanks for the PR!
0896ad6 to
8af2853
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM otherwise but would be nice to have a snippet of how to run!
|
|
||
| # Post-processing for tensor parallelism | ||
| if device_mesh is not None: | ||
| if device_mesh is not None and "tp" in device_mesh.mesh_dim_names: |
There was a problem hiding this comment.
not 100% sure we have to prevent cases where there is no tp in mesh dimnames, it happens a lot in inference
There was a problem hiding this comment.
Agreed, we explicitly require "tp" in mesh dim names only in case of n-d parallelism, this would skip the postprocessing for every 1d parallelism case
|
cc @winglian if this is still breaking! |
What does this PR do?
For N-D parallelism, When using FSDP2+TP with cpu_ram_efficient_loading, we have to specify the device_map as "meta" for non-rank0 processes. Additionally, even though we already know what device it will ultimately end up on through the device_mesh, we don't want to change the device_map for it since we've already defined it as
metadevice.@SunMarc @S1ro1 @ArthurZucker
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.