Skip to content

Add Cambricon MLU accelerator support#2552

Merged
muellerzr merged 8 commits intohuggingface:mainfrom
huismiling:main
Mar 20, 2024
Merged

Add Cambricon MLU accelerator support#2552
muellerzr merged 8 commits intohuggingface:mainfrom
huismiling:main

Conversation

@huismiling
Copy link
Copy Markdown
Contributor

@huismiling huismiling commented Mar 13, 2024

What does this PR do?

If I want to use Cambricon MLUs to train 🤗 Transformers models, the support should be added in Accelerate first and then will come in the Trainer for free.
This PR will support Cambricon MLU accelerator:

  1. Sample config after running the accelerate config command:
debug: false
distributed_type: MULTI_MLU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  1. run nlp_example.py with MLUs.
    accelerate launch nlp_example.py
    Below are the output logs:
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
epoch 0: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
epoch 1: {'accuracy': 0.7058823529411765, 'f1': 0.8170731707317073}
epoch 2: {'accuracy': 0.7598039215686274, 'f1': 0.8398692810457516}
  1. about Cambricon MLU
    Cambricon MLU is a AI processor that support AI frameworks like PyTorch, TensorFlow, etc. So, Its possible run Transformers/Accelerate on MLUs to train foundation model. Website: https://www.cambricon.com

@huismiling
Copy link
Copy Markdown
Contributor Author

@sgugger Hi, good day. Could you please review this PR, thanks

@muellerzr
Copy link
Copy Markdown
Contributor

muellerzr commented Mar 13, 2024

Sylvain is no longer on this project/at Hugging Face, I’ll review this today. Thanks for your contribution!

@muellerzr muellerzr self-requested a review March 13, 2024 11:05
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this is a very sound PR and looks exciting to support yet another hardware! 🚀

Left some suggestions and a few questions I'd like addressed before moving forward. Thanks!

Comment thread src/accelerate/utils/imports.py Outdated
Comment thread src/accelerate/utils/imports.py Outdated
Comment thread src/accelerate/utils/imports.py Outdated
@muellerzr
Copy link
Copy Markdown
Contributor

Also for the quality checks, please do pip install -e .[quality] along with make style; make quality

it's beautiful !

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@huismiling
Copy link
Copy Markdown
Contributor Author

Also for the quality checks, please do pip install -e .[quality] along with make style; make quality

Below is the output.

ruff format .
126 files left unchanged
doc-builder style src/accelerate docs/source --max_len 119
ruff check .
ruff format --check .
126 files already formatted
doc-builder style src/accelerate docs/source --max_len 119 --check_only

@huismiling
Copy link
Copy Markdown
Contributor Author

@muellerzr Thanks for your advice.
The below is done.

  1. Deleted the torch check.
  2. Deleted torch.cuda = torch.mlu.

@huismiling huismiling requested a review from muellerzr March 14, 2024 10:12
Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall looks like a straightforward and easy integration!

cc @SunMarc for the big model inference stuff it touches :)

@muellerzr muellerzr requested a review from SunMarc March 14, 2024 11:29
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome ! Thanks for the clean integration of MLU with big model inference @huismiling ! Can you confirm that you are able to load a model on multi-mlu when using transformers library ( by passing device_map="auto when loading a model such as llama2 or mistral ) ?

@huismiling
Copy link
Copy Markdown
Contributor Author

cc @SunMarc @muellerzr
Hi, I tried Llama-2-7b-chat-hf model with this code by using 8-MLUs.

from transformers import AutoTokenizer
import transformers
import torch

model = "/llm/models/Llama-2-7b-chat-hf/"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

sequences = pipeline(
    'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_length=100,
)
for seq in sequences:
    print(f"Result: {seq['generated_text']}")

below is the output.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:06<00:06,  6.38s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.52s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.95s/it]
/llm/transformers/src/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/llm/transformers/src/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/llm/transformers/src/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
  warnings.warn(
/llm/transformers/src/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
  warnings.warn(
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?

I'm open to different genres and topics, but I prefer shows with complex characters and compelling storytelling.

Thanks!

@huismiling huismiling requested review from SunMarc and muellerzr March 18, 2024 01:18
@muellerzr
Copy link
Copy Markdown
Contributor

Fantastic! Thanks for verifying! I’ll merge once the CI finishes :)

@huismiling
Copy link
Copy Markdown
Contributor Author

huismiling commented Mar 19, 2024

torch._dynamo has conflicts with lru_cache.
E torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(_lru_cache_wrapper) __call__ [] {}

use device.type to do mlu device check.

local test is passed !

@muellerzr
Copy link
Copy Markdown
Contributor

Great work! Thanks for verifying! (failing test is unrelated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants