Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/ja/perf_train_gpu_many.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,6 @@ FlexFlowは、サンプル-オペレータ-属性-パラメータの4D並列化

したがって、このフレームワークの約束は非常に魅力的です。選択したクラスタで30分間のシミュレーションを実行し、この特定の環境を最適に利用するための最良の戦略を提供します。部分を追加/削除/置換すると、それに対して実行して再最適化プランを作成します。その後、トレーニングできます。異なるセットアップには独自の最適化があります。

🤗 Transformersの現在の状況: まだ統合されていません。すでに[transformers.utils.fx](https://github.com/huggingface/transformers/blob/master/src/transformers/utils/fx.py)を使用してモデルがFXトレース可能であるため、FlexFlowを動作させるために必要な手順を誰かが見つける必要があります。

## Which Strategy To Use When

ここでは、どの並列化戦略をいつ使用するかの非常におおまかなアウトラインを示します。各リストの最初が通常よりも速いことが一般的です。
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ko/perf_train_gpu_many.md
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,6 @@ https://huggingface.co/papers/2201.11990)

따라서 이 프레임워크의 장점은 선택한 클러스터에서 30분 동안 시뮬레이션을 실행하고 이 특정 환경을 최적으로 활용하기 위한 최상의 전략을 제안한다는 것입니다. 부품을 추가/제거/교체하면 실행하고 그에 대한 계획을 다시 최적화한 후 훈련할 수 있습니다. 다른 설정은 자체적인 사용자 정의 최적화를 가질 수 있습니다.

🤗 Transformers 현황: 아직 통합되지 않음. 이미 [transformers.utils.fx](https://github.com/huggingface/transformers/blob/master/src/transformers/utils/fx.py)를 통해 모델을 FX-추적할 수 있으며, 이는 FlexFlow의 선행 조건입니다. 따라서 어떤 작업을 수행해야 FlexFlow가 우리의 모델과 함께 작동할 수 있는지 파악해야 합니다.


## 어떤 전략을 사용해야 할까요? [[which-strategy-to-use-when]]

Expand Down
9 changes: 3 additions & 6 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .configuration_utils import PreTrainedConfig
from .utils import is_torch_xpu_available, logging
from .utils.generic import GeneralInterface
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing


if is_torch_flex_attn_available():
Expand Down Expand Up @@ -239,7 +239,6 @@ def _ignore_causal_mask_sdpa(
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
passed).
"""
is_tracing = torch.jit.is_tracing() or isinstance(padding_mask, torch.fx.Proxy) or is_torchdynamo_compiling()
if padding_mask is not None and padding_mask.shape[-1] > kv_length:
mask_indices = torch.arange(kv_length, device=padding_mask.device)
mask_indices += kv_offset
Expand All @@ -250,7 +249,7 @@ def _ignore_causal_mask_sdpa(
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
# `ignore_causal_mask = True` if we are not tracing
if (
not is_tracing
not is_tracing(padding_mask)
# only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
and (query_length == 1 or (kv_length == query_length or _is_torch_xpu_available))
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
Expand All @@ -275,11 +274,9 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo
Detects whether the bidirectional mask can be ignored in case PyTorch's SDPA is used, i.e. when there is full
attention with no padding.
"""
is_tracing = torch.jit.is_tracing() or isinstance(padding_mask, torch.fx.Proxy) or is_torchdynamo_compiling()

# When using `torch.export` or `torch.onnx.dynamo_export`, we need to avoid to check the contents of the mask;
# otherwise, we will encounter dynamic control flows
if not is_tracing and (padding_mask is None or padding_mask.all()):
if not is_tracing(padding_mask) and (padding_mask is None or padding_mask.all()):
return True

return False
Expand Down
16 changes: 7 additions & 9 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from .utils.import_utils import is_torchdynamo_compiling
from .utils.import_utils import is_torchdynamo_compiling, is_tracing


@dataclass
Expand Down Expand Up @@ -267,7 +267,7 @@ def _ignore_causal_mask_sdpa(
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length

is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
is_tracing_ = is_tracing(inputs_embeds)

ignore_causal_mask = False

Expand All @@ -283,15 +283,15 @@ def _ignore_causal_mask_sdpa(
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
if (
(is_training or not is_tracing)
(is_training or not is_tracing_)
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
return False
elif not is_tracing and torch.all(attention_mask == 1):
elif not is_tracing_ and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
Expand Down Expand Up @@ -379,7 +379,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
is_tracing_ = is_tracing(inputs_embeds)

ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
Expand Down Expand Up @@ -408,7 +408,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
if not is_tracing and expanded_4d_mask.device.type == "cuda":
if not is_tracing_ and expanded_4d_mask.device.type == "cuda":
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
)
Expand Down Expand Up @@ -448,10 +448,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
_, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length

is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
if not is_tracing and torch.all(mask == 1):
if not is_tracing(mask) and torch.all(mask == 1):
return None
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@
ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
is_tracing,
)
from .utils.quantization_config import QuantizationMethod

Expand Down Expand Up @@ -4946,7 +4945,7 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
"""

# Skip the check during tracing.
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
if is_tracing(input_ids):
return

if (attention_mask is not None) or (self.config.pad_token_id is None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ....cache_utils import Cache
from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions
from ....modeling_utils import PreTrainedModel
from ....utils import DUMMY_INPUTS, DUMMY_MASK, auto_docstring, is_torch_fx_proxy
from ....utils import DUMMY_INPUTS, DUMMY_MASK, auto_docstring
from .configuration_gptsan_japanese import GPTSanJapaneseConfig


Expand Down Expand Up @@ -593,15 +593,9 @@ def _shift_right(self, input_ids):
"See T5 docs for more information."
)

# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
Expand Down Expand Up @@ -54,11 +54,6 @@
from ...modeling_flash_attention_utils import _flash_attention_forward


# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


logger = logging.get_logger(__name__)


Expand Down
23 changes: 3 additions & 20 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Optional, Union

import torch
import torch.fx
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

Expand All @@ -34,12 +33,7 @@
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
is_torch_fx_proxy,
logging,
)
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_gptj import GPTJConfig


Expand All @@ -62,7 +56,6 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)


@torch.fx.wrap
def get_embed_positions(embed_positions, position_ids):
return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)

Expand Down Expand Up @@ -198,12 +191,7 @@ def forward(
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
embed_positions = self._get_embed_positions(position_ids)

repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
Expand Down Expand Up @@ -283,12 +271,7 @@ def forward(
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
embed_positions = self._get_embed_positions(position_ids)

repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/hiera/modeling_hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,6 @@ def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> tor
if isinstance(head, nn.Identity):
return hidden_states

# Doing explicit to avoid problems with torch.fx
batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size = hidden_states.shape
# From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
# To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
DUMMY_MASK,
auto_docstring,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
)
Expand Down Expand Up @@ -1259,15 +1258,9 @@ def _shift_right(self, input_ids):
"See LongT5 docs for more information."
)

# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
DUMMY_MASK,
auto_docstring,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
)
Expand Down Expand Up @@ -631,15 +630,9 @@ def _shift_right(self, input_ids):
"See MT5 docs for more information."
)

# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
DUMMY_MASK,
auto_docstring,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
)
Expand Down Expand Up @@ -440,15 +439,9 @@ def _shift_right(self, input_ids):
"See Pix2Struct docs for more information."
)

# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
Expand Down
14 changes: 4 additions & 10 deletions src/transformers/models/pop2piano/modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging
from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_pop2piano import Pop2PianoConfig


Expand Down Expand Up @@ -593,15 +593,9 @@ def _shift_right(self, input_ids):
"self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id."
)

# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, logging
from ...utils.import_utils import is_torchdynamo_compiling
from ...utils.import_utils import is_tracing
from .configuration_recurrent_gemma import RecurrentGemmaConfig


Expand Down Expand Up @@ -362,8 +362,7 @@ def forward(
# Apply gamma normalization to the input. We need to clip the derivatives of
# `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
multiplier = 1
tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling()
if not torch.jit.is_tracing() and not tracing:
if not is_tracing(activations):
multiplier = SqrtBoundDerivative.apply(1 - a_square)
multiplier = reset + ~reset * multiplier
normalized_x = gated_inputs * multiplier.type(activations.dtype)
Expand Down
Loading