diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index b5f59b4bb1f9..c6149e101d89 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -533,6 +533,21 @@ def _process_flash_attention_kwargs( flash_kwargs (`dict`): A dict of kwargs that are requested and supported. """ + + user_kwargs = { + "dropout_p": dropout, + "window_size": sliding_window, + "deterministic": deterministic, + "softcap": softcap, + "s_aux": s_aux, + } + # Note 'window_size' in supports_mapping maps to our 'sliding_window' param + for k, v in user_kwargs.items(): + if not supports_mapping[k] and v is not None: + raise ValueError( + f"Parameter `{k}` is not supported by this Flash Attention implementation but was set, please use a different attentionimplementation." + ) + flash_kwargs = { "causal": is_causal and not (use_top_left_mask and query_length == 1), "softmax_scale": softmax_scale,