Skip to content

allow TP to work in ND-parallel with fsdp cpu ram efficient loading#39999

Draft
winglian wants to merge 3 commits intohuggingface:mainfrom
winglian:tp-with-device-mesh
Draft

allow TP to work in ND-parallel with fsdp cpu ram efficient loading#39999
winglian wants to merge 3 commits intohuggingface:mainfrom
winglian:tp-with-device-mesh

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Aug 7, 2025

What does this PR do?

For N-D parallelism, When using FSDP2+TP with cpu_ram_efficient_loading, we have to specify the device_map as "meta" for non-rank0 processes. Additionally, even though we already know what device it will ultimately end up on through the device_mesh, we don't want to change the device_map for it since we've already defined it as meta device.

@SunMarc @S1ro1 @ArthurZucker

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a comment

Comment thread src/transformers/modeling_utils.py Outdated
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if tp_plan is not None and device_map is not None:
if tp_plan is not None and device_map is not None and device_mesh is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if tp_plan is not None and device_map is not None and device_mesh is not None:
if tp_plan is not None and device_map is not None and device_mesh is None:

we should check for device_mesh is None instead no ?
Also maybe we can add is_fsdp_enabled() somewhere here to make it easier to understand and add some comments

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

@salmanmohammadi salmanmohammadi Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workflow here is

  • TP plan is provided (we are applying tensor parallel)
  • device_map is set - it's meta device

We don't want to error out here, and we also don't want to infer the device map from the device mesh. I think a clearer check would be something like

Suggested change
if tp_plan is not None and device_map is not None and device_mesh is not None:
# device_map should be permitted if the user wishes to instantiate the model on meta device
if tp_plan is not None and device_map is not None and device_map != "meta":

What do you think? Is there a better check for meta device instantiation? @SunMarc @winglian

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need a device_mesh check in that. maybe

if tp_plan is not None and device_map is not None and device_map != "meta" and device_mesh is None:

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add what this enables somewhere?! 🤗 like a small snippet would be very nice, thanks for the PR!

@winglian winglian marked this pull request as draft August 8, 2025 02:49
@winglian winglian force-pushed the tp-with-device-mesh branch from 0896ad6 to 8af2853 Compare August 9, 2025 12:22
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise but would be nice to have a snippet of how to run!


# Post-processing for tensor parallelism
if device_mesh is not None:
if device_mesh is not None and "tp" in device_mesh.mesh_dim_names:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure we have to prevent cases where there is no tp in mesh dimnames, it happens a lot in inference

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we explicitly require "tp" in mesh dim names only in case of n-d parallelism, this would skip the postprocessing for every 1d parallelism case

@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @winglian if this is still breaking!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants