Skip to content

Fix Mistakes with FA Padding Free#62

Merged
fabianlim merged 3 commits intomainfrom
fixup/fa
Aug 2, 2024
Merged

Fix Mistakes with FA Padding Free#62
fabianlim merged 3 commits intomainfrom
fixup/fa

Conversation

@fabianlim
Copy link
Copy Markdown
Contributor

@fabianlim fabianlim commented Aug 1, 2024

The PR #57 had a couple of mistakes that needed to be fix, This is because of two things

  1. the flash_attention_forward was moved out earlier
  2. the actual padding free fix was done later, and is still not yet relaased (probably 4.44)

The strategy now is simple:

  • if we can import DataCollatorWithFlattening successfully, means the padding free fix is done
  • if we can import _flash_attention_forward, means the function has been seperated out

Augmentation

  1. If padding free fix is done, then nothing to do, otherwise some patching is required
  2. Patch the static or method _flash_attention_forward depending on version.

Some redesign is done, since _flash_attention_forward couild either be a method or function, then thje previous method to bind _flash_attention_forward by closure doesnt hold. So we need to install a method on the backbone to intercept the position ids, then modify _flash_attention_forward to be able to access the position ids, and bind them

Bad news is that once this is done properly, the speed dropped. However, we verified that the speed is consistent when we upgrade transformers to latest main which means our implementation is correct

{'loss': 0.8762, 'grad_norm': 69.0, 'learning_rate': 2e-05, 'epoch': 0.0}
{'loss': 0.9877, 'grad_norm': 29.40625, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.0}
{'loss': 1.0518, 'grad_norm': 38.90625, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.0}
{'loss': 1.1429, 'grad_norm': 85.625, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.0}
{'loss': 1.0771, 'grad_norm': 22.890625, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.0}
{'loss': 0.9842, 'grad_norm': 33.5, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.0}
{'loss': 2.4449, 'grad_norm': 19.9375, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.01}
{'loss': 0.9717, 'grad_norm': 35.5625, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.01}
{'loss': 0.8958, 'grad_norm': 25.203125, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.01}
{'loss': 0.9145, 'grad_norm': 18.296875, 'learning_rate': 0.0, 'epoch': 0.01}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:07<00:00,  1.41it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 67.8947, 'train_samples_per_second': 5.891, 'train_steps_per_second': 1.473, 'train_tokens_per_second': 2029.615, 'train_loss': 1.1346958923339843, 'init_mem_cpu_alloc_delta': -14387679232, 'init_mem_gpu_alloc_delta': 14483611648, 'init_mem_cpu_peaked_delta': 14483382272, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 691978240, 'train_mem_gpu_alloc_delta': 28984245248, 'train_mem_cpu_peaked_delta': 0, 'train_mem_gpu_peaked_delta': 28990169600, 'before_init_mem_cpu': 15096680448, 'before_init_mem_gpu': 0, 'epoch': 0.01}

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim requested a review from achew010 August 1, 2024 16:10
@fabianlim
Copy link
Copy Markdown
Contributor Author

Potentially, this can be improved by having the bakbone function compute the cumsum once for all layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant