From 078878046f47bb61d569c30222fe938265e0df31 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 18 Oct 2021 17:35:21 +0200 Subject: [PATCH 01/18] changed the way tracing happens, enabling dynamic axes out of the box --- src/transformers/modeling_utils.py | 66 +++++--- .../models/albert/modeling_albert.py | 2 +- src/transformers/models/bert/modeling_bert.py | 4 +- .../models/electra/modeling_electra.py | 4 +- src/transformers/models/gpt2/modeling_gpt2.py | 8 +- .../models/gpt_neo/modeling_gpt_neo.py | 4 +- .../models/layoutlm/modeling_layoutlm.py | 4 +- .../megatron_bert/modeling_megatron_bert.py | 4 +- .../models/mobilebert/modeling_mobilebert.py | 4 +- .../models/roberta/modeling_roberta.py | 4 +- .../models/splinter/modeling_splinter.py | 4 +- src/transformers/utils/fx.py | 153 +++++++++++++++--- tests/test_modeling_common.py | 6 +- 13 files changed, 199 insertions(+), 68 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 324046cc6a53..837a014f2096 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -240,6 +240,27 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: return encoder_extended_attention_mask + def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. @@ -264,26 +285,29 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - causal_mask = torch.cat( - [ - torch.ones( - (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype - ), - causal_mask, - ], - axis=-1, - ) - - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + # batch_size, seq_length = input_shape + # seq_ids = torch.arange(seq_length, device=device) + # causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # # causal and attention masks must have same type with pytorch version < 1.3 + # causal_mask = causal_mask.to(attention_mask.dtype) + + # if causal_mask.shape[1] < attention_mask.shape[1]: + # prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + # causal_mask = torch.cat( + # [ + # torch.ones( + # (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + # ), + # causal_mask, + # ], + # axis=-1, + # ) + + # extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + extended_attention_mask = self.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) else: extended_attention_mask = attention_mask[:, None, None, :] else: @@ -1761,7 +1785,7 @@ def __init__(self, nf, nx): def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(*size_out) + x = x.view(size_out) return x diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 9c7ccccc8c6c..92b46509dfce 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -293,7 +293,7 @@ def __init__(self, config): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def prune_heads(self, heads): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 66df84d47a0a..fab5b5a37614 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -252,7 +252,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -341,7 +341,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 8ad939d5f6b9..0f7f6a5b3e2c 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -245,7 +245,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -334,7 +334,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8bb8590a8b45..40a7cec3c5b8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -193,7 +193,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -281,7 +281,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -915,7 +915,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1410,7 +1410,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7176cfa790b2..a2c9be7f9dfa 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -173,7 +173,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -637,7 +637,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 29a8c071eece..d9e5ac304b23 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -160,7 +160,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -249,7 +249,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index e0f9f1191f21..101a5e94c353 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -223,7 +223,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -312,7 +312,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index a37bd829839b..2d8ea6f6f6f5 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -237,7 +237,7 @@ def __init__(self, config): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -274,7 +274,7 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index dc5d717f8434..6f8e257af021 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -187,7 +187,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -276,7 +276,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index b982a38b62f4..3d15cc682556 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -127,7 +127,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -216,7 +216,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 23a2eb4c1fab..da0219fd1a2e 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -134,6 +134,7 @@ def __init__(self, node: Node, tracer: Optional[Tracer] = None): if hasattr(self, "tracer") and self.tracer is not None: self.device = self.tracer.root.device self.dtype = next(self.tracer.root.parameters()).dtype + self.cache = None @property def shape(self): @@ -145,6 +146,42 @@ def __setitem__(self, key, value): def __contains__(self, key): return False + def __eq__(self, other): + if self.cache is not None: + return self.cache == other + return super().__eq__(other) + + def __len__(self): + if self.cache is not None: + if isinstance(self.cache, int): + return self.cache + elif isinstance(self.cache, (torch.Size, list, tuple)): + return len(self.cache) + else: + return super().__len__(self) + return super().__len__(self) + + def __torch_function__(self, orig_method, types, args=None, kwargs=None): + proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs) + proxy.cache = self.cache + return proxy + + +def _function_to_leaf(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +def _function_leaf_getter(func_name, mapping): + @functools.wraps(mapping[func_name]) + def wrapper(*args, **kwargs): + return mapping[func_name](*args, **kwargs) + + return wrapper + def _wrap_method_for_model_recording(model, method_name, cache_name): """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" @@ -168,10 +205,20 @@ def _create_recorded_proxy_method(proxy, method_name, cache_name): during symbolic tracing. """ - def method(self, *args, **kwargs): - cache = getattr(self.tracer.root, cache_name) + original_method = getattr(torch.Tensor, method_name) + + @functools.wraps(original_method) + def method(*args, **kwargs): + cache = getattr(args[0].tracer.root, cache_name) res = cache.pop(0) - return res + proxy = args[0].__torch_function__( + original_method, + None, + args=args, + kwargs=kwargs, + ) + proxy.cache = res + return proxy method.__name__ = method_name bound_method = method.__get__(proxy, proxy.__class__) @@ -198,6 +245,28 @@ def method(*args, **kwargs): setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) +def _create_proxy_method_for_model_tracing(model, method_name, cache_name): + """ + Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values + during symbolic tracing. + """ + + original_method = getattr(torch.Tensor, method_name) + + @functools.wraps(original_method) + def method(*args, **kwargs): + cache = getattr(model, cache_name) + res = cache.pop(0) + args[0].cache_value = res + + return original_method(*args, **kwargs) + + setattr(HFProxy, method_name, method) + + if method_name == "size": + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + def _monkey_patch_tensor_methods_for_model_recording(model, method_names): """ Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference @@ -233,11 +302,23 @@ class HFTracer(Tracer): regular PyTorch torch.fx.Proxy. """ - default_methods_to_record = {"__bool__", "size", "dim"} + _DEFAULT_METHODS_TO_RECORD = {"__bool__", "size", "dim"} + from transformers import modeling_utils + + _FUNCTIONS_TO_AUTOWRAP = { + torch: {"arange", "zeros", "ones", "full_like"}, + # modeling_utils.ModuleUtilsMixin: {"get_extended_attention_mask"}, + modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, + } def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): super().__init__() + self._leaf_functions_register = {} + for module, names in self._FUNCTIONS_TO_AUTOWRAP.items(): + for name in names: + self._register_leaf_function(module, name) + if not is_torch_fx_available(): torch_version = version.parse(importlib_metadata.version("torch")) raise ImportError( @@ -261,6 +342,24 @@ def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): self.prev_module = None self.recorded_methods = None + def _register_leaf_function(self, module, name): + orig_func = getattr(module, name) + patched_func = _function_to_leaf(orig_func) + patched_func.__module__ = __name__ + self._leaf_functions_register[name] = (module, orig_func, patched_func) + + def _patch_leaf_functions_for_root(self, root, restore=False): + for name in self._leaf_functions_register: + module, orig_func, patched_func = self._leaf_functions_register[name] + if restore: + root.__class__.forward.__globals__.pop(name) + setattr(module, name, orig_func) + else: + root.__class__.forward.__globals__[name] = patched_func + leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__) + leaf_getter.__module__ = __name__ + setattr(module, name, leaf_getter) + def proxy(self, node: Node): p = HFProxy(node, self) if self.recorded_methods: @@ -277,7 +376,7 @@ def _generate_dummy_input(self, model, input_name): if input_name in ["labels", "start_positions", "end_positions"]: batch_size = self.encoder_shape[0] if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device) + inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) @@ -302,11 +401,11 @@ def _generate_dummy_input(self, model, input_name): elif "mask" in input_name or "ids" in input_name: shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device) + inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape shape += [model.config.hidden_size] - inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device) + inputs_dict[input_name] = torch.zeros(shape, dtype=torch.float, device=device) return inputs_dict @@ -316,7 +415,7 @@ def record(self, model, input_names, method_names=None): tracing. """ if method_names is None: - method_names = self.default_methods_to_record + method_names = self._DEFAULT_METHODS_TO_RECORD inputs = {} for input_name in input_names: @@ -341,6 +440,8 @@ def record(self, model, input_names, method_names=None): for cache_name in self.recorded_methods.values(): setattr(model, cache_name, getattr(clone, cache_name)) + # _restore_tensor_creators(self.original_tensor_creators) + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if isinstance(attr_val, torch.nn.Parameter): for n, p in self.root.named_parameters(): @@ -366,11 +467,13 @@ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = self.record(root, input_names, method_names=method_names) - for method_name, cache_name in self.recorded_methods.items(): - _wrap_method_for_model_tracing(root, method_name, cache_name) + autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + self._patch_leaf_functions_for_root(root) graph = super().trace(root, concrete_args=concrete_args) + self._patch_leaf_functions_for_root(root, restore=True) _reset_tensor_methods(self.original_methods) # TODO: keep this until necessary. @@ -434,6 +537,10 @@ def path_of_module(self, mod: nn.Module) -> str: self.prev_module = path return path + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + is_loss_module = m.__module__.startswith("torch.nn.modules.loss") + return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name) + def create_arg(self, a: Any) -> Argument: if isinstance(a, range): return super().create_arg(list(a)) @@ -562,18 +669,18 @@ def symbolic_trace( decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) sequence_length = [encoder_sequence_length, decoder_sequence_length] - if not isinstance(model, _SUPPORTED_MODELS): - supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) - raise NotImplementedError( - f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" - ) - if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( - model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES - ): - supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) - raise NotImplementedError( - f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" - ) + # if not isinstance(model, _SUPPORTED_MODELS): + # supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) + # raise NotImplementedError( + # f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" + # ) + # if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( + # model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES + # ): + # supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) + # raise NotImplementedError( + # f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" + # ) # Tracing. tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) @@ -593,6 +700,6 @@ def symbolic_trace( traced.static_batch_size = batch_size traced.static_sequence_length = sequence_length - transform_to_dynamic_input_(traced) + # transform_to_dynamic_input_(traced) return traced diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1df5e9e0f061..513d83e3a6e9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -654,9 +654,9 @@ def test_torch_fx(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict) - def test_torch_fx_output_loss(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) + # def test_torch_fx_output_loss(self): + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) def test_torch_fx_dynamic_axes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From f7306089428739fc0267cb0e6d21d90b1cee26b0 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Nov 2021 10:01:55 +0100 Subject: [PATCH 02/18] Updated the tests and modeling xlnet --- .../models/xlnet/modeling_xlnet.py | 10 ++--- tests/test_modeling_albert.py | 1 - tests/test_modeling_bart.py | 1 + tests/test_modeling_bert.py | 1 - tests/test_modeling_common.py | 37 ++++--------------- tests/test_modeling_distilbert.py | 1 - tests/test_modeling_electra.py | 1 - tests/test_modeling_layoutlm.py | 1 + tests/test_modeling_megatron_bert.py | 1 - tests/test_modeling_mobilebert.py | 1 - tests/test_modeling_roberta.py | 1 + tests/test_modeling_xlnet.py | 2 + 12 files changed, 18 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 278320a6b41f..71c4d788d34c 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1464,7 +1464,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs, + # **kwargs, ) logits = self.lm_loss(transformer_outputs[0]) @@ -1564,7 +1564,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs, + # **kwargs, ) output = transformer_outputs[0] @@ -1781,7 +1781,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs, + # **kwargs, ) output = transformer_outputs[0] @@ -1878,7 +1878,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs, + # **kwargs, ) sequence_output = outputs[0] @@ -2018,7 +2018,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - **kwargs, + # **kwargs, ) hidden_states = transformer_outputs[0] start_logits = self.start_logits(hidden_states, p_mask=p_mask) diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index d16dcadd5e6a..b01b459d86a5 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -232,7 +232,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 957350b824ce..319b4c38de41 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -411,6 +411,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): if is_torch_available() else () ) + # fx_ready_model_classes = all_model_classes all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 7a6628509799..e010cbaf180e 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -445,7 +445,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 513d83e3a6e9..64c62ebdf30b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -109,7 +109,6 @@ class ModelTesterMixin: all_model_classes = () all_generative_model_classes = () fx_ready_model_classes = () - fx_dynamic_ready_model_classes = () test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -654,22 +653,18 @@ def test_torch_fx(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict) - # def test_torch_fx_output_loss(self): - # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) - - def test_torch_fx_dynamic_axes(self): + def test_torch_fx_output_loss(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True) + self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False): + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): if not is_torch_fx_available(): return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False - model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes + model_classes = self.fx_ready_model_classes for model_class in model_classes: model = model_class(config=configs_no_init) model.to(torch_device) @@ -679,8 +674,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa try: if model.config.is_encoder_decoder: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - input_ids = inputs["input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] labels = inputs.get("labels", None) input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] if labels is not None: @@ -689,17 +682,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model_output = model(**filtered_inputs) - batch_size = input_ids.shape[0] - encoder_sequence_length = input_ids.shape[1] - decoder_sequence_length = decoder_attention_mask.shape[1] - - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1, - ) - + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: input_names = ["input_ids", "attention_mask", "token_type_ids"] @@ -731,14 +714,10 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." ) - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=sequence_length if not dynamic_axes else -1, - num_choices=num_choices, - ) + # import pytest; pytest.set_trace() + traced_model = symbolic_trace(model, input_names, num_choices=num_choices) traced_output = traced_model(**filtered_inputs) + # import pytest; pytest.set_trace() except RuntimeError: self.fail("Couldn't trace module.") diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 8026f92db604..4a86d9a74a97 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -210,7 +210,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): else None ) fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index be19f8d610db..3dd94223c726 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -372,7 +372,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_layoutlm.py b/tests/test_modeling_layoutlm.py index 67423fe21fd1..acbcfa0df9d5 100644 --- a/tests/test_modeling_layoutlm.py +++ b/tests/test_modeling_layoutlm.py @@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) + # fx_ready_model_classes = all_model_classes def setUp(self): self.model_tester = LayoutLMModelTester(self) diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index a7f47ddea322..5a06d57a9ec3 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -284,7 +284,6 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 716714157a76..abf64e416f0a 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -270,7 +270,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 031263cf6df9..11d826f00148 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 5516b28e17e1..46fed1019717 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -526,6 +526,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) all_generative_model_classes = ( (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + + fx_ready_model_classes = all_model_classes test_pruning = False # XLNet has 2 QA models -> need to manually set the correct labels for one of them here From a146048bafb40f334581d58804426a0bef7ada5f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Nov 2021 19:09:44 +0100 Subject: [PATCH 03/18] Added the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). --- src/transformers/utils/fx.py | 286 ++++++++++++++--------------------- 1 file changed, 111 insertions(+), 175 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index da0219fd1a2e..546def8b9145 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -2,7 +2,7 @@ import functools import inspect import random -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, ModuleType, List, Optional, Tuple, Type, Union import torch from packaging import version @@ -24,6 +24,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_MAPPING, GPT2DoubleHeadsModel, + XLNetForQuestionAnswering, PretrainedConfig, PreTrainedModel, logging, @@ -149,7 +150,13 @@ def __contains__(self, key): def __eq__(self, other): if self.cache is not None: return self.cache == other - return super().__eq__(other) + elif isinstance(other, HFProxy): + return True + else: + return super().__eq__(other) + + def __ne__(self, other): + return not self == other def __len__(self): if self.cache is not None: @@ -167,7 +174,8 @@ def __torch_function__(self, orig_method, types, args=None, kwargs=None): return proxy -def _function_to_leaf(func): +def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: + """Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -183,23 +191,7 @@ def wrapper(*args, **kwargs): return wrapper -def _wrap_method_for_model_recording(model, method_name, cache_name): - """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" - method = getattr(torch.Tensor, method_name) - - @functools.wraps(method) - def wrapped(*args, **kwargs): - if not hasattr(model, cache_name): - setattr(model, cache_name, []) - cache = getattr(model, cache_name) - res = method(*args, **kwargs) - cache.append(res) - return res - - return wrapped - - -def _create_recorded_proxy_method(proxy, method_name, cache_name): +def _create_recorded_proxy_method(proxy, method_name, cache_name, return_proxy): """ Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values during symbolic tracing. @@ -211,85 +203,22 @@ def _create_recorded_proxy_method(proxy, method_name, cache_name): def method(*args, **kwargs): cache = getattr(args[0].tracer.root, cache_name) res = cache.pop(0) - proxy = args[0].__torch_function__( - original_method, - None, - args=args, - kwargs=kwargs, - ) - proxy.cache = res - return proxy + if return_proxy: + proxy = args[0].__torch_function__( + original_method, + None, + args=args, + kwargs=kwargs, + ) + proxy.cache = res + return proxy + return res method.__name__ = method_name bound_method = method.__get__(proxy, proxy.__class__) setattr(proxy, method_name, bound_method) -def _wrap_method_for_model_tracing(model, method_name, cache_name): - """ - Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values - during symbolic tracing. - """ - - original_method = getattr(torch.Tensor, method_name) - - @functools.wraps(original_method) - def method(*args, **kwargs): - cache = getattr(model, cache_name) - res = cache.pop(0) - return res - - setattr(torch.Tensor, method_name, method) - - if method_name == "size": - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - -def _create_proxy_method_for_model_tracing(model, method_name, cache_name): - """ - Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values - during symbolic tracing. - """ - - original_method = getattr(torch.Tensor, method_name) - - @functools.wraps(original_method) - def method(*args, **kwargs): - cache = getattr(model, cache_name) - res = cache.pop(0) - args[0].cache_value = res - - return original_method(*args, **kwargs) - - setattr(HFProxy, method_name, method) - - if method_name == "size": - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - -def _monkey_patch_tensor_methods_for_model_recording(model, method_names): - """ - Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference - before symbolic tracing. - """ - cache_names = dict() - original_methods = dict() - for method_name in method_names: - cache_name = f"cache_{method_name}" - cache_names[method_name] = cache_name - if not hasattr(torch.Tensor, method_name): - logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") - continue - original_methods[method_name] = getattr(torch.Tensor, method_name) - setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name)) - - if method_name == "size": - original_methods["shape"] = torch.Tensor.shape - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - return cache_names, original_methods - - def _reset_tensor_methods(original_methods): """Helper function that resets the monkey patched torch.Tensor methods to their original values.""" for name, method in original_methods.items(): @@ -302,16 +231,15 @@ class HFTracer(Tracer): regular PyTorch torch.fx.Proxy. """ - _DEFAULT_METHODS_TO_RECORD = {"__bool__", "size", "dim"} + _DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False} from transformers import modeling_utils _FUNCTIONS_TO_AUTOWRAP = { - torch: {"arange", "zeros", "ones", "full_like"}, - # modeling_utils.ModuleUtilsMixin: {"get_extended_attention_mask"}, + torch: {"arange", "zeros", "ones", "full_like", "eye"}, modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, } - def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): + def __init__(self, batch_size=1, sequence_length=128, num_choices=-1): super().__init__() self._leaf_functions_register = {} @@ -326,29 +254,18 @@ def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): f"{TORCH_FX_REQUIRED_VERSION} is supported." ) - encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length - decoder_sequence_length = ( - sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length - ) - self.encoder_shape = [batch_size, encoder_sequence_length] - self.decoder_shape = ( - [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape) - ) - self.num_choices = num_choices - if self.num_choices > 0: - self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length] - self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length] - + self.shape = [batch_size, sequence_length] self.prev_module = None self.recorded_methods = None - def _register_leaf_function(self, module, name): + def _register_leaf_function(self, module: ModuleType, name: str): + """Registers the function called name in module as a leaf function.""" orig_func = getattr(module, name) patched_func = _function_to_leaf(orig_func) patched_func.__module__ = __name__ self._leaf_functions_register[name] = (module, orig_func, patched_func) - def _patch_leaf_functions_for_root(self, root, restore=False): + def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore=False): for name in self._leaf_functions_register: module, orig_func, patched_func = self._leaf_functions_register[name] if restore: @@ -360,24 +277,71 @@ def _patch_leaf_functions_for_root(self, root, restore=False): leaf_getter.__module__ = __name__ setattr(module, name, leaf_getter) - def proxy(self, node: Node): - p = HFProxy(node, self) - if self.recorded_methods: - for method_name, cache_name in self.recorded_methods.items(): - _create_recorded_proxy_method(p, method_name, cache_name) - return p + def _method_is_called_in_leaf_module(self, module_ids): + currentframe = inspect.currentframe() + while currentframe: + if currentframe is None: + return False + module = currentframe.f_locals.get("self", None) + if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"): + return True + currentframe = currentframe.f_back + return False + + def _wrap_method_for_model_recording(self, model, method_name, cache_name, module_ids): + """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" + method = getattr(torch.Tensor, method_name) + + @functools.wraps(method) + def wrapped(*args, **kwargs): + if self._method_is_called_in_leaf_module(module_ids): + return method(*args, **kwargs) + if not hasattr(model, cache_name): + setattr(model, cache_name, []) + cache = getattr(model, cache_name) + res = method(*args, **kwargs) + cache.append(res) + return res + + return wrapped + + def _monkey_patch_tensor_methods_for_model_recording(self, model, method_names): + """ + Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference + before symbolic tracing. + """ + cache_names = {} + original_methods = {} + module_ids = set(id(mod) for mod in model.modules()) + for method_name in method_names: + cache_name = f"cache_{method_name}" + cache_names[method_name] = cache_name + if not hasattr(torch.Tensor, method_name): + logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") + continue + original_methods[method_name] = getattr(torch.Tensor, method_name) + setattr(torch.Tensor, method_name, self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids)) + + if method_name == "size": + original_methods["shape"] = torch.Tensor.shape + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + return cache_names, original_methods def _generate_dummy_input(self, model, input_name): """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device - inputs_dict = dict() + inputs_dict = {} if input_name in ["labels", "start_positions", "end_positions"]: - batch_size = self.encoder_shape[0] + batch_size = self.shape[0] if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + elif model_class in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), + XLNetForQuestionAnswering, + ]: inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ @@ -387,24 +351,21 @@ def _generate_dummy_input(self, model, input_name): ]: inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ + *get_values(MODEL_FOR_PRETRAINING_MAPPING), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), GPT2DoubleHeadsModel, ]: - inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): - inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device) + inputs_dict["labels"] = torch.zeros(self.shape, dtype=torch.long, device=device) else: raise NotImplementedError(f"{model_class} not supported yet.") elif "mask" in input_name or "ids" in input_name: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) + inputs_dict[input_name] = torch.zeros(self.shape, dtype=torch.long, device=device) else: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - shape += [model.config.hidden_size] + shape = self.shape + [model.config.hidden_size] inputs_dict[input_name] = torch.zeros(shape, dtype=torch.float, device=device) return inputs_dict @@ -417,30 +378,31 @@ def record(self, model, input_names, method_names=None): if method_names is None: method_names = self._DEFAULT_METHODS_TO_RECORD + num_choices = _generate_random_int(low=2, high=5) + if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + self.shape.insert(1, num_choices) + inputs = {} for input_name in input_names: inputs.update(self._generate_dummy_input(model, input_name)) - clone = copy.deepcopy(model) - cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names) + cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) self.original_methods = original_methods - clone(**inputs) + model(**inputs) # Useful because sometime the config is changed at inference time, for instance for # classification tasks where config.problem_type can be set. - model.config = clone.config + # model.config = clone.config _reset_tensor_methods(original_methods) self.recorded_methods = { - method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name) + method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) } - for cache_name in self.recorded_methods.values(): - setattr(model, cache_name, getattr(clone, cache_name)) - - # _restore_tensor_creators(self.original_tensor_creators) + # for cache_name in self.recorded_methods.values(): + # setattr(model, cache_name, getattr(clone, cache_name)) def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if isinstance(attr_val, torch.nn.Parameter): @@ -458,6 +420,14 @@ def _module_getattr(self, attr, attr_val, parameter_proxy_cache): return parameter_proxy_cache[n] return attr_val + def proxy(self, node: Node): + p = HFProxy(node, self) + if self.recorded_methods: + for method_name, cache_name in self.recorded_methods.items(): + return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name] + _create_recorded_proxy_method(p, method_name, cache_name, return_proxy) + return p + def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: if concrete_args is None: concrete_args = {} @@ -604,8 +574,6 @@ def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Option def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, - batch_size: int = 1, - sequence_length: Union[int, List[int], Tuple[int]] = (128, 128), num_choices: int = -1, ) -> GraphModule: @@ -647,27 +615,10 @@ def symbolic_trace( sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - # Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes. - use_dynamic_batch_size = batch_size <= 0 - if isinstance(sequence_length, (list, tuple)): - use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0 - else: - use_dynamic_sequence_length = sequence_length <= 0 - - if use_dynamic_batch_size or use_dynamic_sequence_length: - forbidden_values = [ - model.config.num_attention_heads, - model.config.hidden_size, - model.config.hidden_size // model.config.num_attention_heads, - ] - if use_dynamic_batch_size: - batch_size = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(batch_size) - if use_dynamic_sequence_length: - encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(encoder_sequence_length) - decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - sequence_length = [encoder_sequence_length, decoder_sequence_length] + forbidden_values = [] + batch_size = _generate_random_int(forbidden_values=forbidden_values) + forbidden_values.append(batch_size) + sequence_length = _generate_random_int(forbidden_values=forbidden_values) # if not isinstance(model, _SUPPORTED_MODELS): # supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) @@ -684,22 +635,7 @@ def symbolic_trace( # Tracing. tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) - traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) - traced.config = copy.deepcopy(model.config) - traced.num_choices = num_choices - traced.dummy_inputs = {} - - for name in input_names: - traced.dummy_inputs.update(tracer._generate_dummy_input(model, name)) - - traced.use_dynamic_batch_size = use_dynamic_batch_size - traced.use_dynamic_sequence_length = use_dynamic_sequence_length - traced.static_batch_size = batch_size - traced.static_sequence_length = sequence_length - - # transform_to_dynamic_input_(traced) - return traced From aef4d00a0f60ca854429a5b2534efc0e85d5ea88 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 11:10:58 +0100 Subject: [PATCH 04/18] Comments and making tracing work for gpt-j and xlnet --- src/transformers/models/gptj/modeling_gptj.py | 6 +- .../models/xlnet/modeling_xlnet.py | 6 +- src/transformers/utils/fx.py | 175 +++++++++--------- 3 files changed, 93 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 869014bee626..93329e2a5410 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -107,7 +107,7 @@ def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) if rotary: return tensor if len(tensor.shape) == 5: @@ -665,7 +665,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,7 +945,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 71c4d788d34c..1f4245e74984 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1564,7 +1564,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + # **kwargs, ) output = transformer_outputs[0] @@ -1781,7 +1781,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + # **kwargs, ) output = transformer_outputs[0] @@ -2018,7 +2018,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + # **kwargs, ) hidden_states = transformer_outputs[0] start_logits = self.start_logits(hidden_states, p_mask=p_mask) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 546def8b9145..266326fdc255 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,8 +1,8 @@ -import copy import functools import inspect import random -from typing import Any, Callable, Dict, ModuleType, List, Optional, Tuple, Type, Union +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch from packaging import version @@ -24,20 +24,13 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_MAPPING, GPT2DoubleHeadsModel, - XLNetForQuestionAnswering, PretrainedConfig, PreTrainedModel, + XLNetForQuestionAnswering, logging, ) from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available from ..models.auto import get_values -from .fx_transformations import ( - _cache_attributes, - _patch_arguments_, - _restore_attributes_, - transform_to_dynamic_input_, - transformation, -) logger = logging.get_logger(__name__) @@ -176,6 +169,7 @@ def __torch_function__(self, orig_method, types, args=None, kwargs=None): def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: """Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" + @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -183,7 +177,7 @@ def wrapper(*args, **kwargs): return wrapper -def _function_leaf_getter(func_name, mapping): +def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: @functools.wraps(mapping[func_name]) def wrapper(*args, **kwargs): return mapping[func_name](*args, **kwargs) @@ -191,7 +185,7 @@ def wrapper(*args, **kwargs): return wrapper -def _create_recorded_proxy_method(proxy, method_name, cache_name, return_proxy): +def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool): """ Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values during symbolic tracing. @@ -219,7 +213,7 @@ def method(*args, **kwargs): setattr(proxy, method_name, bound_method) -def _reset_tensor_methods(original_methods): +def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]): """Helper function that resets the monkey patched torch.Tensor methods to their original values.""" for name, method in original_methods.items(): setattr(torch.Tensor, name, method) @@ -239,7 +233,7 @@ class HFTracer(Tracer): modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, } - def __init__(self, batch_size=1, sequence_length=128, num_choices=-1): + def __init__(self, batch_size: int = 1, sequence_length: int = 128, num_choices: int = -1): super().__init__() self._leaf_functions_register = {} @@ -265,7 +259,8 @@ def _register_leaf_function(self, module: ModuleType, name: str): patched_func.__module__ = __name__ self._leaf_functions_register[name] = (module, orig_func, patched_func) - def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore=False): + def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False): + """Patches leaf functions specifically for root.""" for name in self._leaf_functions_register: module, orig_func, patched_func = self._leaf_functions_register[name] if restore: @@ -277,7 +272,12 @@ def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore=False): leaf_getter.__module__ = __name__ setattr(module, name, leaf_getter) - def _method_is_called_in_leaf_module(self, module_ids): + def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool: + """ + Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record + outputs that will not be encountered by the tracer. + """ + currentframe = inspect.currentframe() while currentframe: if currentframe is None: @@ -288,7 +288,9 @@ def _method_is_called_in_leaf_module(self, module_ids): currentframe = currentframe.f_back return False - def _wrap_method_for_model_recording(self, model, method_name, cache_name, module_ids): + def _wrap_method_for_model_recording( + self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int] + ): """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" method = getattr(torch.Tensor, method_name) @@ -305,7 +307,7 @@ def wrapped(*args, **kwargs): return wrapped - def _monkey_patch_tensor_methods_for_model_recording(self, model, method_names): + def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]): """ Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference before symbolic tracing. @@ -320,7 +322,11 @@ def _monkey_patch_tensor_methods_for_model_recording(self, model, method_names): logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") continue original_methods[method_name] = getattr(torch.Tensor, method_name) - setattr(torch.Tensor, method_name, self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids)) + setattr( + torch.Tensor, + method_name, + self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids), + ) if method_name == "size": original_methods["shape"] = torch.Tensor.shape @@ -328,7 +334,7 @@ def _monkey_patch_tensor_methods_for_model_recording(self, model, method_names): return cache_names, original_methods - def _generate_dummy_input(self, model, input_name): + def _generate_dummy_input(self, model: PreTrainedModel, input_name: str) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device @@ -370,9 +376,9 @@ def _generate_dummy_input(self, model, input_name): return inputs_dict - def record(self, model, input_names, method_names=None): + def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None): """ - Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic + Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing. """ if method_names is None: @@ -391,19 +397,12 @@ def record(self, model, input_names, method_names=None): model(**inputs) - # Useful because sometime the config is changed at inference time, for instance for - # classification tasks where config.problem_type can be set. - # model.config = clone.config - _reset_tensor_methods(original_methods) self.recorded_methods = { method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) } - # for cache_name in self.recorded_methods.values(): - # setattr(model, cache_name, getattr(clone, cache_name)) - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if isinstance(attr_val, torch.nn.Parameter): for n, p in self.root.named_parameters(): @@ -428,7 +427,12 @@ def proxy(self, node: Node): _create_recorded_proxy_method(p, method_name, cache_name, return_proxy) return p - def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: + def trace( + self, + root: PreTrainedModel, + concrete_args: Optional[Dict[str, Any]] = None, + method_names: Optional[Iterable[str]] = None, + ) -> Graph: if concrete_args is None: concrete_args = {} @@ -461,7 +465,7 @@ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = return graph - def _insert_module_as_submodule(self, mod): + def _insert_module_as_submodule(self, mod: nn.Module) -> str: """ Helper method which tries to insert a module that was not declared as submodule. """ @@ -507,7 +511,7 @@ def path_of_module(self, mod: nn.Module) -> str: self.prev_module = path return path - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: is_loss_module = m.__module__.startswith("torch.nn.modules.loss") return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name) @@ -517,52 +521,52 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) -@transformation -def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: - """ - Prepares a GraphModule produced by symbolic_trace for retracing by: - - - Caching all the attributes specific to the way the model was initially traced - - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes - For instance, the need to retrace a GraphModule can happen when applying quantization. - """ - attributes = _cache_attributes(gm) - _patch_arguments_(gm, gm.dynamic2static) - - return gm, attributes - - -def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): - """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" - _restore_attributes_(gm, attributes) - # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired - # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. - transform_to_dynamic_input_(gm, is_retracing=True) - _patch_arguments_(gm, gm.static2dynamic) - return gm - - -def retrace_graph_with( - gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None -) -> GraphModule: - """ - Retraces a GraphModule by either using a tracer or a function using a tracer (for instance - torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and - restoring anything necessary after the retrace. - """ - if tracer is None and func is None: - raise ValueError("Either a tracer or a function using a tracer must be provided.") - elif tracer is not None and func is not None: - raise ValueError("Either provide a tracer or a function using a tracer, but not both.") - else: - gm, attributes = prepare_for_retracing(gm) - tracing_func = tracer.trace if tracer else func - traced = tracing_func(gm) - restore_after_retracing_(traced, attributes) - return traced - - -def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): +# @transformation +# def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: +# """ +# Prepares a GraphModule produced by symbolic_trace for retracing by: +# +# - Caching all the attributes specific to the way the model was initially traced +# - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes +# For instance, the need to retrace a GraphModule can happen when applying quantization. +# """ +# attributes = _cache_attributes(gm) +# _patch_arguments_(gm, gm.dynamic2static) +# +# return gm, attributes + + +# def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): +# """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" +# _restore_attributes_(gm, attributes) +# # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired +# # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. +# transform_to_dynamic_input_(gm, is_retracing=True) +# _patch_arguments_(gm, gm.static2dynamic) +# return gm + + +# def retrace_graph_with( +# gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None +# ) -> GraphModule: +# """ +# Retraces a GraphModule by either using a tracer or a function using a tracer (for instance +# torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and +# restoring anything necessary after the retrace. +# """ +# if tracer is None and func is None: +# raise ValueError("Either a tracer or a function using a tracer must be provided.") +# elif tracer is not None and func is not None: +# raise ValueError("Either provide a tracer or a function using a tracer, but not both.") +# else: +# gm, attributes = prepare_for_retracing(gm) +# tracing_func = tracer.trace if tracer else func +# traced = tracing_func(gm) +# restore_after_retracing_(traced, attributes) +# return traced + + +def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None) -> int: if forbidden_values is None: forbidden_values = [] value = random.randint(low, high) @@ -599,16 +603,11 @@ def symbolic_trace( Example: - ```python - from transformers.utils.fx import symbolic_trace - - traced_model = symbolic_trace( - model, - input_names=["input_ids", "attention_mask", "token_type_ids"], - batch_size=1, - sequence_length=128, - ) - ```""" + ```python + from transformers.utils.fx import symbolic_trace + traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"]) + ``` + """ if input_names is None: input_names = model.dummy_inputs.keys() From f7a69eb68f0a55af90818baa377f173beb2b680c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 11:42:23 +0100 Subject: [PATCH 05/18] Refactored things related to num_choices (and batch_size, sequence_length) --- src/transformers/utils/fx.py | 84 +++++++++++------------------------ tests/test_modeling_common.py | 11 +---- 2 files changed, 28 insertions(+), 67 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 266326fdc255..825ca22151c4 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -40,6 +40,7 @@ def _generate_supported_model_classes( model_name: Type[PretrainedConfig], supported_tasks: Optional[Union[str, List[str]]] = None, ) -> List[Type[PreTrainedModel]]: + model_config_class = CONFIG_MAPPING[model_name] task_mapping = { "default": MODEL_MAPPING, @@ -80,15 +81,10 @@ def _generate_supported_model_classes( "gptj", "gpt_neo", "t5", -] + "roberta", + "layoutlm", + "xlnet", -_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [ - "albert", - "bert", - "distilbert", - "mobilebert", - "electra", - "megatron-bert", ] _REGULAR_SUPPORTED_MODELS = [] @@ -100,21 +96,10 @@ def _generate_supported_model_classes( _SPECIAL_SUPPORTED_MODELS = [ GPT2DoubleHeadsModel, + XLNetForQuestionAnswering, ] _SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS) -_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES: - if isinstance(item, dict): - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item)) - else: - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item)) - -_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple( - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES -) - class HFProxy(Proxy): """ @@ -233,7 +218,7 @@ class HFTracer(Tracer): modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, } - def __init__(self, batch_size: int = 1, sequence_length: int = 128, num_choices: int = -1): + def __init__(self): super().__init__() self._leaf_functions_register = {} @@ -248,7 +233,6 @@ def __init__(self, batch_size: int = 1, sequence_length: int = 128, num_choices: f"{TORCH_FX_REQUIRED_VERSION} is supported." ) - self.shape = [batch_size, sequence_length] self.prev_module = None self.recorded_methods = None @@ -334,14 +318,14 @@ def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedMode return cache_names, original_methods - def _generate_dummy_input(self, model: PreTrainedModel, input_name: str) -> Dict[str, torch.Tensor]: + def _generate_dummy_input(self, model: PreTrainedModel, input_name: str, shape: List[int]) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device inputs_dict = {} if input_name in ["labels", "start_positions", "end_positions"]: - batch_size = self.shape[0] + batch_size = shape[0] if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ @@ -364,15 +348,15 @@ def _generate_dummy_input(self, model: PreTrainedModel, input_name: str) -> Dict *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), GPT2DoubleHeadsModel, ]: - inputs_dict["labels"] = torch.zeros(self.shape, dtype=torch.long, device=device) + inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: raise NotImplementedError(f"{model_class} not supported yet.") elif "mask" in input_name or "ids" in input_name: - inputs_dict[input_name] = torch.zeros(self.shape, dtype=torch.long, device=device) + inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: - shape = self.shape + [model.config.hidden_size] - inputs_dict[input_name] = torch.zeros(shape, dtype=torch.float, device=device) + shape_with_hidden_size = shape + [model.config.hidden_size] + inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device) return inputs_dict @@ -384,13 +368,18 @@ def record(self, model: PreTrainedModel, input_names: List[str], method_names: O if method_names is None: method_names = self._DEFAULT_METHODS_TO_RECORD - num_choices = _generate_random_int(low=2, high=5) + # Creating a random input shape to generate dummy inputs. + batch_size = _generate_random_int() + sequence_length = _generate_random_int() + shape = [batch_size, sequence_length] + if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - self.shape.insert(1, num_choices) + num_choices = _generate_random_int(low=2, high=5) + shape.insert(1, num_choices) inputs = {} for input_name in input_names: - inputs.update(self._generate_dummy_input(model, input_name)) + inputs.update(self._generate_dummy_input(model, input_name, shape)) cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) self.original_methods = original_methods @@ -578,7 +567,6 @@ def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Option def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, - num_choices: int = -1, ) -> GraphModule: """ @@ -589,14 +577,6 @@ def symbolic_trace( The model to trace. input_names (`List[str]`, *optional*): The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead. - batch_size (`int`, *optional*, defaults to 1): - The batch size of the traced model inputs. - sequence_length (`int` or `List[int]]`): - The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence - lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length, - decoder_sequence_length]`. - num_choices (`int`, *optional*, defaults to -1): - The number of possible choices for a multiple choice task. Returns: `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. @@ -614,26 +594,14 @@ def symbolic_trace( sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - forbidden_values = [] - batch_size = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(batch_size) - sequence_length = _generate_random_int(forbidden_values=forbidden_values) - - # if not isinstance(model, _SUPPORTED_MODELS): - # supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) - # raise NotImplementedError( - # f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" - # ) - # if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( - # model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES - # ): - # supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) - # raise NotImplementedError( - # f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" - # ) + if not isinstance(model, _SUPPORTED_MODELS): + supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) + raise NotImplementedError( + f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" + ) # Tracing. - tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) + tracer = HFTracer() traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 64c62ebdf30b..606ae1ee9916 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -704,20 +704,13 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model_output = model(**filtered_inputs) rank = len(input_ids.shape) - if rank == 2: - batch_size, sequence_length = input_ids.shape - num_choices = -1 - elif rank == 3: - batch_size, num_choices, sequence_length = input_ids.shape - else: + if rank not in [2, 3]: raise NotImplementedError( f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." ) - # import pytest; pytest.set_trace() - traced_model = symbolic_trace(model, input_names, num_choices=num_choices) + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) - # import pytest; pytest.set_trace() except RuntimeError: self.fail("Couldn't trace module.") From b0e3d968e7ac2c30e150e2ff0942af784a1c5a62 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 11:49:26 +0100 Subject: [PATCH 06/18] style fix --- src/transformers/utils/fx.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 825ca22151c4..a864bc925c36 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -84,7 +84,6 @@ def _generate_supported_model_classes( "roberta", "layoutlm", "xlnet", - ] _REGULAR_SUPPORTED_MODELS = [] @@ -293,8 +292,8 @@ def wrapped(*args, **kwargs): def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]): """ - Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference - before symbolic tracing. + Helper function that patches torch.Tensor methods (specified by the method_names list) to record model + inference before symbolic tracing. """ cache_names = {} original_methods = {} @@ -318,7 +317,9 @@ def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedMode return cache_names, original_methods - def _generate_dummy_input(self, model: PreTrainedModel, input_name: str, shape: List[int]) -> Dict[str, torch.Tensor]: + def _generate_dummy_input( + self, model: PreTrainedModel, input_name: str, shape: List[int] + ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device @@ -362,8 +363,7 @@ def _generate_dummy_input(self, model: PreTrainedModel, input_name: str, shape: def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None): """ - Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic - tracing. + Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing. """ if method_names is None: method_names = self._DEFAULT_METHODS_TO_RECORD @@ -513,12 +513,10 @@ def create_arg(self, a: Any) -> Argument: # @transformation # def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: # """ -# Prepares a GraphModule produced by symbolic_trace for retracing by: -# -# - Caching all the attributes specific to the way the model was initially traced -# - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes -# For instance, the need to retrace a GraphModule can happen when applying quantization. -# """ +# Prepares a GraphModule produced by symbolic_trace for retracing by: # # - Caching all the attributes specific to the +way the model was initially traced # - Patching back the model to a "static input shapes" version if it was traced to +accept dynamic input shapes # For instance, the need to retrace a GraphModule can happen when applying quantization. # +""" # attributes = _cache_attributes(gm) # _patch_arguments_(gm, gm.dynamic2static) # @@ -539,10 +537,10 @@ def create_arg(self, a: Any) -> Argument: # gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None # ) -> GraphModule: # """ -# Retraces a GraphModule by either using a tracer or a function using a tracer (for instance -# torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and -# restoring anything necessary after the retrace. -# """ +# Retraces a GraphModule by either using a tracer or a function using a tracer (for instance # +torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and # +restoring anything necessary after the retrace. # +""" # if tracer is None and func is None: # raise ValueError("Either a tracer or a function using a tracer must be provided.") # elif tracer is not None and func is not None: From 636099b6ce610b7b5744a10e6d2e1d69746658a2 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 15:43:10 +0100 Subject: [PATCH 07/18] Updated fx to work on PyTorch 1.10 --- src/transformers/file_utils.py | 2 +- src/transformers/utils/fx.py | 76 +++++++++------------------------- 2 files changed, 21 insertions(+), 57 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 8a809a92e1a6..c2d94a518e8e 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -322,7 +322,7 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. -TORCH_FX_REQUIRED_VERSION = version.parse("1.9") +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index a864bc925c36..f4dce3826cb2 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,5 +1,6 @@ import functools import inspect +import math import random from types import ModuleType from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union @@ -203,6 +204,15 @@ def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]): setattr(torch.Tensor, name, method) +def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): + if forbidden_values is None: + forbidden_values = [] + value = random.randint(low, high) + while value in forbidden_values: + value = random.randint(low, high) + return value + + class HFTracer(Tracer): """ Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the @@ -217,14 +227,22 @@ class HFTracer(Tracer): modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, } - def __init__(self): - super().__init__() + def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): + # Loading the leaf functions register self._leaf_functions_register = {} for module, names in self._FUNCTIONS_TO_AUTOWRAP.items(): for name in names: self._register_leaf_function(module, name) + autowrap_functions = autowrap_functions + tuple( + patched for (_, _, patched) in self._leaf_functions_register.values() + ) + + super().__init__( + autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching + ) + if not is_torch_fx_available(): torch_version = version.parse(importlib_metadata.version("torch")) raise ImportError( @@ -430,8 +448,6 @@ def trace( self.record(root, input_names, method_names=method_names) - autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) self._patch_leaf_functions_for_root(root) graph = super().trace(root, concrete_args=concrete_args) @@ -510,58 +526,6 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) -# @transformation -# def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: -# """ -# Prepares a GraphModule produced by symbolic_trace for retracing by: # # - Caching all the attributes specific to the -way the model was initially traced # - Patching back the model to a "static input shapes" version if it was traced to -accept dynamic input shapes # For instance, the need to retrace a GraphModule can happen when applying quantization. # -""" -# attributes = _cache_attributes(gm) -# _patch_arguments_(gm, gm.dynamic2static) -# -# return gm, attributes - - -# def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): -# """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" -# _restore_attributes_(gm, attributes) -# # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired -# # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. -# transform_to_dynamic_input_(gm, is_retracing=True) -# _patch_arguments_(gm, gm.static2dynamic) -# return gm - - -# def retrace_graph_with( -# gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None -# ) -> GraphModule: -# """ -# Retraces a GraphModule by either using a tracer or a function using a tracer (for instance # -torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and # -restoring anything necessary after the retrace. # -""" -# if tracer is None and func is None: -# raise ValueError("Either a tracer or a function using a tracer must be provided.") -# elif tracer is not None and func is not None: -# raise ValueError("Either provide a tracer or a function using a tracer, but not both.") -# else: -# gm, attributes = prepare_for_retracing(gm) -# tracing_func = tracer.trace if tracer else func -# traced = tracing_func(gm) -# restore_after_retracing_(traced, attributes) -# return traced - - -def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None) -> int: - if forbidden_values is None: - forbidden_values = [] - value = random.randint(low, high) - while value in forbidden_values: - value = random.randint(low, high) - return value - - def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, From 505333dedff0dd01763ea11b907b588e08632788 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 15:55:23 +0100 Subject: [PATCH 08/18] Postponed autowrap_function feature usage for later --- src/transformers/utils/fx.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index f4dce3826cb2..4aa82f7620e7 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -163,6 +163,7 @@ def wrapper(*args, **kwargs): def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: + @functools.wraps(mapping[func_name]) def wrapper(*args, **kwargs): return mapping[func_name](*args, **kwargs) @@ -235,9 +236,10 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatc for name in names: self._register_leaf_function(module, name) - autowrap_functions = autowrap_functions + tuple( - patched for (_, _, patched) in self._leaf_functions_register.values() - ) + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + # autowrap_functions = autowrap_functions + tuple( + # patched for (_, _, patched) in self._leaf_functions_register.values() + # ) super().__init__( autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching @@ -453,6 +455,7 @@ def trace( graph = super().trace(root, concrete_args=concrete_args) self._patch_leaf_functions_for_root(root, restore=True) + _reset_tensor_methods(self.original_methods) # TODO: keep this until necessary. From 74f74d7892daf9934ab5bc2980faf76e616f73df Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 16:19:59 +0100 Subject: [PATCH 09/18] style fix --- src/transformers/utils/fx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 4aa82f7620e7..e41c38a3e684 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -163,7 +163,6 @@ def wrapper(*args, **kwargs): def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: - @functools.wraps(mapping[func_name]) def wrapper(*args, **kwargs): return mapping[func_name](*args, **kwargs) From e171ed294755666fc08f6d537b6b422a49154194 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Nov 2021 17:29:03 +0100 Subject: [PATCH 10/18] fixed issue --- src/transformers/utils/fx.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index e41c38a3e684..37f5341e41e0 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -449,6 +449,10 @@ def trace( self.record(root, input_names, method_names=method_names) + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + self._patch_leaf_functions_for_root(root) graph = super().trace(root, concrete_args=concrete_args) From 6faf2638da1e471859f24d47dfee5ae4b650007d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 10 Nov 2021 10:26:32 +0100 Subject: [PATCH 11/18] implemented suggestions --- src/transformers/modeling_utils.py | 20 ------------------- .../models/xlnet/modeling_xlnet.py | 10 +++++----- src/transformers/utils/fx.py | 5 +++-- tests/test_modeling_albert.py | 2 +- tests/test_modeling_bart.py | 1 - tests/test_modeling_bert.py | 2 +- tests/test_modeling_common.py | 7 +++---- tests/test_modeling_distilbert.py | 2 +- tests/test_modeling_electra.py | 4 +--- tests/test_modeling_gpt2.py | 2 +- tests/test_modeling_gpt_neo.py | 2 +- tests/test_modeling_gptj.py | 2 +- tests/test_modeling_layoutlm.py | 1 - tests/test_modeling_megatron_bert.py | 3 +-- tests/test_modeling_mobilebert.py | 2 +- tests/test_modeling_roberta.py | 2 +- tests/test_modeling_t5.py | 2 +- tests/test_modeling_xlnet.py | 1 - 18 files changed, 22 insertions(+), 48 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 837a014f2096..997d8818cde2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -285,26 +285,6 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: - # batch_size, seq_length = input_shape - # seq_ids = torch.arange(seq_length, device=device) - # causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # # causal and attention masks must have same type with pytorch version < 1.3 - # causal_mask = causal_mask.to(attention_mask.dtype) - - # if causal_mask.shape[1] < attention_mask.shape[1]: - # prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - # causal_mask = torch.cat( - # [ - # torch.ones( - # (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype - # ), - # causal_mask, - # ], - # axis=-1, - # ) - - # extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = self.create_extended_attention_mask_for_decoder( input_shape, attention_mask, device ) diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 1f4245e74984..278320a6b41f 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1464,7 +1464,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + **kwargs, ) logits = self.lm_loss(transformer_outputs[0]) @@ -1564,7 +1564,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + **kwargs, ) output = transformer_outputs[0] @@ -1781,7 +1781,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + **kwargs, ) output = transformer_outputs[0] @@ -1878,7 +1878,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + **kwargs, ) sequence_output = outputs[0] @@ -2018,7 +2018,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # **kwargs, + **kwargs, ) hidden_states = transformer_outputs[0] start_logits = self.start_logits(hidden_states, p_mask=p_mask) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 37f5341e41e0..c851f93bc8ea 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -83,8 +83,9 @@ def _generate_supported_model_classes( "gpt_neo", "t5", "roberta", - "layoutlm", - "xlnet", + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # "layoutlm", + # "xlnet", ] _REGULAR_SUPPORTED_MODELS = [] diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index b01b459d86a5..b3bac4e955a0 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -231,7 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes + fx_ready = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 319b4c38de41..957350b824ce 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -411,7 +411,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): if is_torch_available() else () ) - # fx_ready_model_classes = all_model_classes all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index e010cbaf180e..c12f32e605bc 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -444,7 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 606ae1ee9916..6118727bea37 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -108,7 +108,7 @@ class ModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () - fx_ready_model_classes = () + fx_ready = False test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -658,14 +658,13 @@ def test_torch_fx_output_loss(self): self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available(): + if not is_torch_fx_available() or not self.fx_ready: return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False - model_classes = self.fx_ready_model_classes - for model_class in model_classes: + for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) model.eval() diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 4a86d9a74a97..2b5c803302ad 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -209,7 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) - fx_ready_model_classes = all_model_classes + fx_ready = True test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 3dd94223c726..eaa70162c771 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -369,9 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () - - fx_ready_model_classes = all_model_classes + fx_ready = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index ef51c815e455..46bc2683a73b 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True test_missing_keys = False test_model_parallel = True diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index a8e5b4babc57..d50a82950646 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True test_missing_keys = False test_pruning = False test_model_parallel = False diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index dd743b80d76a..38fdca3aa7b5 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True test_pruning = False test_missing_keys = False test_model_parallel = False diff --git a/tests/test_modeling_layoutlm.py b/tests/test_modeling_layoutlm.py index acbcfa0df9d5..67423fe21fd1 100644 --- a/tests/test_modeling_layoutlm.py +++ b/tests/test_modeling_layoutlm.py @@ -215,7 +215,6 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) - # fx_ready_model_classes = all_model_classes def setUp(self): self.model_tester = LayoutLMModelTester(self) diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index 5a06d57a9ec3..71296f9f9e2d 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -283,8 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - + fx_ready = True # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index abf64e416f0a..57066ebd4789 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -269,7 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes + fx_ready = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 11d826f00148..8f311faabd7f 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -356,7 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 575850aa9014..d3b9268372bb 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_ready = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 46fed1019717..f4e90fbe7749 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -527,7 +527,6 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable - fx_ready_model_classes = all_model_classes test_pruning = False # XLNet has 2 QA models -> need to manually set the correct labels for one of them here From 83aedfcaf8822f99a24ebe44840c000acdc826ed Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 6 Jan 2022 11:09:12 +0100 Subject: [PATCH 12/18] Add copyrights --- src/transformers/utils/fx.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c851f93bc8ea..d91460ffb994 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import inspect import math @@ -97,7 +112,8 @@ def _generate_supported_model_classes( _SPECIAL_SUPPORTED_MODELS = [ GPT2DoubleHeadsModel, - XLNetForQuestionAnswering, + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # XLNetForQuestionAnswering, ] _SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS) From ee75d025af6344e8767a79e3e9913fef8314d2ee Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 10 Jan 2022 16:13:14 +0100 Subject: [PATCH 13/18] Remove unnecessary file --- src/transformers/utils/fx.py | 1 + src/transformers/utils/fx_transformations.py | 321 ------------------- 2 files changed, 1 insertion(+), 321 deletions(-) delete mode 100644 src/transformers/utils/fx_transformations.py diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index d91460ffb994..f9cdc407aeb6 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -570,6 +570,7 @@ def symbolic_trace( ```python from transformers.utils.fx import symbolic_trace + traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"]) ``` """ diff --git a/src/transformers/utils/fx_transformations.py b/src/transformers/utils/fx_transformations.py deleted file mode 100644 index 3e181617af10..000000000000 --- a/src/transformers/utils/fx_transformations.py +++ /dev/null @@ -1,321 +0,0 @@ -import copy -import functools -import operator -from inspect import signature -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.fx import Graph, GraphModule, Node - - -# Torch FX transformation convention: -# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation -# - transformations that are inplace have a name ending with "_" - - -def _cache_attributes(gm: GraphModule) -> Dict[str, Any]: - attributes_to_keep = [ - "config", - "num_choices", - "dummy_inputs", - "use_dynamic_batch_size", - "use_dynamic_sequence_length", - "static_batch_size", - "static_sequence_length", - "static2dynamic", - "dynamic2static", - ] - attributes = {k: getattr(gm, k, None) for k in attributes_to_keep} - return attributes - - -def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]): - for name, attr in attributes.items(): - setattr(gm, name, attr) - - -def deepcopy_graph(gm: GraphModule) -> GraphModule: - """ - Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was - traced with dynamic axes, and what were the values if that is the case. - """ - - # First, create a copy of the module without the graph. - graph = gm.__dict__.pop("_graph") - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(gm.__dict__) - gm.__dict__["_graph"] = graph - - # Then, copy the graph. - val_map = {} - graph_clone = Graph() - output_val = graph_clone.graph_copy(graph, val_map=val_map) - graph_clone.output(output_val) - - # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. - # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. - clone = gm.__class__(fake_mod, graph_clone) - - # Restore the dynamic axes related attributes to the clone. - attributes = _cache_attributes(gm) - attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} - attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} - _restore_attributes_(clone, attributes) - - return clone - - -def transformation(func): - """ - Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the - original. - """ - - def map_fn(arg): - if isinstance(arg, GraphModule): - return deepcopy_graph(arg) - return arg - - @functools.wraps(func) - def wrapper(*args, **kwargs): - new_args = tuple(map_fn(arg) for arg in args) - new_kwargs = {k: map_fn(v) for k, v in kwargs.items()} - return func(*new_args, **new_kwargs) - - wrapper._is_transformation = True - - return wrapper - - -def compose_transformations( - *args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False -) -> GraphModule: - """ - Allows to compose transformations together and takes of: - - 1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that - are decorated with @transformation (which means that they are not modifying the original GraphModule) are - unwrapped to make them inplace. - 2. Linting and recompiling only at the end of the composition for performance purposes. - """ - args = list(args) - if not inplace: - args.insert(0, deepcopy_graph) - - for i, transformation in enumerate(args[:-1]): - sig = signature(transformation) - - # Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is - # already handled by this function. - if getattr(transformation, "_is_transformation", False): - transformation = transformation.__wrapped__ - - # Linting and recompiling only after the last transformation applied to make composition efficient. - if "lint_and_recompile" in sig.parameters: - args[i] = functools.partial(transformation, lint_and_recompile=False) - - def reduce_func(f, g): - def compose_f_and_g(gm): - output_g = g(gm) - if output_g is None: - output_g = gm - output_f = f(output_g) - if output_f is None: - output_f = gm - return output_f - - return compose_f_and_g - - return functools.reduce(reduce_func, reversed(args), lambda x: x) - - -def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True): - """Removes all the unused nodes in a GraphModule.""" - graph = gm.graph - for node in graph.nodes: - if not node.users and node.op not in ["placeholder", "output"]: - graph.erase_node(node) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the batch size dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - batch_size_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names: - with graph.inserting_after(node): - batch_size_node = graph.call_method("size", args=(node, 0)) - - if batch_size_node is None: - raise ValueError("Could not insert the node that computes the batch size") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[batch_size_node.name] = None - - return batch_size_node - - -def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - encoder_sequence_length_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name: - with graph.inserting_after(node): - # There are two cases to handle: - # 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the - # input shapes is [batch_size, sequence_length] => index 1 - # 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input - # shape is [batch_size, num_choices, sequence_length] => index 2 - encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2)) - - if encoder_sequence_length_node is None: - raise ValueError("Could not insert the node that computes the encoder sequence length") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[encoder_sequence_length_node.name] = None - - return encoder_sequence_length_node - - -def _change_view_methods_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the - batch_size / sequence_length nodes. - """ - graph = gm.graph - for node in graph.nodes: - if node.op == "call_method" and node.target == "view": - if isinstance(node.args[1], tuple): - node.args = (node.args[0], *node.args[1]) - node.args = tuple((mapping.get(arg, arg) for arg in node.args)) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_getitem_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """Patches getitem nodes by replacing current arguments to their corresponding values in mapping.""" - # TODO: combine this with the patch_argument function which seems to do almost the same thing. - graph = gm.graph - for node in graph.nodes: - if node.op == "call_function" and node.target == operator.getitem: - indices = node.args[1] - if isinstance(indices, tuple): - new_indices = [] - for idx in indices: - if isinstance(idx, slice): - new_indices.append( - slice( - mapping.get(idx.start, idx.start), - mapping.get(idx.stop, idx.stop), - mapping.get(idx.step, idx.step), - ) - ) - elif isinstance(idx, int): - new_indices.append(mapping.get(idx, idx)) - else: - new_indices.append(idx) - - node.args = (node.args[0], tuple(new_indices)) - else: - node.args = (node.args[0], mapping.get(node.args[1], node.args[1])) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_arguments_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples - and slices). - """ - - def _patch_slice(s, mapping): - return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step)) - - graph = gm.graph - supported_types = (Node, str, int, float) - for node in graph.nodes: - new_args = [] - for arg in node.args: - if isinstance(arg, tuple): - new_arg = [] - for a in arg: - if isinstance(a, slice): - new_arg.append(_patch_slice(a, mapping)) - else: - new_arg.append(mapping.get(a, a)) - new_args.append(tuple(new_arg)) - elif isinstance(arg, slice): - new_args.append(_patch_slice(arg, mapping)) - elif isinstance(arg, supported_types): - new_args.append(mapping.get(arg, arg)) - else: - new_args.append(arg) - node.args = tuple(new_args) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False): - """Transformation that enables traced models to perform inference on dynamic input shapes.""" - graph = gm.graph - static2dynamic = {} - - # Inserting the nodes that will fetch the batch size and sequence lengths dynamically. - if gm.use_dynamic_batch_size: - batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_batch_size] = batch_size_node - if gm.num_choices > 0: - with graph.inserting_after(batch_size_node): - static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function( - operator.mul, args=(batch_size_node, gm.num_choices) - ) - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None - - if gm.use_dynamic_sequence_length: - encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node - - # TODO: do the same for the decoder. - pass - - _change_view_methods_(gm, static2dynamic, lint_and_recompile=False) - _patch_getitem_(gm, static2dynamic, lint_and_recompile=False) - - remove_unused_nodes_(gm, lint_and_recompile=False) - - graph.lint() - gm.recompile() - - gm.static2dynamic = static2dynamic - gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()} From 4f22de3caf423e5ec32d82a6c148bbc569b8ca44 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Feb 2022 10:53:46 +0100 Subject: [PATCH 14/18] Fix copies --- src/transformers/models/realm/modeling_realm.py | 4 ++-- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 165e62c0ef6c..118916413863 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -260,7 +260,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -349,7 +349,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index cdea06ac57b6..cfeb788ec62e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -181,7 +181,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -270,7 +270,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) From e566f161dc69c882278ee718ba37cf65f2cfa21d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Feb 2022 12:26:09 +0100 Subject: [PATCH 15/18] Fix issue with add_new_model_like --- src/transformers/commands/add_new_model_like.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index e443a235c42e..6b9c39b29351 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1189,6 +1189,13 @@ def create_new_model_like( if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) ] + def disable_fx_test(filename: Path): + with open(filename) as fp: + content = fp.read() + with open(filename, "w") as fp: + new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) + fp.write(new_content) + for test_file in files_to_adapt: new_test_file_name = test_file.name.replace( old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased @@ -1201,6 +1208,11 @@ def create_new_model_like( dest_file=dest_file, add_copied_from=False, ) + disable_fx_test(dest_file) + print( + "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " + "for your new model." + ) # 4. Add model to auto classes add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes) From a90d319e9ecb99de3a15233719fcef91a316053f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Feb 2022 13:06:39 +0100 Subject: [PATCH 16/18] Fix issue with add_new_model_like --- src/transformers/commands/add_new_model_like.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index 6b9c39b29351..38de77f809fe 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1189,12 +1189,15 @@ def create_new_model_like( if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) ] - def disable_fx_test(filename: Path): + def disable_fx_test(filename: Path) -> bool: with open(filename) as fp: content = fp.read() with open(filename, "w") as fp: new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) fp.write(new_content) + return content != new_content + + disabled_fx_test = False for test_file in files_to_adapt: new_test_file_name = test_file.name.replace( @@ -1208,7 +1211,9 @@ def disable_fx_test(filename: Path): dest_file=dest_file, add_copied_from=False, ) - disable_fx_test(dest_file) + disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file) + + if disabled_fx_test: print( "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " "for your new model." From ae60baf13874196db205427bf75ff80359cbcc02 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Feb 2022 16:32:55 +0100 Subject: [PATCH 17/18] Apply suggestions --- src/transformers/commands/add_new_model_like.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index 38de77f809fe..eb5c5d295bed 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1192,8 +1192,8 @@ def create_new_model_like( def disable_fx_test(filename: Path) -> bool: with open(filename) as fp: content = fp.read() + new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) with open(filename, "w") as fp: - new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) fp.write(new_content) return content != new_content From 9ef5813d73b2af238e8820e117dd32415ce4c173 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 7 Feb 2022 21:39:08 +0100 Subject: [PATCH 18/18] Apply suggestions --- src/transformers/commands/add_new_model_like.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- tests/test_modeling_albert.py | 2 +- tests/test_modeling_bert.py | 2 +- tests/test_modeling_common.py | 4 ++-- tests/test_modeling_distilbert.py | 2 +- tests/test_modeling_electra.py | 2 +- tests/test_modeling_gpt2.py | 2 +- tests/test_modeling_gpt_neo.py | 2 +- tests/test_modeling_gptj.py | 2 +- tests/test_modeling_megatron_bert.py | 2 +- tests/test_modeling_mobilebert.py | 2 +- tests/test_modeling_roberta.py | 2 +- tests/test_modeling_t5.py | 2 +- 16 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index eb5c5d295bed..3ba5d71099de 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1192,7 +1192,7 @@ def create_new_model_like( def disable_fx_test(filename: Path) -> bool: with open(filename) as fp: content = fp.read() - new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) + new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content) with open(filename, "w") as fp: fp.write(new_content) return content != new_content diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 978939f38996..59df99e8ab91 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1410,7 +1410,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index a2c9be7f9dfa..c516ca57a1e5 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -891,7 +891,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 93329e2a5410..66163ad49fd0 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -945,7 +945,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index b3bac4e955a0..ab5595f4b6f8 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -231,7 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready = True + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index c12f32e605bc..7b8738fd60f3 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -444,7 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () - fx_ready = True + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6118727bea37..2ca59c3f0c9e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -108,7 +108,7 @@ class ModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () - fx_ready = False + fx_compatible = False test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -658,7 +658,7 @@ def test_torch_fx_output_loss(self): self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available() or not self.fx_ready: + if not is_torch_fx_available() or not self.fx_compatible: return configs_no_init = _config_zero_init(config) # To be sure we have no Nan diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 8d9b575cd354..b81e42bcf175 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -209,7 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) - fx_ready = True + fx_compatible = True test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index eaa70162c771..065d59682693 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -369,7 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready = True + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 46bc2683a73b..cd13be27bbc3 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_ready = True + fx_compatible = True test_missing_keys = False test_model_parallel = True diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index d50a82950646..b8f942ef1786 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () - fx_ready = True + fx_compatible = True test_missing_keys = False test_pruning = False test_model_parallel = False diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index 38fdca3aa7b5..d6b9f9292621 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () - fx_ready = True + fx_compatible = True test_pruning = False test_missing_keys = False test_model_parallel = False diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index 71296f9f9e2d..7ac507988fe0 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -283,7 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready = True + fx_compatible = True # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 57066ebd4789..6ca14526a6dc 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -269,7 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready = True + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 6b4246321692..1a55fda15292 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -356,7 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () - fx_ready = True + fx_compatible = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index d3b9268372bb..c0b5739bca7f 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - fx_ready = True + fx_compatible = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True