Skip to content

Avoid registering pytree when using FSDP#39325

Closed
kaixuanliu wants to merge 4 commits intohuggingface:mainfrom
kaixuanliu:traceable-cache-fsdp
Closed

Avoid registering pytree when using FSDP#39325
kaixuanliu wants to merge 4 commits intohuggingface:mainfrom
kaixuanliu:traceable-cache-fsdp

Conversation

@kaixuanliu
Copy link
Copy Markdown
Contributor

When using FSDP, this register_pytree_node operation will cost lots of extra memory. We found after this PR: #35873, we cannot finetune 70b model using FSDP due to OOM issue.

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

@SunMarc @ArthurZucker @IlyasMoutawwakil pls help review, thx!

@IlyasMoutawwakil
Copy link
Copy Markdown
Member

IlyasMoutawwakil commented Jul 10, 2025

@kaixuanliu do you mean this PR #36311 where it was added ?
Do you have exact measures of how much it costs when FSDP is not used and when it's used, because the operation itself has nothing to do with FSDP.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu
Copy link
Copy Markdown
Contributor Author

@IlyasMoutawwakil ,Oh yes, it's 36311, not 35873. And for the extra memory, I made an experiment: I use 4 processes to do FSDP finetune with llama2-7b model, and compare the maximum memory consumption for two configurations. Result shows w/ register_pytree_node operation, it will cost extra ~12GB memory for each card.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for discovering that ! Left a comment


if is_torch_greater_or_equal("2.3"):
# Register pytree node for DynamicCache if torch version is >= 2.3 and FSDP is not imported, FSDP will need more extra memory when using pytree node
if is_torch_greater_or_equal("2.3") and "torch.distributed.fsdp" not in sys.modules:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we register the pytree node for DynamicCache somewhere else so that we can perform a better check compared to just checking if "torch.distributed.fsdp" not in sys.modules ? cc @IlyasMoutawwakil @gante
We also have the is_fsdp_enabled function in modeling utils that could be used to perform the check.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made more experiments both on A100 and XPU. And found that this is an XPU-specific issue. It steadily occurs on XPU but not on A100(I will try to figure out what happened on XPU later). Based on this finding, it seems not a good choice to just check if torch.distributed.fsdp in sys.modules. Can we add an ENV param here to let people can close the register_pytree_node action for DynamicCache? WDYT? @SunMarc @gante @IlyasMoutawwakil ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not disable it for xpu only ? until xpu fixes it ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this piece of code is targeted to support model export/trace. XPU also need this.

Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move it to the dynamic cache's init ? with some registration check ; because a user will only export/compile it if it was already instantiated ?
and there we skip for xpu + fsdp

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

After consideration, I think actually we do not need KVCache at all during finetune/training stage. So will close this PR, and set use_cache=False explicitly in application code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants