System Info
Hi,
I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in _prepare_4d_causal_attention_mask_for_sdpa.
I ran a simple training job and the training statistics produced 2 different sets of throughputs
| Sequence Length |
release 4.39.3 (toks/s) |
release 4.40.0 (toks/s) |
| 4096 |
3247 |
2483 |
| 8192 |
3083 |
1918 |
When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).
Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.
| Causal Mask Attn Mask is passed to Torch SDPA |
Causal Mask handled internally in Torch SDPA |
 |
 |
Is this slowdown something we should expect from using the SDPA module with the current fix?
I attached a simple script to reproduce the issue
System Info
- `transformers` version: 4.40.0
- Platform: Linux-4.18.0-372.71.1.el8_6.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: False
Who can help?
No response
Information
Tasks
Reproduction
- Script to reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, __version__ as transformer_version
from datasets import load_dataset
from trl import SFTTrainer
print(f"transformers version: {transformer_version}")
dataset = load_dataset("yahma/alpaca-cleaned", split="train")
model_name = 'mistralai/Mistral-7B-v0.1'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token
model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16,
# attn_implementation="flash_attention_2",
attn_implementation="sdpa",
)
print(model.model._attn_implementation)
args = {
'batch_size': 4,
'gradient_accumulation_steps': 1,
'use_gradient_checkpointing': 1,
'warmup_steps': 10,
'lr': 2e-4,
'logging_steps': 10,
'output_dir': './results',
'optimizer': 'adamw_torch',
'weight_decay': 0.0,
'lr_scheduler': 'linear',
'seed': 42,
'max_steps': 100,
'context_length': 4096,
}
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n\n"
),
}
def formatting_prompts_func(example):
output_texts = []
if example.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(example)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(example)
new_example = prompt + example["output"]
return new_example
training_args = TrainingArguments(
per_device_train_batch_size = args['batch_size'],
gradient_accumulation_steps = args['gradient_accumulation_steps'],
gradient_checkpointing=args['use_gradient_checkpointing'],
warmup_steps = args['warmup_steps'],
max_steps = args['max_steps'],
learning_rate = args['lr'],
logging_strategy = 'steps',
logging_steps = args['logging_steps'],
output_dir = args['output_dir'],
optim = args['optimizer'],
weight_decay = args['weight_decay'],
lr_scheduler_type = args['lr_scheduler'],
seed = args['seed'],
include_tokens_per_second = True,
)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
max_seq_length = args['context_length'],
args = training_args,
formatting_func=formatting_prompts_func,
packing=True,
)
stats = trainer.train()
Expected behavior
-
Throughput should remain the same for sequence lengths lower than the window size for SPDA
-
Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths
System Info
Hi,
I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in
_prepare_4d_causal_attention_mask_for_sdpa.I ran a simple training job and the training statistics produced 2 different sets of throughputs
When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).
Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.
Is this slowdown something we should expect from using the SDPA module with the current fix?
I attached a simple script to reproduce the issue
System Info
Who can help?
No response
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Throughput should remain the same for sequence lengths lower than the window size for SPDA
Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths