Skip to content

Persimmon fa2 attention4d#27052

Closed
jeromeku wants to merge 13 commits intohuggingface:mainfrom
jeromeku:persimmon-FA2-attention4d
Closed

Persimmon fa2 attention4d#27052
jeromeku wants to merge 13 commits intohuggingface:mainfrom
jeromeku:persimmon-FA2-attention4d

Conversation

@jeromeku
Copy link
Copy Markdown

What does this PR do?

Adds Flash Attention 2 for Persimmon per #26350
Adds 2d->4d attention mask per #26792

Who can review?

@younesbelkada

@jeromeku jeromeku mentioned this pull request Oct 25, 2023
2 tasks
@jeromeku
Copy link
Copy Markdown
Author

@younesbelkada

lmk if my implementation of 4d attention mask (#26792) + FA2 needs tweaking.

Regarding previous comment, I'd like to understand HF's current strategy for integrating third-party / OSS libraries and components. Given the rapid pace of innovation in this space, want to ensure that transformers and its sister libraries remain best-in-class wrt usability and performance!

@younesbelkada
Copy link
Copy Markdown
Contributor

Thanks very much for your great contrib @jeromeku ! Sorry for the delay responding on the PR, I will have an extensive look at the PR and your questions by beginning of next week (from 30rd october) 🙏

@pszemraj
Copy link
Copy Markdown
Contributor

pszemraj commented Nov 10, 2023

Hi! great work and I don't mean to butt in here, but in case it helps take this home:

I was trying to get this to work and ran into some issues with the latest (4.36.dev0) version of transformers after cloning this pr and rebasing on main. I had to do this because of the llama2 tokenizer/optimum import issue that I get using the transformers version as is verbatim on this pr.

After scouring gh, I came across a fine-tuning repo for Fuyu, and the same author has a working version of FA2 for persimmon (I was able to train persimmon on it with FA2):

pip install git+https://github.com/phillip-kravtsov/transformers.git@floating-updates

This is experimental and the work of one dude, so for the record the working SHA is: b8000bd8d619cbbedcb806b67faa68c2300b4bd0

hope this helps!

@jeromeku
Copy link
Copy Markdown
Author

@younesbelkada

Let me know how I can improve the PR. Also, would appreciate thoughts on previous query when you get a chance.

Thanks!

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Thank you @jeromeku - I left one comment
Can you also add few lines in the documentation? You can do something similar than : #27400



# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter
class AttnMaskConverter:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead you can import it from modeling_attn_mask_utils:

from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask

Copy link
Copy Markdown
Author

@jeromeku jeromeku Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada

Made the aforementioned changes but am getting odd errors with test_flash_attn_2_generate_padding_right test (all other tests pass).

For example, here are the outputs with and without FA2:

out=tensor([[38, 17, 72, 39, 27, 98, 50, 53],
        [88, 83, 53, 72, 98, 43, 42, 90],
        [44, 38, 53, 70, 57, 77, 58, 87],
        [61, 57, 26, 71, 68,  5, 76, 60],
        [29,  8, 61, 62, 66, 12, 52, 63],
        [83, 71, 43, 61, 22, 83, 89, 89],
        [50, 38, 11, 22, 42, 13, 65, 14],
        [71, 71,  0,  9,  3, 37, 10, 84],
        [81, 19, 56, 62, 67, 18, 58, 87],
        [39,  9, 49, 48, 22, 13, 72, 20],
        [26, 96,  7, 26, 54, 46, 32,  6],
        [67,  9, 87, 93, 42,  8, 65, 14],
        [25, 83, 70, 30, 32, 92, 25, 64]], device='cuda:0')
out_fa=tensor([[38, 17, 72, 39, 27, 98, 50, 95],
        [88, 83, 53, 72, 98, 43, 42, 30],
        [44, 38, 53, 70, 57, 77, 58, 87],
        [61, 57, 26, 71, 68,  5, 76, 60],
        [29,  8, 61, 62, 66, 12, 52, 28],
        [83, 71, 43, 61, 22, 83, 89, 89],
        [50, 38, 11, 22, 42, 13, 65, 10],
        [71, 71,  0,  9,  3, 37, 10, 84],
        [81, 19, 56, 62, 67, 18, 58, 87],
        [39,  9, 49, 48, 22, 13, 72, 10],
        [26, 96,  7, 26, 54, 46, 32, 94],
        [67,  9, 87, 93, 42,  8, 65, 10],
        [25, 83, 70, 30, 32, 92, 25, 64]], device='cuda:0')

Seems like there is some unaccounted for randomness as some sequences match while are others do not.

This is odd given that the flash_attn_2_inference_padding_{left,right} tests passes as does flash_attn_2_generate_left_padding test.

How to interpret generate_padding_right test? The test is padding the last token in each sequence of a batch of input_ids then generating 1 token. Since the last token in the input is masked, does this mean the next token generated should be the prediction on the last token but where it only attends to prior tokens in the sequence (compared to the typical causal case where the prediction is also based on the token attending to itself)?

Also, when I run flash_attn_2 tests for Mistral, the generate_padding_right and inference_padding_right tests both fail as no ValueError is raised.

Any ideas what could be causing this? I've attached my venv.

my-env.txt

@jeromeku
Copy link
Copy Markdown
Author

jeromeku commented Nov 21, 2023

@younesbelkada

Trying to get to the bottom of the issue:

  • Created a fresh venv and cloned the transformers repo as of 11/20/2023 and did a pip install -e .[dev]
  • Installed flash_attn:
Name: flash-attn
Version: 2.3.4
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: trid@cs.stanford.edu
License: 
Location: /notebooks/virtualenvs/transformers-test/lib/python3.9/site-packages
Requires: einops, ninja, packaging, torch
Required-by: 
Metadata-Version: 2.1
Installer: pip
Classifiers:
  Programming Language :: Python :: 3
  License :: OSI Approved :: BSD License
  Operating System :: Unix
Entry-points:
Project-URLs:
  • Ran attn_2 tests for the following models (only showing failures):

Whisper

FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_flash_attn_2_inference - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_padding_right - AssertionError: assert False
==============

Mistral

FAILED tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_flash_attn_2_inference_padding_right - AssertionError: ValueError not raised

Bark

FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_fp32_ln - RuntimeError: FlashAttention only support fp16 and bf16 data type
FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkSemanticConfig'> for this kind of AutoModel: ...
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_fp32_ln - RuntimeError: FlashAttention only support fp16 and bf16 data type
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkCoarseConfig'> for this kind of AutoModel: Au...

GPTNeo

FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true

Llama, Distillbert, GPT BigCode tests all pass.
Falcon I get cuda device-side error, which might be due to the fact that I'm running on an A6000 (48Gb memory) which might not be sufficient.

Thoughts?

FWIW, here's the output of pip freeze:

absl-py==2.0.0
accelerate==0.24.1
aiohttp==3.9.0
aiosignal==1.3.1
alembic==1.12.1
ansi2html==1.8.0
APScheduler==3.10.4
arrow==1.3.0
astunparse==1.6.3
async-timeout==4.0.3
attrs==23.1.0
audioread==3.0.1
av==9.2.0
Babel==2.13.1
backoff==1.11.1
beautifulsoup4==4.12.2
binaryornot==0.4.4
bitsandbytes==0.41.2.post2
blinker==1.7.0
cachetools==5.3.2
certifi==2023.11.17
cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.82
click==8.1.7
clldutils==3.20.0
codecarbon==1.2.0
colorama==0.4.6
colorlog==6.7.0
cookiecutter==1.7.3
csvw==3.2.1
dash==2.14.1
dash-bootstrap-components==1.5.0
dash-core-components==2.0.0
dash-html-components==2.0.0
dash-table==5.0.0
datasets==2.15.0
decorator==5.1.1
decord==0.6.0
dill==0.3.4
dlinfo==1.2.1
dm-tree==0.1.8
einops==0.7.0
etils==1.5.2
evaluate==0.4.1
exceptiongroup==1.1.3
execnet==2.0.2
faiss-cpu==1.7.4
fastjsonschema==2.19.0
filelock==3.13.1
fire==0.5.0
flash-attn==2.3.4
Flask==3.0.0
flatbuffers==23.5.26
flax==0.7.0
frozenlist==1.4.0
fsspec==2023.10.0
fugashi==1.3.0
gast==0.5.4
gitdb==4.0.11
GitPython==3.1.18
google-auth==2.23.4
google-auth-oauthlib==1.1.0
google-pasta==0.2.0
gql==3.4.1
graphql-core==3.2.3
greenlet==3.0.1
grpcio==1.59.3
h5py==3.10.0
hf-doc-builder==0.4.0
huggingface-hub==0.19.4
hypothesis==6.90.0
idna==3.4
importlib-metadata==6.8.0
importlib-resources==6.1.1
iniconfig==2.0.0
ipadic==1.0.0
isodate==0.6.1
isort==5.12.0
itsdangerous==2.1.2
jax==0.4.13
jaxlib==0.4.13
Jinja2==3.1.2
jinja2-time==0.2.0
joblib==1.3.2
jsonschema==4.20.0
jsonschema-specifications==2023.11.1
jupyter_core==5.5.0
kenlm==0.2.0
keras==2.15.0
keras-core==0.1.7
keras-nlp==0.6.3
language-tags==1.2.0
lazy_loader==0.3
libclang==16.0.6
librosa==0.10.1
llvmlite==0.41.1
lxml==4.9.3
Mako==1.3.0
Markdown==3.5.1
markdown-it-py==3.0.0
MarkupSafe==2.1.3
mdurl==0.1.2
ml-dtypes==0.2.0
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
multiprocess==0.70.12.2
namex==0.0.7
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
numba==0.58.1
numpy==1.26.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
onnx==1.15.0
onnxconverter-common==1.13.0
opt-einsum==3.3.0
optax==0.1.4
optuna==3.4.0
orbax-checkpoint==0.4.3
packaging==23.2
pandas==2.1.3
parameterized==0.9.0
phonemizer==3.2.1
Pillow==9.5.0
plac==1.4.1
platformdirs==4.0.0
plotly==5.18.0
pluggy==1.3.0
pooch==1.8.0
portalocker==2.0.0
poyo==0.5.0
protobuf==3.20.3
psutil==5.9.6
py-cpuinfo==9.0.0
pyarrow==14.0.1
pyarrow-hotfix==0.5
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pyctcdecode==0.5.0
pydantic==1.10.13
Pygments==2.17.1
pygtrie==2.5.0
pylatexenc==2.10
pynvml==11.5.0
pyparsing==3.1.1
pypng==0.20220715.0
pytest==7.4.3
pytest-timeout==2.2.0
pytest-xdist==3.4.0
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3.post1
PyYAML==6.0.1
ray==2.8.0
rdflib==7.0.0
referencing==0.31.0
regex==2023.10.3
requests==2.31.0
requests-oauthlib==1.3.1
requests-toolbelt==0.10.1
responses==0.18.0
retrying==1.3.4
rfc3986==1.5.0
rhoknp==1.3.0
rich==13.7.0
rjieba==0.1.11
rouge-score==0.1.2
rpds-py==0.13.1
rsa==4.9
ruff==0.1.6
sacrebleu==1.5.1
sacremoses==0.1.1
safetensors==0.4.0
scikit-learn==1.3.2
scipy==1.11.4
segments==2.2.1
sentencepiece==0.1.99
sigopt==8.8.2
six==1.16.0
smmap==5.0.1
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
SQLAlchemy==2.0.23
SudachiDict-core==20230927
SudachiPy==0.6.7
sympy==1.12
tabulate==0.9.0
tenacity==8.2.3
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-hub==0.15.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-text==2.15.0
tensorstore==0.1.45
termcolor==2.3.0
text-unidecode==1.3
tf2onnx==1.15.1
threadpoolctl==3.2.0
timeout-decorator==0.5.0
timm==0.9.11
tokenizers==0.15.0
tomli==2.0.1
toolz==0.12.0
torch==2.1.1
torchaudio==2.1.1
torchvision==0.16.1
tqdm==4.66.1
traitlets==5.13.0
-e git+https://github.com/huggingface/transformers@38e2633f80a4924bf613b0240622492beee4cfcc#egg=transformers
triton==2.1.0
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
tzdata==2023.3
tzlocal==5.2
unidic==1.1.0
unidic-lite==1.0.8
uritemplate==4.1.1
urllib3==1.26.18
wasabi==0.10.1
Werkzeug==3.0.1
wrapt==1.14.1
xxhash==3.4.1
yarl==1.9.3
zipp==3.17.0

@susnato
Copy link
Copy Markdown
Contributor

susnato commented Nov 22, 2023

Hi @jeromeku, the test_flash_attn_2_generate_padding_right for GptNeo is quite flaky, most of the time it passes but sometimes it fails.

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions Bot closed this Dec 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants