[CP] Add attention_mask to the buffer when the mask is causal #40619
[CP] Add attention_mask to the buffer when the mask is causal #40619
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
cc @S1ro1 for visibility |
|
It's useless to add this to buffers no? CP doesn't work with explicitly passed attention mask so in accelerate we attach a hook that pops it out |
|
i was getting errors in the SFT Trainer since we use the attention mask for metrics and its shape was not matching |
|
@S1ro1 the entropy metric for example is erroring out: https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L1041-L1056 |
|
Okay, ig makes sense. But in this case it's not entirely correct to shard mask across dim(1), as mask should be sharded across dim(1) and dim(2) which is not expressable in torch DTensor placements, so just to look after that. |
|
@S1ro1 good point! How about we only split the 2d masks? |
What does this PR do?
attention_maskis always checked for causality before being added to the buffer, and that this validation is performed only once for performance.attention_maskis appended to buffers only after successful validation, preventing non-causal masks from being used.