From ee0ce083d885ddb21ded776a0bb3ca3148887992 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 19 Mar 2020 12:05:38 +0100 Subject: [PATCH 001/162] first copy & past commit from Bert and morgans LSH code --- src/transformers/modeling_reformer.py | 1004 +++++++++++++++++++++++++ 1 file changed, 1004 insertions(+) create mode 100644 src/transformers/modeling_reformer.py diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py new file mode 100644 index 000000000000..9f75be1af2d8 --- /dev/null +++ b/src/transformers/modeling_reformer.py @@ -0,0 +1,1004 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model. """ + + +import logging +import math +import os + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from .activations import gelu, gelu_new, swish +from .configuration_reformer import ReformerConfig +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, prune_linear_layer + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} +# TODO: fill with pretrained model weights +# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", +# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", +# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", +# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", + + +def load_tf_weights_in_reformer(model, config, tf_checkpoint_path): + return NotImplementedError('check code below to implement') + """ Load tf checkpoints in a pytorch model. + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} + + +ReformerLayerNorm = torch.nn.LayerNorm + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class LSHAttention(nn.Module): + def __init__(self, bucket_size: int, nb_hashs: int = 1): + super().__init__() + + if (bucket_size % 2) != 0: + raise ValueError("bucket_size should be multiple of 2") + + if nb_hashs < 1: + raise ValueError("nb_hashs should be >= 1") + + self.bucket_size = bucket_size + self.nb_hashes = nb_hashs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + nb_buckets = max(1, x.size(1) // self.bucket_size) + + # Random projection matrices + if self.nb_hashes == 1: + h = torch.randn(1, x.size(-1), nb_buckets, device=x.device) + + # Hash + rotated = torch.bmm(x, h) + rotated = torch.cat((rotated, -rotated), dim=-1) + bucket_range = torch.arange(rotated.shape[-1], device=x.device).expand_as(rotated) + + # Extract bucket + rotated_sorted, idx = rotated.sort(dim=-1) + rotated_buckets = bucket_range.gather(-1, idx)[:, -self.nb_hashes:] + + h, *_ = rotated_buckets.shape + buckets = torch.reshape(rotated_buckets.permute((*_, h)), (-1,)) + else: + h = torch.randn(1, x.size(-1), self.nb_hashes, nb_buckets, device=x.device) + offsets = torch.arange(self.nb_hashes, device=x.device) + offsets = (offsets * nb_buckets).view(1, 1, -1) + + # Hash + rotated = torch.matmul(x, h.flatten(2)).view(x.size()[:2] + (self.nb_hashes, nb_buckets)) + rotated = torch.cat((rotated, -rotated), dim=-1) + buckets = torch.argmax(rotated, dim=-1) + buckets = (buckets + offsets).view(x.size(0), -1) + + +class ReformerSelfAttention(nn.Module): + def __init__(self, config): + raise NotImplementedError('This has be changed -> insert LSHAttention') + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + if encoder_hidden_states is not None: + mixed_key_layer = self.key(encoder_hidden_states) + mixed_value_layer = self.value(encoder_hidden_states) + attention_mask = encoder_attention_mask + else: + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ReformerModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ReformerAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = ReformerSelfAttention(config) + self.output = ReformerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_outputs = self.self( + hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ReformerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ReformerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ReformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = ReformerAttention(config) + self.is_decoder = config.is_decoder + if self.is_decoder: + self.crossattention = ReformerAttention(config) + self.intermediate = ReformerIntermediate(config) + self.output = ReformerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask + ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class ReformerPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ReformerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class ReformerLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = ReformerPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class ReformerOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class ReformerPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ReformerLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class ReformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = ReformerConfig + pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + load_tf_weights = load_tf_weights_in_reformer + base_model_prefix = "reformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, ReformerLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +REFORMER_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +REFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.ReformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. +""" + + +@add_start_docstrings( + "The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.", + REFORMER_START_DOCSTRING, +) +class ReformerModel(ReformerPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + :obj:`is_decoder` argument of the configuration set to :obj:`True`; an + :obj:`encoder_hidden_states` is expected as an input to the forward pass. + + .. _`Attention is all you need`: + https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + self.pooler = ReformerPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during pre-training. + + This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerModel.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + causal_mask = causal_mask.to( + attention_mask.dtype + ) # causal and attention masks must have same type with pytorch version < 1.3 + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + elif encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( + encoder_hidden_shape, encoder_attention_mask.shape + ) + ) + + encoder_extended_attention_mask = encoder_extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """Reformer Model with two heads on top as done during the pre-training: a `masked language modeling` head and + a `next sentence prediction (classification)` head. """, + REFORMER_START_DOCSTRING, +) +class ReformerForPreTraining(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.cls = ReformerPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + masked_lm_labels=None, + next_sentence_label=None, + ): + r""" + masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False + continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + + Examples:: + + from transformers import ReformerTokenizer, ReformerForPreTraining + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerForPreTraining.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + prediction_scores, seq_relationship_scores = outputs[:2] + + """ + + outputs = self.reformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + outputs = (prediction_scores, seq_relationship_score,) + outputs[ + 2: + ] # add hidden states and attention if they are here + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + outputs = (total_loss,) + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Reformer Model with a `next sentence prediction (classification)` head on top. """, REFORMER_START_DOCSTRING, +) +class ReformerForNextSentencePrediction(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.cls = ReformerOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + ): + r""" + next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): + Next sequence prediction (classification) loss. + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerTokenizer, ReformerForNextSentencePrediction + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerForNextSentencePrediction.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + seq_relationship_scores = outputs[0] + + """ + + outputs = self.reformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + pooled_output = outputs[1] + + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + outputs = (next_sentence_loss,) + outputs + + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) From 32591155e19e2d03baba67bd676cb69846eb2869 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 20 Mar 2020 19:26:00 +0100 Subject: [PATCH 002/162] add easy way to compare to trax original code --- src/transformers/modeling_reformer.py | 2 + tests/test_modeling_reformer.py | 155 ++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 tests/test_modeling_reformer.py diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 9f75be1af2d8..718adeae2248 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -205,6 +205,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: buckets = torch.argmax(rotated, dim=-1) buckets = (buckets + offsets).view(x.size(0), -1) + return buckets + class ReformerSelfAttention(nn.Module): def __init__(self, config): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py new file mode 100644 index 000000000000..f928c7d00b80 --- /dev/null +++ b/tests/test_modeling_reformer.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2020 Huggingface +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +import numpy as np + +from transformers import is_torch_available # noqa: F401 +from .utils import require_torch, torch_device # noqa: F401 + + +if is_torch_available(): + import torch # noqa: F401 +# from transformers.modeling_reformer import () + + +class TraxUtils(object): + """ class that will help for testing in the beginning + should be deleted step-by-step + + README (HOW-TO-INSTALL TRAX): + 1) git clone https://github.com/patrickvonplaten/trax.git + + - I had to do one tiny change to make the imports work, + see: https://github.com/patrickvonplaten/trax/commit/6c23e88afe7f1c57b0c38eeaa4d450e5f912590c) + 2) link your PYTHON_PATH to ~/trax/trax + 3) pip install all the missing packages HINT: the package gin is installed + + - HINT: the package gin is installed with pip install gin-config==0.1.4 + and not pip install gin. + - The other packages can just be installed with pip install form + error message " missing" + """ + + def __init__(self, shape=(3, 32, 8)): + from trax import math + from trax.math import numpy as trax_np + from trax.shapes import ShapeDtype + import jax + from trax.layers.research.efficient_attention_v2 import LSHSelfAttention + + self.math = math + self.np_trax = trax_np + self.ShapeDtype = ShapeDtype + self.jax = jax + self.LSHSelfAttention = LSHSelfAttention + self._shape = shape + + def convert_to_jax_array(self, np_array): + return self.jax.numpy.asarray(np_array) + + def get_input_signature(self, shape=None): + with self.math.use_backend("jax"): + if shape is None: + shape = self._shape + input_signature = self.ShapeDtype(shape) + return input_signature + + def get_lsh_self_attention_layer( + self, + shape=None, + n_heads=5, + d_qk=7, + d_v=17, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ): + + with self.math.use_backend("jax"): + if shape is None: + shape = self._shape + layer = self.LSHSelfAttention( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + chunk_len=chunk_len, + n_chunks_before=n_chunks_before, + n_chunks_after=n_chunks_after, + n_hashes=n_hashes, + n_buckets=n_buckets, + use_reference_code=use_reference_code, + attention_dropout=attention_dropout, + mode=mode, + ) + + return layer + + def forward_layer( + self, + np_input_data, + layer=None, + input_signature=None, + random_number_generator=None, + ): + with self.math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + + if layer is None: + layer = self.get_lsh_self_attention_layer() + + if input_signature is None: + input_signature = self.get_input_signature() + + import ipdb + ipdb.set_trace() + + weights, state = layer.init(input_signature) + + if random_number_generator is None: + random_number_generator = layer.new_rngs(1)[0] + + output = layer( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output + + +@require_torch +class ReformerIntegrationTests(unittest.TestCase): + def _get_random_input(self, shape): + return np.random.rand(*shape) + + def _get_trax_utils(self, shape): + return TraxUtils(shape) + + def test_lsh_hashing(self): + shape = (3, 32, 8) + + np_input = self._get_random_input(shape) + trax_utils = self._get_trax_utils(shape) + + lsh_trax_output = trax_utils.forward_layer(np_input) # noqa: F841 + + pass From 25d162e200a735596e6c2c92409cd3d237f0353d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Mar 2020 12:46:57 +0100 Subject: [PATCH 003/162] translate most of function --- src/transformers/modeling_reformer.py | 152 ++++++++++++++++++++++++++ tests/test_modeling_reformer.py | 30 ++--- 2 files changed, 164 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 718adeae2248..59ecebe22d73 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -163,6 +163,158 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs return embeddings +class LSHSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() +# def __init__(self, n_heads=2, dim_qk=64, dim_v=64, share_qk='unused', +# causal=False, chunk_len=None, n_chunks_before=1, n_chunks_after=0, +# n_hashes=1, n_buckets=256, mode='train', +# predict_mem_len=2048, predict_drop_len=256, +# attention_dropout=0.0, n_parallel_heads=1, +# use_python_loop=False, use_reference_code=False +# ): + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query_key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + + def _hash_vectors(self, vectors): + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.n_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for num_bucket in self.num_buckets: + assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) + rotation_size += num_bucket + num_buckets *= num_bucket + + rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + + random_rotations = torch.randn(rotations_shape, device=vectors.device) + # TODO: check with dimensions here! + rotated_vectors = torch.einsum('td,dhb->htb', vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for num_bucket in self.num_buckets: + rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] + cur_sum += num_bucket // 2 + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + if buckets is None: + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + cur_product *= num_bucket + + # buckets is now (self.n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(self.num_hashes, device=vectors.device) + offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + buckets = torch.reshape(buckets + offsets, (-1,)) + + return buckets + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + ): + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_key_vectors = self.transpose_for_scores(mixed_query_key_vectors) + value_vectors = self.transpose_for_scores(mixed_value_vectors) + + buckets = self._hash_vectors(query_key_vectors) + + sequence_len = hidden_states.shape[1] + + # TODO: check shapes here + assert int(buckets.shape[0]) == self.num_hashes * sequence_len + + # TODO: what is ticker? + ticker = torch.arange(self.num_hashes * sequence_len, device=buckets.device) + buckets_and_t = sequence_len * buckets + (ticker % sequence_len) + + # Hash-based sort + sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) + _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_len) + sorted_query_key_vectors = torch.gather(query_key_vectors, sorted_ticker, dim=0) + sorted_value_vectors = torch.gather(value_vectors, sorted_ticker, dim=0) + + # Hash-based sort +# _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) +# sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) +# sticker = jax.lax.stop_gradient(sticker) +# undo_sort = jax.lax.stop_gradient(undo_sort) +# +# st = (sticker % seqlen) +# sq = np.take(q, st, axis=0) +# sv = np.take(v, st, axis=0) +# mask_fn = functools.partial( +# mask_self_attention, causal=self.causal, exclude_self=True) +# q_info = st + + query_info = sorted_ticker + + so, slogits = attend( + sq, k=None, v=sv, + q_chunk_len=self.chunk_len, + n_chunks_before=self.n_chunks_before, + n_chunks_after=self.n_chunks_after, + mask_fn=mask_fn, q_info=q_info, + dropout=self.attention_dropout, rng=attend_rng, + ) + + # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would + # also work, but these helpers include performance optimizations for TPU. + o = permute_via_gather(so, undo_sort, sticker, axis=0) + logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) + + if self.n_hashes > 1: + o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) + probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) + o = np.sum(o * probs, axis=0) + + assert o.shape == (seqlen, w_v.shape[-1]) + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self.output_dropout, output_rng) + return out, state + + class LSHAttention(nn.Module): def __init__(self, bucket_size: int, nb_hashs: int = 1): super().__init__() diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index f928c7d00b80..dfaf0a5fcb5d 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -17,6 +17,13 @@ import unittest import numpy as np +# trax imports - to be deleted later +from trax import math as trax_math +from trax.shapes import ShapeDtype as trax_ShapeDtype +import jax +from trax.layers.research.efficient_attention_v2 import LSHSelfAttention + + from transformers import is_torch_available # noqa: F401 from .utils import require_torch, torch_device # noqa: F401 @@ -45,27 +52,17 @@ class TraxUtils(object): """ def __init__(self, shape=(3, 32, 8)): - from trax import math - from trax.math import numpy as trax_np - from trax.shapes import ShapeDtype - import jax - from trax.layers.research.efficient_attention_v2 import LSHSelfAttention - - self.math = math - self.np_trax = trax_np - self.ShapeDtype = ShapeDtype - self.jax = jax self.LSHSelfAttention = LSHSelfAttention self._shape = shape def convert_to_jax_array(self, np_array): - return self.jax.numpy.asarray(np_array) + return jax.numpy.asarray(np_array) def get_input_signature(self, shape=None): - with self.math.use_backend("jax"): + with trax_math.use_backend("jax"): if shape is None: shape = self._shape - input_signature = self.ShapeDtype(shape) + input_signature = trax_ShapeDtype(shape) return input_signature def get_lsh_self_attention_layer( @@ -85,7 +82,7 @@ def get_lsh_self_attention_layer( mode="train", ): - with self.math.use_backend("jax"): + with trax_math.use_backend("jax"): if shape is None: shape = self._shape layer = self.LSHSelfAttention( @@ -112,7 +109,7 @@ def forward_layer( input_signature=None, random_number_generator=None, ): - with self.math.use_backend("jax"): + with trax_math.use_backend("jax"): input_data = self.convert_to_jax_array(np_input_data) if layer is None: @@ -121,9 +118,6 @@ def forward_layer( if input_signature is None: input_signature = self.get_input_signature() - import ipdb - ipdb.set_trace() - weights, state = layer.init(input_signature) if random_number_generator is None: From dc07c08e7d46af533d29dc2e96e5c4d09cdd39da Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Mar 2020 15:30:11 +0100 Subject: [PATCH 004/162] make trax lsh self attention deterministic with numpy seed + copy paste code --- src/transformers/modeling_reformer.py | 176 +++++++++++++++----------- tests/test_modeling_reformer.py | 1 + 2 files changed, 106 insertions(+), 71 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 59ecebe22d73..87e32f3f8697 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -186,12 +186,90 @@ def __init__(self, config): self.query_key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + ): + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_key_vectors = self.transpose_for_scores(mixed_query_key_vectors) + value_vectors = self.transpose_for_scores(mixed_value_vectors) + + buckets = self._hash_vectors(query_key_vectors) + + sequence_len = hidden_states.shape[1] + + # TODO: check shapes here + assert int(buckets.shape[0]) == self.num_hashes * sequence_len + + # TODO: what is ticker? + ticker = torch.arange(self.num_hashes * sequence_len, device=buckets.device) + buckets_and_t = sequence_len * buckets + (ticker % sequence_len) + + # Hash-based sort + sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) + _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_len) + sorted_query_key_vectors = torch.gather(query_key_vectors, sorted_ticker, dim=0) + sorted_value_vectors = torch.gather(value_vectors, sorted_ticker, dim=0) + + query_info = sorted_ticker + + sorted_query_key_vectors = torch.reshape(sorted_query_key_vectors, (-1, self.query_key_chunk_len, sorted_query_key_vectors.shape[-1])) + query_info = torch.reshape(query_info, (-1, self.query_key_chunk_len)) + + key_vectors = sorted_query_key_vectors + key_value_chunk_len = self.query_key_chunk_len + key_value_info = query_info + + sorted_value_vectors = torch.reshape(sorted_value_vectors, (-1, key_value_chunk_len, value_vectors.shape[-1])) + + key_vectors = self._length_normalize(key_vectors) + + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + sorted_value_vectors = self._look_adjacent(sorted_value_vectors, self.num_chunks_before, self.num_chunks_after) + key_value_info = self._look_adjacent(key_value_info, self.num_chunks_before, self.num_chunks_after) + + dots = torch.matmul(sorted_query_key_vectors, key_vectors.transpose(-1, -2)) + + # TODO: add masking here + + dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) + dots = torch.exp(dots - dots_logsumexp) + + # TODO: add dropout here + + sorted_out = torch.matmul(dots, sorted_value_vectors) + sorted_out = torch.reshape(sorted_out, (-1, sorted_out.shape[-1])) + dots_logsumexp = torch.reshape(dots_logsumexp, (-1,)) + + out = torch.gather(sorted_out, undo_sort_indices, dim=0) + logits = torch.gather(dots_logsumexp, undo_sort_indices, dim=0) + + if self.num_hashes > 1: + out = torch.reshape(out, (self.num_hashes, sequence_len, out.shape[-1])) + logits = torch.reshape(logits, (self.num_hashes, sequence_len, 1)) + probs = torch.exp(logits - torch.logsumexp(logits, dim=0, keepdim=True)) + out = torch.sum(out * probs, dim=0) + + assert out.shape == (sequence_len, self.value[-1]) + out = self.dense(out) + + # TODO: apply broadcasted dropout + + return out + def _hash_vectors(self, vectors): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to @@ -238,81 +316,37 @@ def _hash_vectors(self, vectors): return buckets - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - ): - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) - - query_key_vectors = self.transpose_for_scores(mixed_query_key_vectors) - value_vectors = self.transpose_for_scores(mixed_value_vectors) - - buckets = self._hash_vectors(query_key_vectors) - - sequence_len = hidden_states.shape[1] - - # TODO: check shapes here - assert int(buckets.shape[0]) == self.num_hashes * sequence_len + def _look_adjacent(self, x, n_chunks_before, n_chunks_after): + """ Used to implement attention between consecutive chunks. - # TODO: what is ticker? - ticker = torch.arange(self.num_hashes * sequence_len, device=buckets.device) - buckets_and_t = sequence_len * buckets + (ticker % sequence_len) - - # Hash-based sort - sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) - _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) + Args: + x: array of shape [n_chunks, chunk_len, ...] + n_chunks_before: Number of previous chunks to attend to. + n_chunks_after: Number of subsequent chunks to attend to. + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if n_chunks_before == 0 and n_chunks_after == 0: + return x - sorted_ticker = (sorted_ticker % sequence_len) - sorted_query_key_vectors = torch.gather(query_key_vectors, sorted_ticker, dim=0) - sorted_value_vectors = torch.gather(value_vectors, sorted_ticker, dim=0) + slices = [] + for i in range(-n_chunks_before, n_chunks_after + 1): + if i == 0: + slices.append(x) + else: + slices.append(torch.cat([x[i:, ...], x[:i, ...]], dim=0)) + return torch.cat(slices, dim=1) - # Hash-based sort -# _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) -# sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) -# sticker = jax.lax.stop_gradient(sticker) -# undo_sort = jax.lax.stop_gradient(undo_sort) -# -# st = (sticker % seqlen) -# sq = np.take(q, st, axis=0) -# sv = np.take(v, st, axis=0) -# mask_fn = functools.partial( -# mask_self_attention, causal=self.causal, exclude_self=True) -# q_info = st - - query_info = sorted_ticker - - so, slogits = attend( - sq, k=None, v=sv, - q_chunk_len=self.chunk_len, - n_chunks_before=self.n_chunks_before, - n_chunks_after=self.n_chunks_after, - mask_fn=mask_fn, q_info=q_info, - dropout=self.attention_dropout, rng=attend_rng, - ) + def _length_normalize(self, x, epsilon=1e-6): + variance = torch.mean(x**2, -1, keepdim=True) + norm_x = x / torch.sqrt(variance + epsilon) + return norm_x - # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would - # also work, but these helpers include performance optimizations for TPU. - o = permute_via_gather(so, undo_sort, sticker, axis=0) - logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - - if self.n_hashes > 1: - o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) - logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) - probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) - o = np.sum(o * probs, axis=0) - - assert o.shape == (seqlen, w_v.shape[-1]) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self.output_dropout, output_rng) - return out, state + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) class LSHAttention(nn.Module): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index dfaf0a5fcb5d..2a748403f354 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -80,6 +80,7 @@ def get_lsh_self_attention_layer( use_reference_code=True, attention_dropout=0.0, mode="train", + hash_seed=0 ): with trax_math.use_backend("jax"): From 09d42309215db1d98791cfd634462a9c41242319 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Mar 2020 16:41:32 +0100 Subject: [PATCH 005/162] add same config --- tests/test_modeling_reformer.py | 53 ++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 2a748403f354..6df019682857 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -68,37 +68,36 @@ def get_input_signature(self, shape=None): def get_lsh_self_attention_layer( self, shape=None, - n_heads=5, - d_qk=7, - d_v=17, - causal=True, - chunk_len=8, - n_chunks_before=1, - n_chunks_after=0, - n_hashes=2, - n_buckets=4, + num_attention_heads=None, + hidden_size=None, + query_key_chunk_len=None, + num_chunks_before=None, + num_chunks_after=None, + num_hashes=None, + num_buckets=None, use_reference_code=True, attention_dropout=0.0, mode="train", - hash_seed=0 + seed=0 ): with trax_math.use_backend("jax"): if shape is None: shape = self._shape layer = self.LSHSelfAttention( - n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - chunk_len=chunk_len, - n_chunks_before=n_chunks_before, - n_chunks_after=n_chunks_after, - n_hashes=n_hashes, - n_buckets=n_buckets, + n_heads=num_attention_heads, + d_qk=hidden_size, + d_v=hidden_size, + causal=False, + chunk_len=query_key_chunk_len, + n_chunks_before=num_chunks_before, + n_chunks_after=num_chunks_after, + n_hashes=num_hashes, + n_buckets=num_buckets, use_reference_code=use_reference_code, attention_dropout=attention_dropout, mode=mode, + hash_seed=seed ) return layer @@ -139,12 +138,26 @@ def _get_random_input(self, shape): def _get_trax_utils(self, shape): return TraxUtils(shape) + def _create_config(self, num_attention_heads=6, hidden_size=17, num_hashes=2, num_buckets=4, query_key_chunk_len=5, num_chunks_before=1, num_chunks_after=0, seed=0): + return { + 'hidden_size': hidden_size, + 'num_hashes': num_hashes, + 'num_buckets': num_buckets, + 'query_key_chunk_len': query_key_chunk_len, + 'num_chunks_before': num_chunks_before, + 'num_chunks_after': num_chunks_after, + 'seed': seed + } + def test_lsh_hashing(self): shape = (3, 32, 8) + config_dict = self._create_config() + np_input = self._get_random_input(shape) trax_utils = self._get_trax_utils(shape) - lsh_trax_output = trax_utils.forward_layer(np_input) # noqa: F841 + trax_layer = trax_utils.get_layer(**config_dict) + lsh_trax_output = trax_utils.forward_layer(np_input, layer=trax_layer) # noqa: F841 pass From 9386450016891375f5ad6134851225c97a79e479 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Mar 2020 16:41:41 +0100 Subject: [PATCH 006/162] add same config --- src/transformers/modeling_reformer.py | 36 +++++++++++++++++---------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 87e32f3f8697..2e4adc40650a 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -29,6 +29,9 @@ from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, prune_linear_layer +# DELETE later +import numpy + logger = logging.getLogger(__name__) @@ -173,25 +176,29 @@ def __init__(self, config): # attention_dropout=0.0, n_parallel_heads=1, # use_python_loop=False, use_reference_code=False # ): - if config.hidden_size % config.num_attention_heads != 0: + if config['hidden_size'] % config['num_attention_heads'] != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) + "heads (%d)" % (config['hidden_size'], config['num_attention_heads']) ) - self.output_attentions = config.output_attentions - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.num_attention_heads = config['num_attention_heads'] + self.hidden_size = config['hidden_size'] + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.query_key = nn.Linear(self.hidden_size, self.all_head_size) + self.value = nn.Linear(self.hidden_size, self.all_head_size) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + # add Dropout + self.python_seed = config['seed'] - self.num_hashes = config.num_hashes - self.num_buckets = config.num_buckets + self.num_hashes = config['num_hashes'] + self.num_buckets = config['num_buckets'] + self.query_key_chunk_len = config['query_key_chunk_len'] + self.num_chunks_before = config['num_chunks_before'] + self.num_chunks_after = config['num_chunks_after'] def forward( self, @@ -263,7 +270,7 @@ def forward( probs = torch.exp(logits - torch.logsumexp(logits, dim=0, keepdim=True)) out = torch.sum(out * probs, dim=0) - assert out.shape == (sequence_len, self.value[-1]) + assert out.shape == (sequence_len, self.value_vectors[-1]) out = self.dense(out) # TODO: apply broadcasted dropout @@ -288,7 +295,10 @@ def _hash_vectors(self, vectors): rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) - random_rotations = torch.randn(rotations_shape, device=vectors.device) + numpy.random.seed(self.hash_seed) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) + +# random_rotations = torch.randn(rotations_shape, device=vectors.device) # TODO: check with dimensions here! rotated_vectors = torch.einsum('td,dhb->htb', vectors, random_rotations) From 3fb1182156257152ca3ad5215528d20e5ba817d2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 09:29:58 +0200 Subject: [PATCH 007/162] fix merge conflicts --- src/transformers/__init__.py | 3 +++ src/transformers/modeling_reformer.py | 14 +++++++++----- tests/test_modeling_reformer.py | 21 ++++++++++++++++----- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c869fa55d0e0..adcf7a01123f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -318,6 +318,9 @@ ElectraModel, load_tf_weights_in_electra, ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, + + from .modeling_reformer import ( + LSHSelfAttention ) # Optimization diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 2e4adc40650a..7e5624cd2412 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -25,13 +25,15 @@ from torch.nn import CrossEntropyLoss from .activations import gelu, gelu_new, swish -from .configuration_reformer import ReformerConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, prune_linear_layer # DELETE later import numpy +# ADD later +#from .configuration_reformer import ReformerConfig + logger = logging.getLogger(__name__) @@ -176,14 +178,16 @@ def __init__(self, config): # attention_dropout=0.0, n_parallel_heads=1, # use_python_loop=False, use_reference_code=False # ): - if config['hidden_size'] % config['num_attention_heads'] != 0: + + self.num_attention_heads = config['num_attention_heads'] + self.hidden_size = config['hidden_size'] * self.num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config['hidden_size'], config['num_attention_heads']) ) - self.num_attention_heads = config['num_attention_heads'] - self.hidden_size = config['hidden_size'] self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -727,7 +731,7 @@ class ReformerPreTrainedModel(PreTrainedModel): a simple interface for downloading and loading pretrained models. """ - config_class = ReformerConfig +# TODO: config_class = ReformerConfig pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_reformer base_model_prefix = "reformer" diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 6df019682857..8c1f0a3884a1 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -21,7 +21,9 @@ from trax import math as trax_math from trax.shapes import ShapeDtype as trax_ShapeDtype import jax +#from trax.layers.research.efficient_attention_v2 import LSHSelfAttention as TraxLSHSelfAttention from trax.layers.research.efficient_attention_v2 import LSHSelfAttention +#from transformers import LSHSelfAttention from transformers import is_torch_available # noqa: F401 @@ -52,7 +54,6 @@ class TraxUtils(object): """ def __init__(self, shape=(3, 32, 8)): - self.LSHSelfAttention = LSHSelfAttention self._shape = shape def convert_to_jax_array(self, np_array): @@ -65,7 +66,7 @@ def get_input_signature(self, shape=None): input_signature = trax_ShapeDtype(shape) return input_signature - def get_lsh_self_attention_layer( + def get_layer( self, shape=None, num_attention_heads=None, @@ -84,7 +85,7 @@ def get_lsh_self_attention_layer( with trax_math.use_backend("jax"): if shape is None: shape = self._shape - layer = self.LSHSelfAttention( + layer = LSHSelfAttention( n_heads=num_attention_heads, d_qk=hidden_size, d_v=hidden_size, @@ -113,7 +114,7 @@ def forward_layer( input_data = self.convert_to_jax_array(np_input_data) if layer is None: - layer = self.get_lsh_self_attention_layer() + layer = self.get_layer() if input_signature is None: input_signature = self.get_input_signature() @@ -140,6 +141,7 @@ def _get_trax_utils(self, shape): def _create_config(self, num_attention_heads=6, hidden_size=17, num_hashes=2, num_buckets=4, query_key_chunk_len=5, num_chunks_before=1, num_chunks_after=0, seed=0): return { + 'num_attention_heads': num_attention_heads, 'hidden_size': hidden_size, 'num_hashes': num_hashes, 'num_buckets': num_buckets, @@ -155,9 +157,18 @@ def test_lsh_hashing(self): config_dict = self._create_config() np_input = self._get_random_input(shape) + trax_utils = self._get_trax_utils(shape) trax_layer = trax_utils.get_layer(**config_dict) - lsh_trax_output = trax_utils.forward_layer(np_input, layer=trax_layer) # noqa: F841 + lsh_trax_output, weights, state = trax_utils.forward_layer(np_input, layer=trax_layer) # noqa: F841 + + import ipdb + ipdb.set_trace() + +# hf_layer = LSHSelfAttention(config_dict) + + import ipdb + ipdb.set_trace() pass From ebb9d3f66ea44c922aba55f374c127a8f5e452f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 31 Mar 2020 16:11:00 +0200 Subject: [PATCH 008/162] make layer init work --- src/transformers/modeling_reformer.py | 11 +++-- tests/test_modeling_reformer.py | 66 ++++++++++++++++----------- 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 7e5624cd2412..1b4f9065ae2a 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -179,6 +179,7 @@ def __init__(self, config): # use_python_loop=False, use_reference_code=False # ): + self.hidden_states_input_size = config['input_size'] self.num_attention_heads = config['num_attention_heads'] self.hidden_size = config['hidden_size'] * self.num_attention_heads @@ -191,9 +192,9 @@ def __init__(self, config): self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_key = nn.Linear(self.hidden_size, self.all_head_size) - self.value = nn.Linear(self.hidden_size, self.all_head_size) - self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size) + self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size) + self.dense = nn.Linear(self.all_head_size, self.hidden_size) # add Dropout self.python_seed = config['seed'] @@ -210,6 +211,10 @@ def forward( attention_mask=None, head_mask=None, ): + + import ipdb + ipdb.set_trace() + mixed_query_key_vectors = self.query_key(hidden_states) mixed_value_vectors = self.value(hidden_states) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 8c1f0a3884a1..63dfe6e53d05 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1,5 +1,4 @@ -# coding=utf-8 -# Copyright 2020 Huggingface +# coding=utf-8 # Copyright 2020 Huggingface # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +20,10 @@ from trax import math as trax_math from trax.shapes import ShapeDtype as trax_ShapeDtype import jax -#from trax.layers.research.efficient_attention_v2 import LSHSelfAttention as TraxLSHSelfAttention -from trax.layers.research.efficient_attention_v2 import LSHSelfAttention -#from transformers import LSHSelfAttention +from trax.layers.research.efficient_attention_v2 import ( + LSHSelfAttention as TraxLSHSelfAttention, +) +from transformers import LSHSelfAttention from transformers import is_torch_available # noqa: F401 @@ -79,13 +79,14 @@ def get_layer( use_reference_code=True, attention_dropout=0.0, mode="train", - seed=0 + seed=0, + **kwargs ): with trax_math.use_backend("jax"): if shape is None: shape = self._shape - layer = LSHSelfAttention( + layer = TraxLSHSelfAttention( n_heads=num_attention_heads, d_qk=hidden_size, d_v=hidden_size, @@ -98,7 +99,7 @@ def get_layer( use_reference_code=use_reference_code, attention_dropout=attention_dropout, mode=mode, - hash_seed=seed + hash_seed=seed, ) return layer @@ -128,31 +129,45 @@ def forward_layer( input_data, weights=weights, state=state, rng=random_number_generator ) - return output + return output, weights, state @require_torch class ReformerIntegrationTests(unittest.TestCase): + + hidden_size = 32 + seq_len = 7 + def _get_random_input(self, shape): return np.random.rand(*shape) def _get_trax_utils(self, shape): return TraxUtils(shape) - def _create_config(self, num_attention_heads=6, hidden_size=17, num_hashes=2, num_buckets=4, query_key_chunk_len=5, num_chunks_before=1, num_chunks_after=0, seed=0): + def _create_config( + self, + input_size=8, + num_attention_heads=5, + num_hashes=2, + num_buckets=4, + num_chunks_before=1, + num_chunks_after=0, + seed=0, + ): return { - 'num_attention_heads': num_attention_heads, - 'hidden_size': hidden_size, - 'num_hashes': num_hashes, - 'num_buckets': num_buckets, - 'query_key_chunk_len': query_key_chunk_len, - 'num_chunks_before': num_chunks_before, - 'num_chunks_after': num_chunks_after, - 'seed': seed + "input_size": self.hidden_size, + "num_attention_heads": num_attention_heads, + "hidden_size": self.hidden_size, + "num_hashes": num_hashes, + "num_buckets": num_buckets, + "query_key_chunk_len": self.seq_len, + "num_chunks_before": num_chunks_before, + "num_chunks_after": num_chunks_after, + "seed": seed, } def test_lsh_hashing(self): - shape = (3, 32, 8) + shape = (3, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim config_dict = self._create_config() @@ -161,14 +176,13 @@ def test_lsh_hashing(self): trax_utils = self._get_trax_utils(shape) trax_layer = trax_utils.get_layer(**config_dict) - lsh_trax_output, weights, state = trax_utils.forward_layer(np_input, layer=trax_layer) # noqa: F841 - - import ipdb - ipdb.set_trace() + lsh_trax_output, weights, state = trax_utils.forward_layer( + np_input, layer=trax_layer + ) # noqa: F841 -# hf_layer = LSHSelfAttention(config_dict) + torch_input = torch.tensor(np_input, dtype=torch.float) - import ipdb - ipdb.set_trace() + hf_layer = LSHSelfAttention(config_dict) + lsh_hf_output = hf_layer(torch_input) pass From c910e03e61662383f9caee426264a676820a3302 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Apr 2020 11:56:03 +0200 Subject: [PATCH 009/162] implemented hash_vectors function for lsh attention --- src/transformers/modeling_reformer.py | 38 ++++++++++++++++++--------- tests/test_modeling_reformer.py | 33 +++++++++++++++++++---- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1b4f9065ae2a..ad9934ca2ac8 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -181,7 +181,7 @@ def __init__(self, config): self.hidden_states_input_size = config['input_size'] self.num_attention_heads = config['num_attention_heads'] - self.hidden_size = config['hidden_size'] * self.num_attention_heads + self.hidden_size = config['hf_hidden_size'] if self.hidden_size % self.num_attention_heads != 0: raise ValueError( @@ -192,12 +192,12 @@ def __init__(self, config): self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size) - self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size) - self.dense = nn.Linear(self.all_head_size, self.hidden_size) + self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) + self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) # add Dropout - self.python_seed = config['seed'] + self.hash_seed = config['seed'] self.num_hashes = config['num_hashes'] self.num_buckets = config['num_buckets'] @@ -212,9 +212,6 @@ def forward( head_mask=None, ): - import ipdb - ipdb.set_trace() - mixed_query_key_vectors = self.query_key(hidden_states) mixed_value_vectors = self.value(hidden_states) @@ -223,6 +220,8 @@ def forward( buckets = self._hash_vectors(query_key_vectors) + return buckets + sequence_len = hidden_states.shape[1] # TODO: check shapes here @@ -295,7 +294,7 @@ def _hash_vectors(self, vectors): rotation_size = self.num_buckets num_buckets = self.num_buckets else: - # Factorize the hash if self.n_buckets is a list or tuple + # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for num_bucket in self.num_buckets: assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) @@ -304,12 +303,19 @@ def _hash_vectors(self, vectors): rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + # TODO: delete later when integration tests are ok numpy.random.seed(self.hash_seed) random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) + # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 # random_rotations = torch.randn(rotations_shape, device=vectors.device) - # TODO: check with dimensions here! - rotated_vectors = torch.einsum('td,dhb->htb', vectors, random_rotations) + + # rotated_vectors has dim: + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + # TODO: IMPORTANT: At the moment we use the same random rotation over all batches + # and heads -> is that bad? It seems like in original reformer a different random + # rotation is used over heads and batches + rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -327,10 +333,16 @@ def _hash_vectors(self, vectors): buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) cur_product *= num_bucket - # buckets is now (self.n_hashes, seqlen). Next we add offsets so that - # bucket numbers from different hashing rounds don't overlap. + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. offsets = torch.arange(self.num_hashes, device=vectors.device) offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + + batch_size = vectors.shape[0] + assert vectors.shape[1] == self.num_attention_heads + + # repeat same values for Batch_size and Num_Attn_Heads + offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) buckets = torch.reshape(buckets + offsets, (-1,)) return buckets diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 63dfe6e53d05..23b528148a23 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -79,7 +79,8 @@ def get_layer( use_reference_code=True, attention_dropout=0.0, mode="train", - seed=0, + seed=None, + path_to_save_weights=None, **kwargs ): @@ -100,6 +101,7 @@ def get_layer( attention_dropout=attention_dropout, mode=mode, hash_seed=seed, + path_to_save_weights=path_to_save_weights ) return layer @@ -147,27 +149,30 @@ def _get_trax_utils(self, shape): def _create_config( self, input_size=8, - num_attention_heads=5, + num_attention_heads=1, num_hashes=2, num_buckets=4, num_chunks_before=1, num_chunks_after=0, seed=0, + path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights" ): return { "input_size": self.hidden_size, "num_attention_heads": num_attention_heads, "hidden_size": self.hidden_size, + "hf_hidden_size": num_attention_heads * self.hidden_size, "num_hashes": num_hashes, "num_buckets": num_buckets, "query_key_chunk_len": self.seq_len, "num_chunks_before": num_chunks_before, "num_chunks_after": num_chunks_after, "seed": seed, + "path_to_save_weights": path_to_save_weights } def test_lsh_hashing(self): - shape = (3, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim + shape = (1, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim config_dict = self._create_config() @@ -176,13 +181,31 @@ def test_lsh_hashing(self): trax_utils = self._get_trax_utils(shape) trax_layer = trax_utils.get_layer(**config_dict) - lsh_trax_output, weights, state = trax_utils.forward_layer( + output, weights, state = trax_utils.forward_layer( np_input, layer=trax_layer ) # noqa: F841 torch_input = torch.tensor(np_input, dtype=torch.float) hf_layer = LSHSelfAttention(config_dict) - lsh_hf_output = hf_layer(torch_input) + + # set torch weights for 1-to-1 comparison + with torch.no_grad(): + + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + + hf_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, self.hidden_size)) + hf_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, self.hidden_size)) + +# trax_query_out = np.load(config_dict['path_to_save_weights'] + '_1_query_out.npy') +# trax_query_weight = np.load(config_dict['path_to_save_weights'] + '_1_query_weight.npy') +# trax_query_input = np.load(config_dict['path_to_save_weights'] + '_1_query_input.npy') + + # check that buckets are equal + buckets = hf_layer(torch_input) + + import ipdb + ipdb.set_trace() pass From b956933fd9fb93917edea8e61a8486c4ffba3776 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Apr 2020 20:41:16 +0200 Subject: [PATCH 010/162] continue reformer translation --- src/transformers/modeling_reformer.py | 72 ++++++--------------------- tests/test_modeling_reformer.py | 10 ++-- 2 files changed, 18 insertions(+), 64 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ad9934ca2ac8..3668c41a5272 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -220,22 +220,25 @@ def forward( buckets = self._hash_vectors(query_key_vectors) - return buckets + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] - sequence_len = hidden_states.shape[1] + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - # TODO: check shapes here - assert int(buckets.shape[0]) == self.num_hashes * sequence_len + # TODO: what is ticker? Is ticker something like indices?? Ask authors + ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) - # TODO: what is ticker? - ticker = torch.arange(self.num_hashes * sequence_len, device=buckets.device) - buckets_and_t = sequence_len * buckets + (ticker % sequence_len) + buckets_and_t = sequence_length * buckets + (ticker % sequence_length) # Hash-based sort sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) - sorted_ticker = (sorted_ticker % sequence_len) + import ipdb + ipdb.set_trace() + + sorted_ticker = (sorted_ticker % sequence_length) sorted_query_key_vectors = torch.gather(query_key_vectors, sorted_ticker, dim=0) sorted_value_vectors = torch.gather(value_vectors, sorted_ticker, dim=0) @@ -273,12 +276,12 @@ def forward( logits = torch.gather(dots_logsumexp, undo_sort_indices, dim=0) if self.num_hashes > 1: - out = torch.reshape(out, (self.num_hashes, sequence_len, out.shape[-1])) - logits = torch.reshape(logits, (self.num_hashes, sequence_len, 1)) + out = torch.reshape(out, (self.num_hashes, sequence_length, out.shape[-1])) + logits = torch.reshape(logits, (self.num_hashes, sequence_length, 1)) probs = torch.exp(logits - torch.logsumexp(logits, dim=0, keepdim=True)) out = torch.sum(out * probs, dim=0) - assert out.shape == (sequence_len, self.value_vectors[-1]) + assert out.shape == (sequence_length, self.value_vectors[-1]) out = self.dense(out) # TODO: apply broadcasted dropout @@ -343,7 +346,7 @@ def _hash_vectors(self, vectors): # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - buckets = torch.reshape(buckets + offsets, (-1,)) + buckets = torch.reshape(buckets + offsets, (batch_size, self.num_attention_heads, -1)) return buckets @@ -380,51 +383,6 @@ def transpose_for_scores(self, x): return x.permute(0, 2, 1, 3) -class LSHAttention(nn.Module): - def __init__(self, bucket_size: int, nb_hashs: int = 1): - super().__init__() - - if (bucket_size % 2) != 0: - raise ValueError("bucket_size should be multiple of 2") - - if nb_hashs < 1: - raise ValueError("nb_hashs should be >= 1") - - self.bucket_size = bucket_size - self.nb_hashes = nb_hashs - - def forward(self, x: torch.Tensor) -> torch.Tensor: - nb_buckets = max(1, x.size(1) // self.bucket_size) - - # Random projection matrices - if self.nb_hashes == 1: - h = torch.randn(1, x.size(-1), nb_buckets, device=x.device) - - # Hash - rotated = torch.bmm(x, h) - rotated = torch.cat((rotated, -rotated), dim=-1) - bucket_range = torch.arange(rotated.shape[-1], device=x.device).expand_as(rotated) - - # Extract bucket - rotated_sorted, idx = rotated.sort(dim=-1) - rotated_buckets = bucket_range.gather(-1, idx)[:, -self.nb_hashes:] - - h, *_ = rotated_buckets.shape - buckets = torch.reshape(rotated_buckets.permute((*_, h)), (-1,)) - else: - h = torch.randn(1, x.size(-1), self.nb_hashes, nb_buckets, device=x.device) - offsets = torch.arange(self.nb_hashes, device=x.device) - offsets = (offsets * nb_buckets).view(1, 1, -1) - - # Hash - rotated = torch.matmul(x, h.flatten(2)).view(x.size()[:2] + (self.nb_hashes, nb_buckets)) - rotated = torch.cat((rotated, -rotated), dim=-1) - buckets = torch.argmax(rotated, dim=-1) - buckets = (buckets + offsets).view(x.size(0), -1) - - return buckets - - class ReformerSelfAttention(nn.Module): def __init__(self, config): raise NotImplementedError('This has be changed -> insert LSHAttention') diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 23b528148a23..7a2405e63f39 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -149,7 +149,7 @@ def _get_trax_utils(self, shape): def _create_config( self, input_size=8, - num_attention_heads=1, + num_attention_heads=2, num_hashes=2, num_buckets=4, num_chunks_before=1, @@ -198,14 +198,10 @@ def test_lsh_hashing(self): hf_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, self.hidden_size)) hf_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, self.hidden_size)) -# trax_query_out = np.load(config_dict['path_to_save_weights'] + '_1_query_out.npy') -# trax_query_weight = np.load(config_dict['path_to_save_weights'] + '_1_query_weight.npy') -# trax_query_input = np.load(config_dict['path_to_save_weights'] + '_1_query_input.npy') + trax_buckets_1 = np.load(config_dict['path_to_save_weights'] + '_1_buckets.npy') + trax_buckets_2 = np.load(config_dict['path_to_save_weights'] + '_2_buckets.npy') - # check that buckets are equal buckets = hf_layer(torch_input) - import ipdb - ipdb.set_trace() pass From a449c2efe7b69bc3f792dfb14102e750d33f440f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 12:05:29 +0200 Subject: [PATCH 011/162] hf LSHSelfAttentionLayer gives same output as trax layer --- src/transformers/modeling_reformer.py | 57 +++++++++++++++++---------- tests/test_modeling_reformer.py | 19 +++++---- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3668c41a5272..6b1c0bd2ebde 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -194,7 +194,7 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) - self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) + self.dense = nn.Linear(self.all_head_size, self.hidden_states_input_size, bias=False) # add Dropout self.hash_seed = config['seed'] @@ -235,25 +235,33 @@ def forward( sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) - import ipdb - ipdb.set_trace() + assert query_key_vectors.shape[-1] == self.attention_head_size + assert value_vectors.shape[-1] == self.attention_head_size sorted_ticker = (sorted_ticker % sequence_length) - sorted_query_key_vectors = torch.gather(query_key_vectors, sorted_ticker, dim=0) - sorted_value_vectors = torch.gather(value_vectors, sorted_ticker, dim=0) + expanded_sorted_ticker = sorted_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + sorted_query_key_vectors = torch.gather(query_key_vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_sorted_ticker) + sorted_value_vectors = torch.gather(value_vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_sorted_ticker) query_info = sorted_ticker - sorted_query_key_vectors = torch.reshape(sorted_query_key_vectors, (-1, self.query_key_chunk_len, sorted_query_key_vectors.shape[-1])) - query_info = torch.reshape(query_info, (-1, self.query_key_chunk_len)) + sorted_query_key_vectors = torch.reshape(sorted_query_key_vectors, (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len, self.attention_head_size)) + query_info = torch.reshape(query_info, (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len)) key_vectors = sorted_query_key_vectors key_value_chunk_len = self.query_key_chunk_len key_value_info = query_info - sorted_value_vectors = torch.reshape(sorted_value_vectors, (-1, key_value_chunk_len, value_vectors.shape[-1])) + sorted_value_vectors = torch.reshape(sorted_value_vectors, (batch_size, self.num_attention_heads, -1, key_value_chunk_len, value_vectors.shape[-1])) key_vectors = self._length_normalize(key_vectors) + key_vectors = key_vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32)) + + # Optionally include adjacent chunks. + if self.query_key_chunk_len is not None or key_value_chunk_len is not None: + assert self.query_key_chunk_len is not None and key_value_chunk_len is not None + else: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0 key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) sorted_value_vectors = self._look_adjacent(sorted_value_vectors, self.num_chunks_before, self.num_chunks_after) @@ -268,22 +276,25 @@ def forward( # TODO: add dropout here + # attend values sorted_out = torch.matmul(dots, sorted_value_vectors) - sorted_out = torch.reshape(sorted_out, (-1, sorted_out.shape[-1])) - dots_logsumexp = torch.reshape(dots_logsumexp, (-1,)) + sorted_out = torch.reshape(sorted_out, (batch_size, self.num_attention_heads, -1, self.attention_head_size)) + dots_logsumexp = torch.reshape(dots_logsumexp, (batch_size, self.num_attention_heads, -1)) - out = torch.gather(sorted_out, undo_sort_indices, dim=0) - logits = torch.gather(dots_logsumexp, undo_sort_indices, dim=0) + expanded_undo_sort_indices = undo_sort_indices.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out = torch.gather(sorted_out, 2, expanded_undo_sort_indices) + logits = torch.gather(dots_logsumexp, 2, undo_sort_indices) if self.num_hashes > 1: - out = torch.reshape(out, (self.num_hashes, sequence_length, out.shape[-1])) - logits = torch.reshape(logits, (self.num_hashes, sequence_length, 1)) - probs = torch.exp(logits - torch.logsumexp(logits, dim=0, keepdim=True)) - out = torch.sum(out * probs, dim=0) + out = torch.reshape(out, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, self.attention_head_size)) + logits = torch.reshape(logits, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, 1)) + probs = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out = torch.sum(out * probs, dim=2) - assert out.shape == (sequence_length, self.value_vectors[-1]) - out = self.dense(out) + assert out.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + out = self.transpose_for_output(out) + out = self.dense(out) # TODO: apply broadcasted dropout return out @@ -354,7 +365,7 @@ def _look_adjacent(self, x, n_chunks_before, n_chunks_after): """ Used to implement attention between consecutive chunks. Args: - x: array of shape [n_chunks, chunk_len, ...] + x: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] n_chunks_before: Number of previous chunks to attend to. n_chunks_after: Number of subsequent chunks to attend to. Returns: @@ -369,8 +380,8 @@ def _look_adjacent(self, x, n_chunks_before, n_chunks_after): if i == 0: slices.append(x) else: - slices.append(torch.cat([x[i:, ...], x[:i, ...]], dim=0)) - return torch.cat(slices, dim=1) + slices.append(torch.cat([x[:, :, i:, ...], x[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) def _length_normalize(self, x, epsilon=1e-6): variance = torch.mean(x**2, -1, keepdim=True) @@ -382,6 +393,10 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) + def transpose_for_output(self, x): + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) + class ReformerSelfAttention(nn.Module): def __init__(self, config): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 7a2405e63f39..baa5ecc12ef7 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -181,27 +181,26 @@ def test_lsh_hashing(self): trax_utils = self._get_trax_utils(shape) trax_layer = trax_utils.get_layer(**config_dict) - output, weights, state = trax_utils.forward_layer( + trax_output, trax_weights, trax_state = trax_utils.forward_layer( np_input, layer=trax_layer ) # noqa: F841 - - torch_input = torch.tensor(np_input, dtype=torch.float) + trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_layer = LSHSelfAttention(config_dict) # set torch weights for 1-to-1 comparison with torch.no_grad(): - - np_query_key = np.asarray(weights[0]) - np_value = np.asarray(weights[1]) + np_query_key = np.asarray(trax_weights[0]) + np_value = np.asarray(trax_weights[1]) + np_dense = np.asarray(trax_weights[2]) hf_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, self.hidden_size)) hf_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, self.hidden_size)) + hf_layer.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, self.hidden_size).contiguous().transpose(0, 1)) - trax_buckets_1 = np.load(config_dict['path_to_save_weights'] + '_1_buckets.npy') - trax_buckets_2 = np.load(config_dict['path_to_save_weights'] + '_2_buckets.npy') - - buckets = hf_layer(torch_input) + hf_input = torch.tensor(np_input, dtype=torch.float) + hf_output = hf_layer(hf_input) + assert torch.allclose(hf_output, trax_torch_output, atol=1e-6) pass From 35491d8ab377a6d299ffe3bd5a3eac3939d08e22 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 16:31:34 +0200 Subject: [PATCH 012/162] refactor code --- src/transformers/modeling_reformer.py | 246 ++++++++++---------------- tests/test_modeling_reformer.py | 2 +- 2 files changed, 99 insertions(+), 149 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 6b1c0bd2ebde..86d3097a54a2 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -17,7 +17,6 @@ import logging -import math import os import torch @@ -183,15 +182,15 @@ def __init__(self, config): self.num_attention_heads = config['num_attention_heads'] self.hidden_size = config['hf_hidden_size'] + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config['hidden_size'], config['num_attention_heads']) ) - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) self.dense = nn.Linear(self.all_head_size, self.hidden_states_input_size, bias=False) @@ -212,78 +211,39 @@ def forward( head_mask=None, ): - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) - - query_key_vectors = self.transpose_for_scores(mixed_query_key_vectors) - value_vectors = self.transpose_for_scores(mixed_value_vectors) - - buckets = self._hash_vectors(query_key_vectors) - + # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - - # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) - ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) - - buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) - # Hash-based sort - sorted_buckets, sorted_ticker = torch.sort(buckets_and_t, dim=-1) - _, undo_sort_indices = torch.sort(sorted_ticker, dim=-1) + query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) + value_vectors = self._transpose_for_scores(mixed_value_vectors) assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size - sorted_ticker = (sorted_ticker % sequence_length) - expanded_sorted_ticker = sorted_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - sorted_query_key_vectors = torch.gather(query_key_vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_sorted_ticker) - sorted_value_vectors = torch.gather(value_vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_sorted_ticker) - - query_info = sorted_ticker - - sorted_query_key_vectors = torch.reshape(sorted_query_key_vectors, (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len, self.attention_head_size)) - query_info = torch.reshape(query_info, (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len)) + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors) + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - key_vectors = sorted_query_key_vectors - key_value_chunk_len = self.query_key_chunk_len - key_value_info = query_info + sorted_ticker, undo_sorted_ticker = self._get_sorted_and_undo_sorted_ticker(sequence_length, buckets) - sorted_value_vectors = torch.reshape(sorted_value_vectors, (batch_size, self.num_attention_heads, -1, key_value_chunk_len, value_vectors.shape[-1])) + sorted_query_key_vectors = self._sort_by_indices(query_key_vectors, sorted_ticker) + sorted_value_vectors = self._sort_by_indices(value_vectors, sorted_ticker) - key_vectors = self._length_normalize(key_vectors) - key_vectors = key_vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32)) + sorted_query_key_vectors = self._split_by_chunk_len(sorted_query_key_vectors) + sorted_value_vectors = self._split_by_chunk_len(sorted_value_vectors) + sorted_ticker = self._split_by_chunk_len(sorted_ticker) # Optionally include adjacent chunks. - if self.query_key_chunk_len is not None or key_value_chunk_len is not None: - assert self.query_key_chunk_len is not None and key_value_chunk_len is not None - else: + if self.query_key_chunk_len is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) - sorted_value_vectors = self._look_adjacent(sorted_value_vectors, self.num_chunks_before, self.num_chunks_after) - key_value_info = self._look_adjacent(key_value_info, self.num_chunks_before, self.num_chunks_after) - - dots = torch.matmul(sorted_query_key_vectors, key_vectors.transpose(-1, -2)) - - # TODO: add masking here - - dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) - dots = torch.exp(dots - dots_logsumexp) - - # TODO: add dropout here - - # attend values - sorted_out = torch.matmul(dots, sorted_value_vectors) - sorted_out = torch.reshape(sorted_out, (batch_size, self.num_attention_heads, -1, self.attention_head_size)) - dots_logsumexp = torch.reshape(dots_logsumexp, (batch_size, self.num_attention_heads, -1)) + key_vectors = self._len_and_dim_norm(sorted_query_key_vectors) - expanded_undo_sort_indices = undo_sort_indices.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out = torch.gather(sorted_out, 2, expanded_undo_sort_indices) - logits = torch.gather(dots_logsumexp, 2, undo_sort_indices) + out, logits = self._attend(sorted_query_key_vectors, key_vectors, sorted_value_vectors, sorted_ticker, undo_sorted_ticker) if self.num_hashes > 1: out = torch.reshape(out, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, self.attention_head_size)) @@ -293,7 +253,7 @@ def forward( assert out.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) - out = self.transpose_for_output(out) + out = self._transpose_for_output(out) out = self.dense(out) # TODO: apply broadcasted dropout @@ -361,120 +321,110 @@ def _hash_vectors(self, vectors): return buckets - def _look_adjacent(self, x, n_chunks_before, n_chunks_after): + def _get_sorted_and_undo_sorted_ticker(self, sequence_length, buckets): + batch_size = buckets.shape[0] + + # TODO: what is ticker? Is ticker something like indices?? Ask authors + ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) + + buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + + # Hash-based sort + sorted_ticker = torch.argsort(buckets_and_t, dim=-1) + undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_length) + return sorted_ticker, undo_sorted_ticker + + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker): + key_vectors = self._look_adjacent(key_vectors) + value_vectors = self._look_adjacent(value_vectors) + ticker = self._look_adjacent(ticker) + + # TODO: add masking here + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + dots = torch.exp(query_key_dots - logits) + + # TODO: add dropout here + + # attend values + out_vectors = torch.matmul(dots, value_vectors) + + # merge chunk length + logits = self._merge_by_chunk_len(logits) + out_vectors = self._merge_by_chunk_len(out_vectors) + + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_ticker) + + return out_vectors, logits + + def _look_adjacent(self, vectors): """ Used to implement attention between consecutive chunks. Args: - x: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] - n_chunks_before: Number of previous chunks to attend to. - n_chunks_after: Number of subsequent chunks to attend to. + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] Returns: array of shape [n_chunks, N * chunk_len, ...], where N = (1 + n_chunks_before + n_chunks_after). """ - if n_chunks_before == 0 and n_chunks_after == 0: - return x + if self.num_chunks_before == 0 and self.num_chunks_after == 0: + return vectors slices = [] - for i in range(-n_chunks_before, n_chunks_after + 1): + for i in range(-self.num_chunks_before, self.num_chunks_after + 1): if i == 0: - slices.append(x) + slices.append(vectors) else: - slices.append(torch.cat([x[:, :, i:, ...], x[:, :, :i, ...]], dim=2)) + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) return torch.cat(slices, dim=3) - def _length_normalize(self, x, epsilon=1e-6): + def _len_and_dim_norm(self, vectors): + vectors = self._len_norm(vectors) + vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) + return vectors + + def _len_norm(self, x, epsilon=1e-6): variance = torch.mean(x**2, -1, keepdim=True) norm_x = x / torch.sqrt(variance + epsilon) return norm_x - def transpose_for_scores(self, x): + def _sort_by_indices(self, vectors, indices): + expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + sorted_vectors = torch.gather(vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_indices) + return sorted_vectors + + def _transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def transpose_for_output(self, x): + def _transpose_for_output(self, x): x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) + def _split_by_chunk_len(self, vectors): + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len) -class ReformerSelfAttention(nn.Module): - def __init__(self, config): - raise NotImplementedError('This has be changed -> insert LSHAttention') - super().__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - self.output_attentions = config.output_attentions - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) - attention_mask = encoder_attention_mask + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in ReformerModel forward() function) - attention_scores = attention_scores + attention_mask + def _merge_by_chunk_len(self, vectors): + assert len(vectors.shape) == 5, "Input vectors should be of rank 5, but is of rank {}".format(len(vectors.shape)) - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) - return outputs + batch_size = vectors.shape[0] + new_dim_shape = (batch_size, self.num_attention_heads, -1, vectors.shape[-1]) + return torch.reshape(vectors, new_dim_shape).squeeze(-1) class ReformerSelfOutput(nn.Module): @@ -494,7 +444,7 @@ def forward(self, hidden_states, input_tensor): class ReformerAttention(nn.Module): def __init__(self, config): super().__init__() - self.self = ReformerSelfAttention(config) + self.self = LSHSelfAttention(config) self.output = ReformerSelfOutput(config) self.pruned_heads = set() diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index baa5ecc12ef7..66a0e5f6f706 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -172,7 +172,7 @@ def _create_config( } def test_lsh_hashing(self): - shape = (1, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim + shape = (3, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim config_dict = self._create_config() From c9a091963478371619765ac4e261bdd92645e427 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 16:32:58 +0200 Subject: [PATCH 013/162] refactor code --- src/transformers/modeling_reformer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 86d3097a54a2..743754ad2aae 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -228,14 +228,13 @@ def forward( buckets = self._hash_vectors(query_key_vectors) assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - sorted_ticker, undo_sorted_ticker = self._get_sorted_and_undo_sorted_ticker(sequence_length, buckets) + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - sorted_query_key_vectors = self._sort_by_indices(query_key_vectors, sorted_ticker) - sorted_value_vectors = self._sort_by_indices(value_vectors, sorted_ticker) + sorted_query_key_vectors = self._sort_by_indices(query_key_vectors, ticker) + sorted_value_vectors = self._sort_by_indices(value_vectors, ticker) sorted_query_key_vectors = self._split_by_chunk_len(sorted_query_key_vectors) sorted_value_vectors = self._split_by_chunk_len(sorted_value_vectors) - sorted_ticker = self._split_by_chunk_len(sorted_ticker) # Optionally include adjacent chunks. if self.query_key_chunk_len is None: @@ -243,7 +242,7 @@ def forward( key_vectors = self._len_and_dim_norm(sorted_query_key_vectors) - out, logits = self._attend(sorted_query_key_vectors, key_vectors, sorted_value_vectors, sorted_ticker, undo_sorted_ticker) + out, logits = self._attend(sorted_query_key_vectors, key_vectors, sorted_value_vectors, undo_ticker) if self.num_hashes > 1: out = torch.reshape(out, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, self.attention_head_size)) @@ -337,10 +336,9 @@ def _get_sorted_and_undo_sorted_ticker(self, sequence_length, buckets): sorted_ticker = (sorted_ticker % sequence_length) return sorted_ticker, undo_sorted_ticker - def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker): + def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): key_vectors = self._look_adjacent(key_vectors) value_vectors = self._look_adjacent(value_vectors) - ticker = self._look_adjacent(ticker) # TODO: add masking here From 4ef4739741d858688e69da74e1616ed1e729469b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 16:34:27 +0200 Subject: [PATCH 014/162] refactor code --- src/transformers/modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 743754ad2aae..cbe2a9291154 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -320,7 +320,7 @@ def _hash_vectors(self, vectors): return buckets - def _get_sorted_and_undo_sorted_ticker(self, sequence_length, buckets): + def _get_ticker_and_undo_ticker(self, sequence_length, buckets): batch_size = buckets.shape[0] # TODO: what is ticker? Is ticker something like indices?? Ask authors From f5800746c8e5525518c016114cde0cf673c6c68a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 17:58:02 +0200 Subject: [PATCH 015/162] refactor --- src/transformers/modeling_reformer.py | 74 ++++++++++++++------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index cbe2a9291154..1ebc71080f52 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -230,35 +230,38 @@ def forward( ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - sorted_query_key_vectors = self._sort_by_indices(query_key_vectors, ticker) - sorted_value_vectors = self._sort_by_indices(value_vectors, ticker) + query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) + value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) - sorted_query_key_vectors = self._split_by_chunk_len(sorted_query_key_vectors) - sorted_value_vectors = self._split_by_chunk_len(sorted_value_vectors) + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.query_key_chunk_len) + value_vectors = self._split_dim_by(value_vectors, -1, self.query_key_chunk_len) # Optionally include adjacent chunks. if self.query_key_chunk_len is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - key_vectors = self._len_and_dim_norm(sorted_query_key_vectors) + key_vectors = self._len_and_dim_norm(query_key_vectors) - out, logits = self._attend(sorted_query_key_vectors, key_vectors, sorted_value_vectors, undo_ticker) + out_vectors, logits = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker) if self.num_hashes > 1: - out = torch.reshape(out, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, self.attention_head_size)) - logits = torch.reshape(logits, (batch_size, self.num_attention_heads, self.num_hashes, sequence_length, 1)) - probs = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) - out = torch.sum(out * probs, dim=2) + out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) + logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) - assert out.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - out = self._transpose_for_output(out) - out = self.dense(out) + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + + out_vectors = self._transpose_for_output(out_vectors) + out = self.dense(out_vectors) # TODO: apply broadcasted dropout return out def _hash_vectors(self, vectors): + batch_size = vectors.shape[0] + # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. @@ -277,12 +280,10 @@ def _hash_vectors(self, vectors): rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) # TODO: delete later when integration tests are ok - numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) - # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 # random_rotations = torch.randn(rotations_shape, device=vectors.device) - + numpy.random.seed(self.hash_seed) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # TODO: IMPORTANT: At the moment we use the same random rotation over all batches @@ -300,10 +301,12 @@ def _hash_vectors(self, vectors): rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] cur_sum += num_bucket // 2 rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + if buckets is None: buckets = torch.argmax(rotated_vectors, dim=-1) else: buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + cur_product *= num_bucket # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). @@ -311,14 +314,11 @@ def _hash_vectors(self, vectors): offsets = torch.arange(self.num_hashes, device=vectors.device) offsets = torch.reshape(offsets * num_buckets, (-1, 1)) - batch_size = vectors.shape[0] - assert vectors.shape[1] == self.num_attention_heads - # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - buckets = torch.reshape(buckets + offsets, (batch_size, self.num_attention_heads, -1)) + offset_buckets = self._merge_by_middle_dim(buckets + offsets) - return buckets + return offset_buckets def _get_ticker_and_undo_ticker(self, sequence_length, buckets): batch_size = buckets.shape[0] @@ -353,8 +353,8 @@ def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): out_vectors = torch.matmul(dots, value_vectors) # merge chunk length - logits = self._merge_by_chunk_len(logits) - out_vectors = self._merge_by_chunk_len(out_vectors) + logits = self._merge_by_middle_dim(logits).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors) expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) @@ -392,10 +392,10 @@ def _len_norm(self, x, epsilon=1e-6): norm_x = x / torch.sqrt(variance + epsilon) return norm_x - def _sort_by_indices(self, vectors, indices): - expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - sorted_vectors = torch.gather(vectors.repeat(1, 1, self.num_hashes, 1), 2, expanded_indices) - return sorted_vectors + def _gather_by_expansion(self, vectors, idxs): + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, self.num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) def _transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -406,9 +406,9 @@ def _transpose_for_output(self, x): x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) - def _split_by_chunk_len(self, vectors): + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): batch_size = vectors.shape[0] - split_dim_shape = (batch_size, self.num_attention_heads, -1, self.query_key_chunk_len) + split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) if len(vectors.shape) == 4: return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) @@ -417,12 +417,16 @@ def _split_by_chunk_len(self, vectors): else: raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - def _merge_by_chunk_len(self, vectors): - assert len(vectors.shape) == 5, "Input vectors should be of rank 5, but is of rank {}".format(len(vectors.shape)) - + def _merge_by_middle_dim(self, vectors): batch_size = vectors.shape[0] - new_dim_shape = (batch_size, self.num_attention_heads, -1, vectors.shape[-1]) - return torch.reshape(vectors, new_dim_shape).squeeze(-1) + new_dim_shape = (batch_size, self.num_attention_heads, -1) + + if len(vectors.shape) == 5: + return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) + elif len(vectors.shape) == 4: + return torch.reshape(vectors, new_dim_shape) + else: + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) class ReformerSelfOutput(nn.Module): From 4176564c785004fa69d4583a2108715c8dcf1818 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 18:56:24 +0200 Subject: [PATCH 016/162] refactor + add reformer config --- src/transformers/__init__.py | 1 + src/transformers/configuration_reformer.py | 148 +++ src/transformers/modeling_reformer.py | 46 +- tests/test_modeling_reformer.py | 120 +- w! | 1146 ++++++++++++++++++++ 5 files changed, 1348 insertions(+), 113 deletions(-) create mode 100644 src/transformers/configuration_reformer.py create mode 100644 w! diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index adcf7a01123f..6d40812d4c3a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -47,6 +47,7 @@ from .configuration_marian import MarianConfig from .configuration_mmbt import MMBTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig +from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py new file mode 100644 index 000000000000..91eb400efba3 --- /dev/null +++ b/src/transformers/configuration_reformer.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Reformer model configuration """ + + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class ReformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ReformerModel`. + It is used to instantiate an Reformer model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the Reformer `bert-base-uncased `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + vocab_size (:obj:`int`, optional, defaults to 30522): + Vocabulary size of the Reformer model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`. + hidden_size (:obj:`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_buckets (:obj:`int`, optional, defaults to ): + TODO (PVP) + num_hashes (:obj:`int`, optional, defaults to ): + TODO (PVP) + chunk_length (:obj:`int`, optional, defaults to ): + TODO (PVP) + num_chunks_before (:obj:`int`, optional, defaults to ): + TODO (PVP) + num_chunks_after (:obj:`int`, optional, defaults to ): + TODO (PVP) + intermediate_size (:obj:`int`, optional, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, optional, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.ReformerModel`. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + Example:: + + from transformers import ReformerModel, ReformerConfig + + # Initializing a Reformer bert-base-uncased style configuration + configuration = ReformerConfig() + + # Initializing a model from the bert-base-uncased style configuration + model = ReformerModel(configuration) + + # Accessing the model configuration + configuration = model.config + + Attributes: + pretrained_config_archive_map (Dict[str, str]): + A dictionary containing all the available pre-trained checkpoints. + """ + pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=64, + num_hidden_layers=12, + num_attention_heads=2, + num_buckets=4, + num_hashes=2, + chunk_length=7, + num_chunks_before=1, + num_chunks_after=0, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + +# TODO: not added yet +# predict_mem_len=2048 +# predict_drop_len=256 +# n_parallel_heads=1 + +# TO DELETE LATER: + self.seed = 0 + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_hashes = num_hashes + self.num_buckets = num_buckets + self.chunk_length = chunk_length + self.num_chunks_after = num_chunks_after + self.num_chunks_before = num_chunks_before + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1ebc71080f52..abd41ff6a94d 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -30,8 +30,7 @@ # DELETE later import numpy -# ADD later -#from .configuration_reformer import ReformerConfig +from .configuration_reformer import ReformerConfig logger = logging.getLogger(__name__) @@ -170,17 +169,15 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs class LSHSelfAttention(nn.Module): def __init__(self, config): super().__init__() -# def __init__(self, n_heads=2, dim_qk=64, dim_v=64, share_qk='unused', -# causal=False, chunk_len=None, n_chunks_before=1, n_chunks_after=0, -# n_hashes=1, n_buckets=256, mode='train', -# predict_mem_len=2048, predict_drop_len=256, -# attention_dropout=0.0, n_parallel_heads=1, -# use_python_loop=False, use_reference_code=False -# ): - - self.hidden_states_input_size = config['input_size'] - self.num_attention_heads = config['num_attention_heads'] - self.hidden_size = config['hf_hidden_size'] + + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.hash_seed = config.seed + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.chunk_length = config.chunk_length + self.num_chunks_before = config.num_chunks_before + self.num_chunks_after = config.num_chunks_after self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -188,21 +185,14 @@ def __init__(self, config): if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config['hidden_size'], config['num_attention_heads']) + "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) - self.query_key = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) - self.value = nn.Linear(self.hidden_states_input_size, self.all_head_size, bias=False) - self.dense = nn.Linear(self.all_head_size, self.hidden_states_input_size, bias=False) + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) # add Dropout - self.hash_seed = config['seed'] - - self.num_hashes = config['num_hashes'] - self.num_buckets = config['num_buckets'] - self.query_key_chunk_len = config['query_key_chunk_len'] - self.num_chunks_before = config['num_chunks_before'] - self.num_chunks_after = config['num_chunks_after'] def forward( self, @@ -233,11 +223,11 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.query_key_chunk_len) - value_vectors = self._split_dim_by(value_vectors, -1, self.query_key_chunk_len) + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) # Optionally include adjacent chunks. - if self.query_key_chunk_len is None: + if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 key_vectors = self._len_and_dim_norm(query_key_vectors) @@ -673,7 +663,7 @@ class ReformerPreTrainedModel(PreTrainedModel): a simple interface for downloading and loading pretrained models. """ -# TODO: config_class = ReformerConfig + config_class = ReformerConfig pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_reformer base_model_prefix = "reformer" diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 66a0e5f6f706..1e567404606f 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -23,7 +23,7 @@ from trax.layers.research.efficient_attention_v2 import ( LSHSelfAttention as TraxLSHSelfAttention, ) -from transformers import LSHSelfAttention +from transformers import LSHSelfAttention, ReformerConfig from transformers import is_torch_available # noqa: F401 @@ -53,7 +53,7 @@ class TraxUtils(object): error message " missing" """ - def __init__(self, shape=(3, 32, 8)): + def __init__(self, shape): self._shape = shape def convert_to_jax_array(self, np_array): @@ -68,39 +68,29 @@ def get_input_signature(self, shape=None): def get_layer( self, - shape=None, - num_attention_heads=None, - hidden_size=None, - query_key_chunk_len=None, - num_chunks_before=None, - num_chunks_after=None, - num_hashes=None, - num_buckets=None, + config, use_reference_code=True, - attention_dropout=0.0, mode="train", - seed=None, - path_to_save_weights=None, + path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", **kwargs ): with trax_math.use_backend("jax"): - if shape is None: - shape = self._shape + hidden_size_per_head = config.hidden_size // config.num_attention_heads layer = TraxLSHSelfAttention( - n_heads=num_attention_heads, - d_qk=hidden_size, - d_v=hidden_size, + n_heads=config.num_attention_heads, + d_qk=hidden_size_per_head, + d_v=hidden_size_per_head, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + attention_dropout=config.attention_probs_dropout_prob, + hash_seed=config.seed, causal=False, - chunk_len=query_key_chunk_len, - n_chunks_before=num_chunks_before, - n_chunks_after=num_chunks_after, - n_hashes=num_hashes, - n_buckets=num_buckets, use_reference_code=use_reference_code, - attention_dropout=attention_dropout, mode=mode, - hash_seed=seed, path_to_save_weights=path_to_save_weights ) @@ -109,16 +99,13 @@ def get_layer( def forward_layer( self, np_input_data, - layer=None, + layer, input_signature=None, random_number_generator=None, ): with trax_math.use_backend("jax"): input_data = self.convert_to_jax_array(np_input_data) - if layer is None: - layer = self.get_layer() - if input_signature is None: input_signature = self.get_input_signature() @@ -137,70 +124,33 @@ def forward_layer( @require_torch class ReformerIntegrationTests(unittest.TestCase): - hidden_size = 32 - seq_len = 7 - - def _get_random_input(self, shape): - return np.random.rand(*shape) - - def _get_trax_utils(self, shape): - return TraxUtils(shape) + def _set_weights_in_torch(self, weights, torch_layer, hidden_size_per_head): + # set torch weights for 1-to-1 comparison + with torch.no_grad(): + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + np_dense = np.asarray(weights[2]) - def _create_config( - self, - input_size=8, - num_attention_heads=2, - num_hashes=2, - num_buckets=4, - num_chunks_before=1, - num_chunks_after=0, - seed=0, - path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights" - ): - return { - "input_size": self.hidden_size, - "num_attention_heads": num_attention_heads, - "hidden_size": self.hidden_size, - "hf_hidden_size": num_attention_heads * self.hidden_size, - "num_hashes": num_hashes, - "num_buckets": num_buckets, - "query_key_chunk_len": self.seq_len, - "num_chunks_before": num_chunks_before, - "num_chunks_after": num_chunks_after, - "seed": seed, - "path_to_save_weights": path_to_save_weights - } + torch_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + torch_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + torch_layer.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, hidden_size_per_head).contiguous().transpose(0, 1)) def test_lsh_hashing(self): - shape = (3, self.seq_len, self.hidden_size) # Batch x SeqLen x ModelDim - - config_dict = self._create_config() + config = ReformerConfig() - np_input = self._get_random_input(shape) + hidden_size_per_head = config.hidden_size // config.num_attention_heads - trax_utils = self._get_trax_utils(shape) + shape = (3, 7, hidden_size_per_head) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.rand(*shape) - trax_layer = trax_utils.get_layer(**config_dict) - trax_output, trax_weights, trax_state = trax_utils.forward_layer( - np_input, layer=trax_layer - ) # noqa: F841 + trax_utils = TraxUtils(shape) + trax_layer = trax_utils.get_layer(config) + trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) trax_torch_output = torch.tensor(np.asarray(trax_output)) - hf_layer = LSHSelfAttention(config_dict) - - # set torch weights for 1-to-1 comparison - with torch.no_grad(): - np_query_key = np.asarray(trax_weights[0]) - np_value = np.asarray(trax_weights[1]) - np_dense = np.asarray(trax_weights[2]) - - hf_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, self.hidden_size)) - hf_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, self.hidden_size)) - hf_layer.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, self.hidden_size).contiguous().transpose(0, 1)) - hf_input = torch.tensor(np_input, dtype=torch.float) + hf_layer = LSHSelfAttention(config) + self._set_weights_in_torch(trax_weights, hf_layer, hidden_size_per_head) hf_output = hf_layer(hf_input) - assert torch.allclose(hf_output, trax_torch_output, atol=1e-6) - - pass + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) diff --git a/w! b/w! new file mode 100644 index 000000000000..7fc895502ff8 --- /dev/null +++ b/w! @@ -0,0 +1,1146 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model. """ + + +import logging +import os + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from .activations import gelu, gelu_new, swish +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, prune_linear_layer + +# DELETE later +import numpy + +from .configuration_reformer import ReformerConfig + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} +# TODO: fill with pretrained model weights +# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", +# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", +# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", +# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", + + +def load_tf_weights_in_reformer(model, config, tf_checkpoint_path): + return NotImplementedError('check code below to implement') + """ Load tf checkpoints in a pytorch model. + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} + + +ReformerLayerNorm = torch.nn.LayerNorm + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class LSHSelfAttention(nn.Module): + def __init__(self, config): + super().__init__(config) +# def __init__(self, n_heads=2, dim_qk=64, dim_v=64, share_qk='unused', +# causal=False, chunk_len=None, n_chunks_before=1, n_chunks_after=0, +# n_hashes=1, n_buckets=256, mode='train', +# predict_mem_len=2048, predict_drop_len=256, +# attention_dropout=0.0, n_parallel_heads=1, +# use_python_loop=False, use_reference_code=False +# ): + + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) + + # add Dropout + self.hash_seed = config.seed + + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.chunk_length = config.chunk_length + self.num_chunks_before = config.num_chunks_before + self.num_chunks_after = config.num_chunks_after + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + ): + + # get SeqLen and BatchSize + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) + value_vectors = self._transpose_for_scores(mixed_value_vectors) + + assert query_key_vectors.shape[-1] == self.attention_head_size + assert value_vectors.shape[-1] == self.attention_head_size + + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors) + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length + + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) + + query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) + value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) + + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) + + # Optionally include adjacent chunks. + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + + key_vectors = self._len_and_dim_norm(query_key_vectors) + + out_vectors, logits = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker) + + if self.num_hashes > 1: + out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) + logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + + out_vectors = self._transpose_for_output(out_vectors) + out = self.dense(out_vectors) + # TODO: apply broadcasted dropout + + return out + + def _hash_vectors(self, vectors): + batch_size = vectors.shape[0] + + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.num_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for num_bucket in self.num_buckets: + assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) + rotation_size += num_bucket + num_buckets *= num_bucket + + rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + + # TODO: delete later when integration tests are ok + # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 +# random_rotations = torch.randn(rotations_shape, device=vectors.device) + numpy.random.seed(self.hash_seed) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) + # rotated_vectors has dim: + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + # TODO: IMPORTANT: At the moment we use the same random rotation over all batches + # and heads -> is that bad? It seems like in original reformer a different random + # rotation is used over heads and batches + rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for num_bucket in self.num_buckets: + rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] + cur_sum += num_bucket // 2 + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + + if buckets is None: + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + + cur_product *= num_bucket + + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(self.num_hashes, device=vectors.device) + offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + + # repeat same values for Batch_size and Num_Attn_Heads + offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) + offset_buckets = self._merge_by_middle_dim(buckets + offsets) + + return offset_buckets + + def _get_ticker_and_undo_ticker(self, sequence_length, buckets): + batch_size = buckets.shape[0] + + # TODO: what is ticker? Is ticker something like indices?? Ask authors + ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) + + buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + + # Hash-based sort + sorted_ticker = torch.argsort(buckets_and_t, dim=-1) + undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_length) + return sorted_ticker, undo_sorted_ticker + + def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): + key_vectors = self._look_adjacent(key_vectors) + value_vectors = self._look_adjacent(value_vectors) + + # TODO: add masking here + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + dots = torch.exp(query_key_dots - logits) + + # TODO: add dropout here + + # attend values + out_vectors = torch.matmul(dots, value_vectors) + + # merge chunk length + logits = self._merge_by_middle_dim(logits).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors) + + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_ticker) + + return out_vectors, logits + + def _look_adjacent(self, vectors): + """ Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if self.num_chunks_before == 0 and self.num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-self.num_chunks_before, self.num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _len_and_dim_norm(self, vectors): + vectors = self._len_norm(vectors) + vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) + return vectors + + def _len_norm(self, x, epsilon=1e-6): + variance = torch.mean(x**2, -1, keepdim=True) + norm_x = x / torch.sqrt(variance + epsilon) + return norm_x + + def _gather_by_expansion(self, vectors, idxs): + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, self.num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) + + def _transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def _transpose_for_output(self, x): + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) + + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + + def _merge_by_middle_dim(self, vectors): + batch_size = vectors.shape[0] + new_dim_shape = (batch_size, self.num_attention_heads, -1) + + if len(vectors.shape) == 5: + return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) + elif len(vectors.shape) == 4: + return torch.reshape(vectors, new_dim_shape) + else: + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ReformerAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LSHSelfAttention(config) + self.output = ReformerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_outputs = self.self( + hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ReformerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ReformerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ReformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = ReformerAttention(config) + self.is_decoder = config.is_decoder + if self.is_decoder: + self.crossattention = ReformerAttention(config) + self.intermediate = ReformerIntermediate(config) + self.output = ReformerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask + ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class ReformerPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ReformerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class ReformerLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = ReformerPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class ReformerOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class ReformerPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ReformerLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class ReformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = ReformerConfig + pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + load_tf_weights = load_tf_weights_in_reformer + base_model_prefix = "reformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, ReformerLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +REFORMER_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +REFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.ReformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. +""" + + +@add_start_docstrings( + "The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.", + REFORMER_START_DOCSTRING, +) +class ReformerModel(ReformerPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + :obj:`is_decoder` argument of the configuration set to :obj:`True`; an + :obj:`encoder_hidden_states` is expected as an input to the forward pass. + + .. _`Attention is all you need`: + https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + self.pooler = ReformerPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during pre-training. + + This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerModel.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + causal_mask = causal_mask.to( + attention_mask.dtype + ) # causal and attention masks must have same type with pytorch version < 1.3 + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + elif encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( + encoder_hidden_shape, encoder_attention_mask.shape + ) + ) + + encoder_extended_attention_mask = encoder_extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """Reformer Model with two heads on top as done during the pre-training: a `masked language modeling` head and + a `next sentence prediction (classification)` head. """, + REFORMER_START_DOCSTRING, +) +class ReformerForPreTraining(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.cls = ReformerPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + masked_lm_labels=None, + next_sentence_label=None, + ): + r""" + masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False + continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + + Examples:: + + from transformers import ReformerTokenizer, ReformerForPreTraining + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerForPreTraining.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + prediction_scores, seq_relationship_scores = outputs[:2] + + """ + + outputs = self.reformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + outputs = (prediction_scores, seq_relationship_score,) + outputs[ + 2: + ] # add hidden states and attention if they are here + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + outputs = (total_loss,) + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Reformer Model with a `next sentence prediction (classification)` head on top. """, REFORMER_START_DOCSTRING, +) +class ReformerForNextSentencePrediction(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.cls = ReformerOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + ): + r""" + next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): + Next sequence prediction (classification) loss. + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerTokenizer, ReformerForNextSentencePrediction + import torch + + tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') + model = ReformerForNextSentencePrediction.from_pretrained('reformer-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + seq_relationship_scores = outputs[0] + + """ + + outputs = self.reformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + pooled_output = outputs[1] + + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + outputs = (next_sentence_loss,) + outputs + + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) From 8e02fe73c145ee5d690576f72c4e5300e8fa9f8d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 2 Apr 2020 21:17:58 +0200 Subject: [PATCH 017/162] delete bogus file --- w! | 1146 ------------------------------------------------------------ 1 file changed, 1146 deletions(-) delete mode 100644 w! diff --git a/w! b/w! deleted file mode 100644 index 7fc895502ff8..000000000000 --- a/w! +++ /dev/null @@ -1,1146 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch REFORMER model. """ - - -import logging -import os - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -from .activations import gelu, gelu_new, swish -from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, prune_linear_layer - -# DELETE later -import numpy - -from .configuration_reformer import ReformerConfig - - -logger = logging.getLogger(__name__) - -REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} -# TODO: fill with pretrained model weights -# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", -# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", -# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", -# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", - - -def load_tf_weights_in_reformer(model, config, tf_checkpoint_path): - return NotImplementedError('check code below to implement') - """ Load tf checkpoints in a pytorch model. - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info("Skipping {}".format("/".join(name))) - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - -def mish(x): - return x * torch.tanh(nn.functional.softplus(x)) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} - - -ReformerLayerNorm = torch.nn.LayerNorm - - -class ReformerEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class LSHSelfAttention(nn.Module): - def __init__(self, config): - super().__init__(config) -# def __init__(self, n_heads=2, dim_qk=64, dim_v=64, share_qk='unused', -# causal=False, chunk_len=None, n_chunks_before=1, n_chunks_after=0, -# n_hashes=1, n_buckets=256, mode='train', -# predict_mem_len=2048, predict_drop_len=256, -# attention_dropout=0.0, n_parallel_heads=1, -# use_python_loop=False, use_reference_code=False -# ): - - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - - self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) - - # add Dropout - self.hash_seed = config.seed - - self.num_hashes = config.num_hashes - self.num_buckets = config.num_buckets - self.chunk_length = config.chunk_length - self.num_chunks_before = config.num_chunks_before - self.num_chunks_after = config.num_chunks_after - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - ): - - # get SeqLen and BatchSize - sequence_length = hidden_states.shape[1] - batch_size = hidden_states.shape[0] - - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) - - query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) - value_vectors = self._transpose_for_scores(mixed_value_vectors) - - assert query_key_vectors.shape[-1] == self.attention_head_size - assert value_vectors.shape[-1] == self.attention_head_size - - # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors) - assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - - query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) - value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) - - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) - - # Optionally include adjacent chunks. - if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - - key_vectors = self._len_and_dim_norm(query_key_vectors) - - out_vectors, logits = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker) - - if self.num_hashes > 1: - out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) - logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) - - probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) - out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) - - out_vectors = self._transpose_for_output(out_vectors) - out = self.dense(out_vectors) - # TODO: apply broadcasted dropout - - return out - - def _hash_vectors(self, vectors): - batch_size = vectors.shape[0] - - # See https://arxiv.org/pdf/1509.02897.pdf - # We sample a different random rotation for each round of hashing to - # decrease the probability of hash misses. - if isinstance(self.num_buckets, int): - assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) - rotation_size = self.num_buckets - num_buckets = self.num_buckets - else: - # Factorize the hash if self.num_buckets is a list or tuple - rotation_size, num_buckets = 0, 1 - for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) - rotation_size += num_bucket - num_buckets *= num_bucket - - rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) - - # TODO: delete later when integration tests are ok - # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 -# random_rotations = torch.randn(rotations_shape, device=vectors.device) - numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) - # rotated_vectors has dim: - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 - # TODO: IMPORTANT: At the moment we use the same random rotation over all batches - # and heads -> is that bad? It seems like in original reformer a different random - # rotation is used over heads and batches - rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) - - if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - # Get the buckets for them and combine. - buckets, cur_sum, cur_product = None, 0, 1 - for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] - cur_sum += num_bucket // 2 - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - - if buckets is None: - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) - - cur_product *= num_bucket - - # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). - # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. - offsets = torch.arange(self.num_hashes, device=vectors.device) - offsets = torch.reshape(offsets * num_buckets, (-1, 1)) - - # repeat same values for Batch_size and Num_Attn_Heads - offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets) - - return offset_buckets - - def _get_ticker_and_undo_ticker(self, sequence_length, buckets): - batch_size = buckets.shape[0] - - # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) - ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) - - buckets_and_t = sequence_length * buckets + (ticker % sequence_length) - - # Hash-based sort - sorted_ticker = torch.argsort(buckets_and_t, dim=-1) - undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) - - sorted_ticker = (sorted_ticker % sequence_length) - return sorted_ticker, undo_sorted_ticker - - def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): - key_vectors = self._look_adjacent(key_vectors) - value_vectors = self._look_adjacent(value_vectors) - - # TODO: add masking here - - # get logits and dots - query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) - dots = torch.exp(query_key_dots - logits) - - # TODO: add dropout here - - # attend values - out_vectors = torch.matmul(dots, value_vectors) - - # merge chunk length - logits = self._merge_by_middle_dim(logits).squeeze(-1) - out_vectors = self._merge_by_middle_dim(out_vectors) - - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_ticker) - - return out_vectors, logits - - def _look_adjacent(self, vectors): - """ Used to implement attention between consecutive chunks. - - Args: - vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if self.num_chunks_before == 0 and self.num_chunks_after == 0: - return vectors - - slices = [] - for i in range(-self.num_chunks_before, self.num_chunks_after + 1): - if i == 0: - slices.append(vectors) - else: - slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) - return torch.cat(slices, dim=3) - - def _len_and_dim_norm(self, vectors): - vectors = self._len_norm(vectors) - vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) - return vectors - - def _len_norm(self, x, epsilon=1e-6): - variance = torch.mean(x**2, -1, keepdim=True) - norm_x = x / torch.sqrt(variance + epsilon) - return norm_x - - def _gather_by_expansion(self, vectors, idxs): - expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - vectors = vectors.repeat(1, 1, self.num_hashes, 1) - return torch.gather(vectors, 2, expanded_idxs) - - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def _transpose_for_output(self, x): - x = x.permute(0, 2, 1, 3) - return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) - - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): - batch_size = vectors.shape[0] - split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) - - if len(vectors.shape) == 4: - return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) - elif len(vectors.shape) == 3: - return torch.reshape(vectors, split_dim_shape) - else: - raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - - def _merge_by_middle_dim(self, vectors): - batch_size = vectors.shape[0] - new_dim_shape = (batch_size, self.num_attention_heads, -1) - - if len(vectors.shape) == 5: - return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) - elif len(vectors.shape) == 4: - return torch.reshape(vectors, new_dim_shape) - else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) - - -class ReformerSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class ReformerAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.self = LSHSelfAttention(config) - self.output = ReformerSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - self_outputs = self.self( - hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class ReformerIntermediate(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class ReformerOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class ReformerLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = ReformerAttention(config) - self.is_decoder = config.is_decoder - if self.is_decoder: - self.crossattention = ReformerAttention(config) - self.intermediate = ReformerIntermediate(config) - self.output = ReformerOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - if self.is_decoder and encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights - - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + outputs - return outputs - - -class ReformerEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - all_hidden_states = () - all_attentions = () - for i, layer_module in enumerate(self.layer): - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask - ) - hidden_states = layer_outputs[0] - - if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = (hidden_states,) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - -class ReformerPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class ReformerPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class ReformerLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = ReformerPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class ReformerOnlyNSPHead(nn.Module): - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class ReformerPreTrainingHeads(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = ReformerLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class ReformerPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - config_class = ReformerConfig - pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP - load_tf_weights = load_tf_weights_in_reformer - base_model_prefix = "reformer" - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, ReformerLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -REFORMER_START_DOCSTRING = r""" - This model is a PyTorch `torch.nn.Module `_ sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. -""" - -REFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`transformers.ReformerTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.encode_plus` for details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on padding token indices. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - - `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Segment token indices to indicate first and second portions of the inputs. - Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` - corresponds to a `sentence B` token - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range ``[0, config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): - Mask to nullify selected heads of the self-attention modules. - Mask values selected in ``[0, 1]``: - :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - if the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask - is used in the cross-attention if the model is configured as a decoder. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. -""" - - -@add_start_docstrings( - "The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.", - REFORMER_START_DOCSTRING, -) -class ReformerModel(ReformerPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - :obj:`is_decoder` argument of the configuration set to :obj:`True`; an - :obj:`encoder_hidden_states` is expected as an input to the forward pass. - - .. _`Attention is all you need`: - https://arxiv.org/abs/1706.03762 - - """ - - def __init__(self, config): - super().__init__(config) - self.config = config - - self.embeddings = ReformerEmbeddings(config) - self.encoder = ReformerEncoder(config) - self.pooler = ReformerPooler(config) - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - r""" - Return: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) - further processed by a Linear layer and a Tanh activation function. The Linear - layer weights are trained from the next sentence prediction (classification) - objective during pre-training. - - This output is usually *not* a good summary - of the semantic content of the input, you're often better with averaging or pooling - the sequence of hidden-states for the whole input sequence. - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - from transformers import ReformerModel, ReformerTokenizer - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerModel.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - causal_mask = causal_mask.to( - attention_mask.dtype - ) # causal and attention masks must have same type with pytorch version < 1.3 - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - elif encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( - encoder_hidden_shape, encoder_attention_mask.shape - ) - ) - - encoder_extended_attention_mask = encoder_extended_attention_mask.to( - dtype=next(self.parameters()).dtype - ) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers - - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) - - outputs = (sequence_output, pooled_output,) + encoder_outputs[ - 1: - ] # add hidden_states and attentions if they are here - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) - - -@add_start_docstrings( - """Reformer Model with two heads on top as done during the pre-training: a `masked language modeling` head and - a `next sentence prediction (classification)` head. """, - REFORMER_START_DOCSTRING, -) -class ReformerForPreTraining(ReformerPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.reformer = ReformerModel(config) - self.cls = ReformerPreTrainingHeads(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - masked_lm_labels=None, - next_sentence_label=None, - ): - r""" - masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): - Labels for computing the masked language modeling loss. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` - next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) - Indices should be in ``[0, 1]``. - ``0`` indicates sequence B is a continuation of sequence A, - ``1`` indicates sequence B is a random sequence. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. - prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False - continuation before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - - Examples:: - - from transformers import ReformerTokenizer, ReformerForPreTraining - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerForPreTraining.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - prediction_scores, seq_relationship_scores = outputs[:2] - - """ - - outputs = self.reformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - outputs = (prediction_scores, seq_relationship_score,) + outputs[ - 2: - ] # add hidden states and attention if they are here - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - outputs = (total_loss,) + outputs - - return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) - - -@add_start_docstrings( - """Reformer Model with a `next sentence prediction (classification)` head on top. """, REFORMER_START_DOCSTRING, -) -class ReformerForNextSentencePrediction(ReformerPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.reformer = ReformerModel(config) - self.cls = ReformerOnlyNSPHead(config) - - self.init_weights() - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - next_sentence_label=None, - ): - r""" - next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) - Indices should be in ``[0, 1]``. - ``0`` indicates sequence B is a continuation of sequence A, - ``1`` indicates sequence B is a random sequence. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): - Next sequence prediction (classification) loss. - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - from transformers import ReformerTokenizer, ReformerForNextSentencePrediction - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerForNextSentencePrediction.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - seq_relationship_scores = outputs[0] - - """ - - outputs = self.reformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - ) - - pooled_output = outputs[1] - - seq_relationship_score = self.cls(pooled_output) - - outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - outputs = (next_sentence_loss,) + outputs - - return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) From 2af8377add98d7b4f8f084cb184ce03fbcbb6044 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Apr 2020 10:41:56 +0200 Subject: [PATCH 018/162] split reformer attention layer into two layers --- src/transformers/__init__.py | 2 +- src/transformers/configuration_reformer.py | 2 +- src/transformers/modeling_reformer.py | 71 ++++++++-------------- tests/test_modeling_reformer.py | 15 +++-- 4 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6d40812d4c3a..2984056eb4b1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -321,7 +321,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, from .modeling_reformer import ( - LSHSelfAttention + ReformerAttention ) # Optimization diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 91eb400efba3..e4ab9e967e28 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -110,7 +110,7 @@ def __init__( num_chunks_after=0, intermediate_size=3072, hidden_act="gelu", - hidden_dropout_prob=0.1, + hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index abd41ff6a94d..a3b09c18396b 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -25,7 +25,7 @@ from .activations import gelu, gelu_new, swish from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel, prune_linear_layer +from .modeling_utils import PreTrainedModel # DELETE later import numpy @@ -178,6 +178,7 @@ def __init__(self, config): self.chunk_length = config.chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after + self.output_attentions = config.output_attentions self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -190,9 +191,8 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dense = nn.Linear(self.all_head_size, self.hidden_size, bias=False) - # add Dropout + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def forward( self, @@ -232,7 +232,7 @@ def forward( key_vectors = self._len_and_dim_norm(query_key_vectors) - out_vectors, logits = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker) + out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker, attention_mask, head_mask) if self.num_hashes > 1: out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) @@ -244,10 +244,8 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) out_vectors = self._transpose_for_output(out_vectors) - out = self.dense(out_vectors) - # TODO: apply broadcasted dropout - - return out + outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + return outputs def _hash_vectors(self, vectors): batch_size = vectors.shape[0] @@ -326,18 +324,27 @@ def _get_ticker_and_undo_ticker(self, sequence_length, buckets): sorted_ticker = (sorted_ticker % sequence_length) return sorted_ticker, undo_sorted_ticker - def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): + def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker, attention_mask, head_mask): key_vectors = self._look_adjacent(key_vectors) value_vectors = self._look_adjacent(value_vectors) - # TODO: add masking here - # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask. + # Attention mask is precomputed for all layers in forward() function + query_key_dots = query_key_dots + attention_mask + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) dots = torch.exp(query_key_dots - logits) - # TODO: add dropout here + # dropout + dots = self.dropout(dots) + + # Mask heads if we want to + if head_mask is not None: + dots = dots * head_mask # attend values out_vectors = torch.matmul(dots, value_vectors) @@ -350,7 +357,7 @@ def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker): out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) logits = torch.gather(logits, 2, undo_ticker) - return out_vectors, logits + return out_vectors, logits, dots def _look_adjacent(self, vectors): """ Used to implement attention between consecutive chunks. @@ -422,57 +429,31 @@ def _merge_by_middle_dim(self, vectors): class ReformerSelfOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): + + # TODO (PVP): Add residual here hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class ReformerAttention(nn.Module): def __init__(self, config): super().__init__() - self.self = LSHSelfAttention(config) + self.self_attention = LSHSelfAttention(config) self.output = ReformerSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, ): - self_outputs = self.self( - hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask + self_outputs = self.self_attention( + hidden_states, attention_mask, head_mask ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 1e567404606f..fbe40d2f5b2b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -23,7 +23,7 @@ from trax.layers.research.efficient_attention_v2 import ( LSHSelfAttention as TraxLSHSelfAttention, ) -from transformers import LSHSelfAttention, ReformerConfig +from transformers import ReformerAttention, ReformerConfig from transformers import is_torch_available # noqa: F401 @@ -87,6 +87,7 @@ def get_layer( n_hashes=config.num_hashes, n_buckets=config.num_buckets, attention_dropout=config.attention_probs_dropout_prob, + output_dropout=config.hidden_dropout_prob, hash_seed=config.seed, causal=False, use_reference_code=use_reference_code, @@ -131,9 +132,11 @@ def _set_weights_in_torch(self, weights, torch_layer, hidden_size_per_head): np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) - torch_layer.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) - torch_layer.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) - torch_layer.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, hidden_size_per_head).contiguous().transpose(0, 1)) + torch_layer.self_attention.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + + torch_layer.self_attention.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + + torch_layer.output.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, hidden_size_per_head).contiguous().transpose(0, 1)) def test_lsh_hashing(self): config = ReformerConfig() @@ -149,8 +152,8 @@ def test_lsh_hashing(self): trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) - hf_layer = LSHSelfAttention(config) + hf_layer = ReformerAttention(config) self._set_weights_in_torch(trax_weights, hf_layer, hidden_size_per_head) - hf_output = hf_layer(hf_input) + hf_output = hf_layer(hf_input)[0] self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) From 6fe94783cb6efc36dfd8f404ed8f170b47a71f49 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 09:30:59 +0200 Subject: [PATCH 019/162] fix merge conflicts --- src/transformers/activations.py | 4 + src/transformers/configuration_reformer.py | 7 +- tests/test_modeling_reformer.py | 86 +++++++++++++++++++++- 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 60e8e8d05e38..9fa175cfffb6 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -34,6 +34,10 @@ def gelu_new(x): else: gelu = F.gelu + +gelu = _gelu_python + + ACT2FN = { "relu": F.relu, "swish": swish, diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index e4ab9e967e28..ad2d5ee74a97 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -126,8 +126,13 @@ def __init__( # predict_drop_len=256 # n_parallel_heads=1 -# TO DELETE LATER: +# TO CHANGE LATER: self.seed = 0 + self.num_attention_chunks = 1 + self.ff_chunk_size = 0 + self.d_model = 32 + self.d_ff = 64 +# self.ff_activation = # GELU self.vocab_size = vocab_size self.hidden_size = hidden_size diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index fbe40d2f5b2b..8b2d5f7dc241 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -23,6 +23,9 @@ from trax.layers.research.efficient_attention_v2 import ( LSHSelfAttention as TraxLSHSelfAttention, ) +from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock +from trax import layers as tl + from transformers import ReformerAttention, ReformerConfig @@ -121,6 +124,73 @@ def forward_layer( return output, weights, state + def get_block( + self, + config, + use_reference_code=True, + share_qk=True, + ff_use_sru=0, + mode="train", + causal=False, + path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", + ): + + with trax_math.use_backend("jax"): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + list_of_layers = TraxLSHAttentionBlock( + d_model=config.d_model, + d_ff=config.d_ff, + d_attention_key=hidden_size_per_head, + d_attention_value=hidden_size_per_head, + n_heads=config.num_attention_heads, + n_attention_chunks=config.num_attention_chunks, + attention_type=tl.LSHSelfAttention, + dropout=config.hidden_dropout_prob, + share_qk=share_qk, + ff_activation=tl.Gelu, + ff_use_sru=ff_use_sru, + ff_chunk_size=config.ff_chunk_size, + mode=mode, + causal=causal, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + use_reference_code=use_reference_code, + hash_seed=config.seed, + path_to_save_weights=path_to_save_weights + ) + layer = tl.Serial(tl.ReversibleSerial([list_of_layers])) + + return layer + + def forward_block( + self, + np_input_data, + block, + input_signature=None, + random_number_generator=None, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + input_data = (input_data,) * 2 + + if input_signature is None: + input_signature = self.get_input_signature() + input_signature = (input_signature, input_signature) + + weights, state = block.init(input_signature) + + if random_number_generator is None: + random_number_generator = block.new_rngs(1)[0] + + output = block( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + @require_torch class ReformerIntegrationTests(unittest.TestCase): @@ -138,9 +208,8 @@ def _set_weights_in_torch(self, weights, torch_layer, hidden_size_per_head): torch_layer.output.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, hidden_size_per_head).contiguous().transpose(0, 1)) - def test_lsh_hashing(self): + def test_lsh_hashing_layer(self): config = ReformerConfig() - hidden_size_per_head = config.hidden_size // config.num_attention_heads shape = (3, 7, hidden_size_per_head) # Batch x SeqLen x ModelDimPerHead @@ -157,3 +226,16 @@ def test_lsh_hashing(self): hf_output = hf_layer(hf_input)[0] self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) + + def test_lsh_block(self): + config = ReformerConfig() + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + shape = (3, 7, hidden_size_per_head) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.rand(*shape) + + trax_utils = TraxUtils(shape) + trax_block = trax_utils.get_block(config) + trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) + + pass From 9e6e1af765fd2babbd45484529af87ad4c5ac6fb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Apr 2020 18:34:19 +0200 Subject: [PATCH 020/162] save intermediate step --- src/transformers/__init__.py | 3 +- src/transformers/configuration_reformer.py | 8 +-- src/transformers/modeling_reformer.py | 34 +++++----- tests/test_modeling_reformer.py | 76 ++++++++++++++++++---- 4 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2984056eb4b1..b16a760816de 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -321,7 +321,8 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, from .modeling_reformer import ( - ReformerAttention + ReformerAttention, + ReformerLayer ) # Optimization diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ad2d5ee74a97..ce6fa26825ad 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -100,7 +100,7 @@ class ReformerConfig(PretrainedConfig): def __init__( self, vocab_size=30522, - hidden_size=64, + hidden_size=32, num_hidden_layers=12, num_attention_heads=2, num_buckets=4, @@ -108,7 +108,7 @@ def __init__( chunk_length=7, num_chunks_before=1, num_chunks_after=0, - intermediate_size=3072, + intermediate_size=64, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, @@ -130,8 +130,8 @@ def __init__( self.seed = 0 self.num_attention_chunks = 1 self.ff_chunk_size = 0 - self.d_model = 32 - self.d_ff = 64 + self.d_model = hidden_size + self.d_ff = intermediate_size # self.ff_activation = # GELU self.vocab_size = vocab_size diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index a3b09c18396b..9cc109a858c0 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -433,16 +433,18 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): - - # TODO (PVP): Add residual here hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - return hidden_states + output = (hidden_states + input_tensor) + # Uncomment here if testing only LSHSelfAttentionLayer + output = hidden_states # Uncomment for layer only test + return output class ReformerAttention(nn.Module): def __init__(self, config): super().__init__() + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.self_attention = LSHSelfAttention(config) self.output = ReformerSelfOutput(config) @@ -452,8 +454,11 @@ def forward( attention_mask=None, head_mask=None, ): + norm_hidden_states = self.layer_norm(hidden_states) + # Uncomment here if testing only LSHSelfAttentionLayer + norm_hidden_states = hidden_states self_outputs = self.self_attention( - hidden_states, attention_mask, head_mask + norm_hidden_states, attention_mask, head_mask ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -464,6 +469,7 @@ class ReformerIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: @@ -471,6 +477,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -479,23 +486,20 @@ class ReformerOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states + output = (hidden_states + input_tensor) + return output class ReformerLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = ReformerAttention(config) - self.is_decoder = config.is_decoder - if self.is_decoder: - self.crossattention = ReformerAttention(config) + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.intermediate = ReformerIntermediate(config) self.output = ReformerOutput(config) @@ -504,20 +508,12 @@ def forward( hidden_states, attention_mask=None, head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, ): self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - if self.is_decoder and encoder_hidden_states is not None: - cross_attention_outputs = self.crossattention( - attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights - + norm_outputs = self.layer_norm(outputs) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) outputs = (layer_output,) + outputs diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 8b2d5f7dc241..c8199df02037 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -26,7 +26,7 @@ from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock from trax import layers as tl -from transformers import ReformerAttention, ReformerConfig +from transformers import ReformerAttention, ReformerLayer, ReformerConfig from transformers import is_torch_available # noqa: F401 @@ -195,24 +195,69 @@ def forward_block( @require_torch class ReformerIntegrationTests(unittest.TestCase): - def _set_weights_in_torch(self, weights, torch_layer, hidden_size_per_head): + def _set_param(self, torch_layer, weight, bias=None): + import ipdb + ipdb.set_trace() + + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) + torch_layer.weight = torch.nn.Parameter(weight) + if bias is not None: + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) + torch_layer.bias = torch.nn.Parameter(bias) + + def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison with torch.no_grad(): np_query_key = np.asarray(weights[0]) np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) - torch_layer.self_attention.query_key.weight = torch.nn.Parameter(torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.self_attention.value.weight, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.output.dense.weight, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) - torch_layer.self_attention.value.weight = torch.nn.Parameter(torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size_per_head)) + def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): + weights = weights[0] - torch_layer.output.dense.weight = torch.nn.Parameter(torch.tensor(np_dense).view(-1, hidden_size_per_head).contiguous().transpose(0, 1)) + # layernorm 1 + layer_norm_1 = weights[0][0][0] + layer_norm_1_weight = np.asarray(layer_norm_1[0]) + layer_norm_1_bias = np.asarray(layer_norm_1[1]) + with torch.no_grad(): + torch_layer.attention.layer_norm.weight = torch.nn.Parameter(torch.tensor(layer_norm_1_weight)) + torch_layer.attention.layer_norm.bias = torch.nn.Parameter(torch.tensor(layer_norm_1_bias)) + # lsh weights + output + lsh_weights = weights[0][1] + self._set_layer_weights_in_torch(lsh_weights, torch_layer.attention, hidden_size) + + # intermediate weighs + intermediate_weights = weights[2][0][2][2] + + # layernorm 2 + layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) + layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) + with torch.no_grad(): + torch_layer.layer_norm.weight = torch.nn.Parameter(torch.tensor(layer_norm_2_weight)) + torch_layer.layer_norm.bias = torch.nn.Parameter(torch.tensor(layer_norm_2_bias)) - def test_lsh_hashing_layer(self): - config = ReformerConfig() - hidden_size_per_head = config.hidden_size // config.num_attention_heads + # intermediate dense + inter_dense_weight = np.asarray(intermediate_weights[1][0]) + inter_dense_bias = np.asarray(intermediate_weights[1][1]) + with torch.no_grad(): + torch_layer.intermediate.dense.weight = torch.nn.Parameter(torch.tensor(inter_dense_weight).transpose(0, 1).contiguous()) + torch_layer.intermediate.dense.bias = torch.nn.Parameter(torch.tensor(inter_dense_bias)) + # intermediate out + out_dense_weight = np.asarray(intermediate_weights[4][0]) + out_dense_bias = np.asarray(intermediate_weights[4][1]) + with torch.no_grad(): + torch_layer.output.dense.weight = torch.nn.Parameter(torch.tensor(out_dense_weight).transpose(0, 1).contiguous()) + torch_layer.output.dense.bias = torch.nn.Parameter(torch.tensor(out_dense_bias)) - shape = (3, 7, hidden_size_per_head) # Batch x SeqLen x ModelDimPerHead + def test_lsh_layer(self): + # Remove residual connection in ReformerSelfOutput to test this layer only + # Remove layer norm in ReformerAttention to test this layer only + config = ReformerConfig() + shape = (3, 7, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) @@ -222,20 +267,27 @@ def test_lsh_hashing_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) - self._set_weights_in_torch(trax_weights, hf_layer, hidden_size_per_head) + self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) hf_output = hf_layer(hf_input)[0] self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) def test_lsh_block(self): config = ReformerConfig() - hidden_size_per_head = config.hidden_size // config.num_attention_heads - shape = (3, 7, hidden_size_per_head) # Batch x SeqLen x ModelDimPerHead + shape = (3, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) trax_block = trax_utils.get_block(config) trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) + trax_torch_output = torch.tensor(np.asarray(trax_output)) + + hf_input = torch.tensor(np_input, dtype=torch.float) + hf_block = ReformerLayer(config) + self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) + + import ipdb + ipdb.set_trace() pass From 18550742df9b3d005730811db70f1dfa016c3765 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Apr 2020 19:44:32 +0200 Subject: [PATCH 021/162] save intermediate step --- tests/test_modeling_reformer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c8199df02037..e06e51c017b1 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -198,16 +198,15 @@ class ReformerIntegrationTests(unittest.TestCase): def _set_param(self, torch_layer, weight, bias=None): import ipdb ipdb.set_trace() - - assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) - torch_layer.weight = torch.nn.Parameter(weight) - if bias is not None: - assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) - torch_layer.bias = torch.nn.Parameter(bias) + with torch.no_grad(): + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) + torch_layer.weight = torch.nn.Parameter(weight) + if bias is not None: + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) + torch_layer.bias = torch.nn.Parameter(bias) def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison - with torch.no_grad(): np_query_key = np.asarray(weights[0]) np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) From 1a4e61a9b80b6e768f3588c6dbcba559c1ab6dc6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Apr 2020 19:54:58 +0200 Subject: [PATCH 022/162] make test work --- tests/test_modeling_reformer.py | 38 ++++++++++++++------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e06e51c017b1..71cecc633155 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -196,8 +196,6 @@ def forward_block( class ReformerIntegrationTests(unittest.TestCase): def _set_param(self, torch_layer, weight, bias=None): - import ipdb - ipdb.set_trace() with torch.no_grad(): assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) torch_layer.weight = torch.nn.Parameter(weight) @@ -207,13 +205,13 @@ def _set_param(self, torch_layer, weight, bias=None): def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison - np_query_key = np.asarray(weights[0]) - np_value = np.asarray(weights[1]) - np_dense = np.asarray(weights[2]) + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + np_dense = np.asarray(weights[2]) - self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.self_attention.value.weight, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.output.dense.weight, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): weights = weights[0] @@ -222,9 +220,8 @@ def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): layer_norm_1 = weights[0][0][0] layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) - with torch.no_grad(): - torch_layer.attention.layer_norm.weight = torch.nn.Parameter(torch.tensor(layer_norm_1_weight)) - torch_layer.attention.layer_norm.bias = torch.nn.Parameter(torch.tensor(layer_norm_1_bias)) + self._set_param(torch_layer.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) + # lsh weights + output lsh_weights = weights[0][1] self._set_layer_weights_in_torch(lsh_weights, torch_layer.attention, hidden_size) @@ -235,22 +232,17 @@ def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - with torch.no_grad(): - torch_layer.layer_norm.weight = torch.nn.Parameter(torch.tensor(layer_norm_2_weight)) - torch_layer.layer_norm.bias = torch.nn.Parameter(torch.tensor(layer_norm_2_bias)) + self._set_param(torch_layer.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - with torch.no_grad(): - torch_layer.intermediate.dense.weight = torch.nn.Parameter(torch.tensor(inter_dense_weight).transpose(0, 1).contiguous()) - torch_layer.intermediate.dense.bias = torch.nn.Parameter(torch.tensor(inter_dense_bias)) + self._set_param(torch_layer.intermediate.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) + # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) - with torch.no_grad(): - torch_layer.output.dense.weight = torch.nn.Parameter(torch.tensor(out_dense_weight).transpose(0, 1).contiguous()) - torch_layer.output.dense.bias = torch.nn.Parameter(torch.tensor(out_dense_bias)) + self._set_param(torch_layer.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) def test_lsh_layer(self): # Remove residual connection in ReformerSelfOutput to test this layer only @@ -262,13 +254,13 @@ def test_lsh_layer(self): trax_utils = TraxUtils(shape) trax_layer = trax_utils.get_layer(config) trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) - trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) hf_output = hf_layer(hf_input)[0] + trax_torch_output = torch.tensor(np.asarray(trax_output)) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) def test_lsh_block(self): @@ -280,12 +272,14 @@ def test_lsh_block(self): trax_utils = TraxUtils(shape) trax_block = trax_utils.get_block(config) trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) - trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) hf_block = ReformerLayer(config) self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) + hf_output = hf_block(hf_input)[0] + # check which trax output is correct + trax_torch_output = torch.tensor(np.asarray(trax_output)) import ipdb ipdb.set_trace() From da6bfe4ae4f2f000972f49ed924d8f0c765fd5a3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Apr 2020 19:55:21 +0200 Subject: [PATCH 023/162] add complete reformer block layer --- src/transformers/modeling_reformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 9cc109a858c0..ea293c2305fd 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -437,7 +437,7 @@ def forward(self, hidden_states, input_tensor): hidden_states = self.dropout(hidden_states) output = (hidden_states + input_tensor) # Uncomment here if testing only LSHSelfAttentionLayer - output = hidden_states # Uncomment for layer only test +# output = hidden_states # Uncomment for layer only test return output @@ -456,7 +456,7 @@ def forward( ): norm_hidden_states = self.layer_norm(hidden_states) # Uncomment here if testing only LSHSelfAttentionLayer - norm_hidden_states = hidden_states +# norm_hidden_states = hidden_states self_outputs = self.self_attention( norm_hidden_states, attention_mask, head_mask ) @@ -513,8 +513,8 @@ def forward( attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - norm_outputs = self.layer_norm(outputs) - intermediate_output = self.intermediate(attention_output) + norm_attention_output = self.layer_norm(attention_output) + intermediate_output = self.intermediate(norm_attention_output) layer_output = self.output(intermediate_output, attention_output) outputs = (layer_output,) + outputs return outputs From 2825d241e8a09408796e5785d26bcdfe54a0a0c0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 4 Apr 2020 21:11:10 +0200 Subject: [PATCH 024/162] finish reformer layer --- src/transformers/modeling_reformer.py | 766 ++------------------------ tests/test_modeling_reformer.py | 81 +-- 2 files changed, 77 insertions(+), 770 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ea293c2305fd..382ba6010725 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -15,23 +15,16 @@ # limitations under the License. """PyTorch REFORMER model. """ - import logging -import os import torch from torch import nn -from torch.nn import CrossEntropyLoss from .activations import gelu, gelu_new, swish -from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel # DELETE later import numpy -from .configuration_reformer import ReformerConfig - logger = logging.getLogger(__name__) @@ -43,129 +36,14 @@ # "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", -def load_tf_weights_in_reformer(model, config, tf_checkpoint_path): - return NotImplementedError('check code below to implement') - """ Load tf checkpoints in a pytorch model. - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info("Skipping {}".format("/".join(name))) - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - def mish(x): return x * torch.tanh(nn.functional.softplus(x)) ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} - - ReformerLayerNorm = torch.nn.LayerNorm -class ReformerEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - class LSHSelfAttention(nn.Module): def __init__(self, config): super().__init__() @@ -435,9 +313,10 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) + # residual connection output = (hidden_states + input_tensor) # Uncomment here if testing only LSHSelfAttentionLayer -# output = hidden_states # Uncomment for layer only test +# output = hidden_states return output @@ -450,39 +329,45 @@ def __init__(self, config): def forward( self, + prev_attention_output, hidden_states, attention_mask=None, head_mask=None, ): norm_hidden_states = self.layer_norm(hidden_states) + # TODO: Remove comments later # Uncomment here if testing only LSHSelfAttentionLayer # norm_hidden_states = hidden_states - self_outputs = self.self_attention( + + self_attention_outputs = self.self_attention( norm_hidden_states, attention_mask, head_mask ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + attention_output = self.output(self_attention_outputs[0], prev_attention_output) + outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them return outputs -class ReformerIntermediate(nn.Module): +class ReformerFeedForwardDense(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] + self.act_fn = ACT2FN[config.hidden_act] else: - self.intermediate_act_fn = config.hidden_act + self.act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.act_fn(hidden_states) return hidden_states -class ReformerOutput(nn.Module): +class ReformerFeedForwardOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -495,621 +380,40 @@ def forward(self, hidden_states, input_tensor): return output -class ReformerLayer(nn.Module): +class ReformerFeedForward(nn.Module): def __init__(self, config): super().__init__() - self.attention = ReformerAttention(config) self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.intermediate = ReformerIntermediate(config) - self.output = ReformerOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - ): - self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + self.dense = ReformerFeedForwardDense(config) + self.output = ReformerFeedForwardOutput(config) + def forward(self, attention_output, prev_attention_output): norm_attention_output = self.layer_norm(attention_output) - intermediate_output = self.intermediate(norm_attention_output) - layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + outputs - return outputs + dense_output = self.dense(norm_attention_output) + output = self.output(dense_output, prev_attention_output) + return output -class ReformerEncoder(nn.Module): +class ReformerLayer(nn.Module): def __init__(self, config): super().__init__() - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + self.attention = ReformerAttention(config) + self.feed_forward = ReformerFeedForward(config) def forward( self, + prev_attention_output, hidden_states, attention_mask=None, head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - all_hidden_states = () - all_attentions = () - for i, layer_module in enumerate(self.layer): - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask - ) - hidden_states = layer_outputs[0] - - if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add last layer - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = (hidden_states,) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - -class ReformerPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class ReformerPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class ReformerLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = ReformerPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class ReformerOnlyNSPHead(nn.Module): - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class ReformerPreTrainingHeads(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = ReformerLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class ReformerPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - config_class = ReformerConfig - pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP - load_tf_weights = load_tf_weights_in_reformer - base_model_prefix = "reformer" - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, ReformerLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -REFORMER_START_DOCSTRING = r""" - This model is a PyTorch `torch.nn.Module `_ sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. -""" - -REFORMER_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`transformers.ReformerTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.encode_plus` for details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on padding token indices. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - - `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Segment token indices to indicate first and second portions of the inputs. - Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` - corresponds to a `sentence B` token - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range ``[0, config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): - Mask to nullify selected heads of the self-attention modules. - Mask values selected in ``[0, 1]``: - :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - if the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask - is used in the cross-attention if the model is configured as a decoder. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. -""" - - -@add_start_docstrings( - "The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.", - REFORMER_START_DOCSTRING, -) -class ReformerModel(ReformerPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - :obj:`is_decoder` argument of the configuration set to :obj:`True`; an - :obj:`encoder_hidden_states` is expected as an input to the forward pass. - - .. _`Attention is all you need`: - https://arxiv.org/abs/1706.03762 - - """ - - def __init__(self, config): - super().__init__(config) - self.config = config - - self.embeddings = ReformerEmbeddings(config) - self.encoder = ReformerEncoder(config) - self.pooler = ReformerPooler(config) - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, ): - r""" - Return: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) - further processed by a Linear layer and a Tanh activation function. The Linear - layer weights are trained from the next sentence prediction (classification) - objective during pre-training. - - This output is usually *not* a good summary - of the semantic content of the input, you're often better with averaging or pooling - the sequence of hidden-states for the whole input sequence. - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + # Y_1 = attention_output; Y_2 = output + attention_outputs = self.attention(prev_attention_output, hidden_states, attention_mask, head_mask) + attention_output = attention_outputs[0] + output = self.feed_forward(attention_output, prev_attention_output) - Examples:: - - from transformers import ReformerModel, ReformerTokenizer - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerModel.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - causal_mask = causal_mask.to( - attention_mask.dtype - ) # causal and attention masks must have same type with pytorch version < 1.3 - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - elif encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( - encoder_hidden_shape, encoder_attention_mask.shape - ) - ) - - encoder_extended_attention_mask = encoder_extended_attention_mask.to( - dtype=next(self.parameters()).dtype - ) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers - - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) - - outputs = (sequence_output, pooled_output,) + encoder_outputs[ - 1: - ] # add hidden_states and attentions if they are here - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) - - -@add_start_docstrings( - """Reformer Model with two heads on top as done during the pre-training: a `masked language modeling` head and - a `next sentence prediction (classification)` head. """, - REFORMER_START_DOCSTRING, -) -class ReformerForPreTraining(ReformerPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.reformer = ReformerModel(config) - self.cls = ReformerPreTrainingHeads(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - masked_lm_labels=None, - next_sentence_label=None, - ): - r""" - masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): - Labels for computing the masked language modeling loss. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` - next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) - Indices should be in ``[0, 1]``. - ``0`` indicates sequence B is a continuation of sequence A, - ``1`` indicates sequence B is a random sequence. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. - prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False - continuation before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - - Examples:: - - from transformers import ReformerTokenizer, ReformerForPreTraining - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerForPreTraining.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - prediction_scores, seq_relationship_scores = outputs[:2] - - """ - - outputs = self.reformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - outputs = (prediction_scores, seq_relationship_score,) + outputs[ - 2: - ] # add hidden states and attention if they are here - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - outputs = (total_loss,) + outputs - - return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) - - -@add_start_docstrings( - """Reformer Model with a `next sentence prediction (classification)` head on top. """, REFORMER_START_DOCSTRING, -) -class ReformerForNextSentencePrediction(ReformerPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.reformer = ReformerModel(config) - self.cls = ReformerOnlyNSPHead(config) - - self.init_weights() - - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - next_sentence_label=None, - ): - r""" - next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) - Indices should be in ``[0, 1]``. - ``0`` indicates sequence B is a continuation of sequence A, - ``1`` indicates sequence B is a random sequence. - - Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): - Next sequence prediction (classification) loss. - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - from transformers import ReformerTokenizer, ReformerForNextSentencePrediction - import torch - - tokenizer = ReformerTokenizer.from_pretrained('reformer-base-uncased') - model = ReformerForNextSentencePrediction.from_pretrained('reformer-base-uncased') - - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - - seq_relationship_scores = outputs[0] - - """ - - outputs = self.reformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - ) - - pooled_output = outputs[1] - - seq_relationship_score = self.cls(pooled_output) - - outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - outputs = (next_sentence_loss,) + outputs - - return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) + outputs = (attention_output, output) + attention_outputs[1:] + return outputs diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 71cecc633155..61cb65c67b0a 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -37,6 +37,8 @@ import torch # noqa: F401 # from transformers.modeling_reformer import () +PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" + class TraxUtils(object): """ class that will help for testing in the beginning @@ -73,7 +75,7 @@ def get_layer( self, config, use_reference_code=True, - mode="train", + mode="eval", path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", **kwargs ): @@ -130,38 +132,39 @@ def get_block( use_reference_code=True, share_qk=True, ff_use_sru=0, - mode="train", + mode="eval", causal=False, - path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", + path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): with trax_math.use_backend("jax"): - hidden_size_per_head = config.hidden_size // config.num_attention_heads - list_of_layers = TraxLSHAttentionBlock( - d_model=config.d_model, - d_ff=config.d_ff, - d_attention_key=hidden_size_per_head, - d_attention_value=hidden_size_per_head, - n_heads=config.num_attention_heads, - n_attention_chunks=config.num_attention_chunks, - attention_type=tl.LSHSelfAttention, - dropout=config.hidden_dropout_prob, - share_qk=share_qk, - ff_activation=tl.Gelu, - ff_use_sru=ff_use_sru, - ff_chunk_size=config.ff_chunk_size, - mode=mode, - causal=causal, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - use_reference_code=use_reference_code, - hash_seed=config.seed, - path_to_save_weights=path_to_save_weights - ) - layer = tl.Serial(tl.ReversibleSerial([list_of_layers])) + with jax.disable_jit(): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + list_of_layers = TraxLSHAttentionBlock( + d_model=config.d_model, + d_ff=config.d_ff, + d_attention_key=hidden_size_per_head, + d_attention_value=hidden_size_per_head, + n_heads=config.num_attention_heads, + n_attention_chunks=config.num_attention_chunks, + attention_type=tl.LSHSelfAttention, + dropout=config.hidden_dropout_prob, + share_qk=share_qk, + ff_activation=tl.Gelu, + ff_use_sru=ff_use_sru, + ff_chunk_size=config.ff_chunk_size, + mode=mode, + causal=causal, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + use_reference_code=use_reference_code, + hash_seed=config.seed, + path_to_save_weights=path_to_save_weights + ) + layer = tl.Serial(tl.ReversibleSerial([list_of_layers])) return layer @@ -232,23 +235,23 @@ def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - self._set_param(torch_layer.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) + self._set_param(torch_layer.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - self._set_param(torch_layer.intermediate.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) + self._set_param(torch_layer.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) - self._set_param(torch_layer.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) + self._set_param(torch_layer.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) def test_lsh_layer(self): # Remove residual connection in ReformerSelfOutput to test this layer only # Remove layer norm in ReformerAttention to test this layer only config = ReformerConfig() - shape = (3, 7, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (1, 7, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) @@ -258,7 +261,7 @@ def test_lsh_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) - hf_output = hf_layer(hf_input)[0] + hf_output = hf_layer(hf_input, hf_input)[0] trax_torch_output = torch.tensor(np.asarray(trax_output)) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) @@ -272,15 +275,15 @@ def test_lsh_block(self): trax_utils = TraxUtils(shape) trax_block = trax_utils.get_block(config) trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) + trax_torch_output_1 = torch.tensor(np.asarray(trax_output[0])) + trax_torch_output_2 = torch.tensor(np.asarray(trax_output[1])) hf_input = torch.tensor(np_input, dtype=torch.float) hf_block = ReformerLayer(config) self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) - hf_output = hf_block(hf_input)[0] + hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] - # check which trax output is correct - trax_torch_output = torch.tensor(np.asarray(trax_output)) - import ipdb - ipdb.set_trace() + self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) + self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) pass From 45e6635b666b9f87123cabd722a6d1960e434658 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 5 Apr 2020 11:59:23 +0200 Subject: [PATCH 025/162] implement causal and self mask --- src/transformers/configuration_reformer.py | 1 + src/transformers/modeling_reformer.py | 44 ++++++++++++++++------ tests/test_modeling_reformer.py | 7 ++-- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ce6fa26825ad..ec781269f958 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -132,6 +132,7 @@ def __init__( self.ff_chunk_size = 0 self.d_model = hidden_size self.d_ff = intermediate_size + self.is_decoder = True # self.ff_activation = # GELU self.vocab_size = vocab_size diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 382ba6010725..166c7071ca24 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -57,6 +57,7 @@ def __init__(self, config): self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after self.output_attentions = config.output_attentions + self.is_decoder = config.is_decoder self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -75,7 +76,6 @@ def __init__(self, config): def forward( self, hidden_states, - attention_mask=None, head_mask=None, ): @@ -101,8 +101,11 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) + # q_info = ticker + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) + ticker = self._split_dim_by(ticker, -1, self.chunk_length) # Optionally include adjacent chunks. if self.chunk_length is None: @@ -110,7 +113,7 @@ def forward( key_vectors = self._len_and_dim_norm(query_key_vectors) - out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, undo_ticker, attention_mask, head_mask) + out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) if self.num_hashes > 1: out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) @@ -202,17 +205,38 @@ def _get_ticker_and_undo_ticker(self, sequence_length, buckets): sorted_ticker = (sorted_ticker % sequence_length) return sorted_ticker, undo_sorted_ticker - def _attend(self, query_vectors, key_vectors, value_vectors, undo_ticker, attention_mask, head_mask): + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): key_vectors = self._look_adjacent(key_vectors) value_vectors = self._look_adjacent(value_vectors) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - if attention_mask is not None: - # Apply the attention mask. - # Attention mask is precomputed for all layers in forward() function - query_key_dots = query_key_dots + attention_mask + # TODO(PVP): better naming + query_info = ticker + key_value_info = self._look_adjacent(ticker) + + # Causal mask + if self.is_decoder: + causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long() + query_key_dots = query_key_dots - causal_mask * 1e9 + + # Self mask + self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long() + query_key_dots = query_key_dots - self_mask * 1e5 + + # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. + + # Note: Self mask is used because Q and K projection weights are shared. + # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): + + # " While attention to the future is not allowed, typical implementations of the + # Transformer do allow a position to attend to itself. + # Such behavior is undesirable in a shared-QK formulation because the dot-product + # of a query vector with itself will almost always be greater than the dot product of a + # query vector with a vector at another position. We therefore modify the masking + # to forbid a token from attending to itself, except in situations + # where a token has no other valid attention targets (e.g. the first token in a sequence) " logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) dots = torch.exp(query_key_dots - logits) @@ -331,7 +355,6 @@ def forward( self, prev_attention_output, hidden_states, - attention_mask=None, head_mask=None, ): norm_hidden_states = self.layer_norm(hidden_states) @@ -340,7 +363,7 @@ def forward( # norm_hidden_states = hidden_states self_attention_outputs = self.self_attention( - norm_hidden_states, attention_mask, head_mask + norm_hidden_states, head_mask ) # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) @@ -404,14 +427,13 @@ def forward( self, prev_attention_output, hidden_states, - attention_mask=None, head_mask=None, ): # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # X_1 = prev_attention_output; X_2 = hidden_states # Y_1 = attention_output; Y_2 = output - attention_outputs = self.attention(prev_attention_output, hidden_states, attention_mask, head_mask) + attention_outputs = self.attention(prev_attention_output, hidden_states, head_mask) attention_output = attention_outputs[0] output = self.feed_forward(attention_output, prev_attention_output) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 61cb65c67b0a..1576942e7b13 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -94,7 +94,7 @@ def get_layer( attention_dropout=config.attention_probs_dropout_prob, output_dropout=config.hidden_dropout_prob, hash_seed=config.seed, - causal=False, + causal=config.is_decoder, use_reference_code=use_reference_code, mode=mode, path_to_save_weights=path_to_save_weights @@ -133,7 +133,6 @@ def get_block( share_qk=True, ff_use_sru=0, mode="eval", - causal=False, path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): @@ -154,7 +153,7 @@ def get_block( ff_use_sru=ff_use_sru, ff_chunk_size=config.ff_chunk_size, mode=mode, - causal=causal, + causal=config.is_decoder, chunk_len=config.chunk_length, n_chunks_before=config.num_chunks_before, n_chunks_after=config.num_chunks_after, @@ -264,7 +263,7 @@ def test_lsh_layer(self): hf_output = hf_layer(hf_input, hf_input)[0] trax_torch_output = torch.tensor(np.asarray(trax_output)) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-6)) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_lsh_block(self): config = ReformerConfig() From b5ed5d4f9767f8331b520980c4c184b639737c57 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 5 Apr 2020 15:27:08 +0200 Subject: [PATCH 026/162] clean reformer test and refactor code --- src/transformers/modeling_reformer.py | 10 ++-------- tests/test_modeling_reformer.py | 6 ++++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 166c7071ca24..1b3c91ce6f51 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -98,8 +98,8 @@ def forward( ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker) - value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker) + query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) + value_vectors = self._gather_by_expansion(value_vectors, ticker) # q_info = ticker @@ -339,8 +339,6 @@ def forward(self, hidden_states, input_tensor): hidden_states = self.dropout(hidden_states) # residual connection output = (hidden_states + input_tensor) - # Uncomment here if testing only LSHSelfAttentionLayer -# output = hidden_states return output @@ -358,10 +356,6 @@ def forward( head_mask=None, ): norm_hidden_states = self.layer_norm(hidden_states) - # TODO: Remove comments later - # Uncomment here if testing only LSHSelfAttentionLayer -# norm_hidden_states = hidden_states - self_attention_outputs = self.self_attention( norm_hidden_states, head_mask ) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 1576942e7b13..0913d471c3b9 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -250,7 +250,7 @@ def test_lsh_layer(self): # Remove residual connection in ReformerSelfOutput to test this layer only # Remove layer norm in ReformerAttention to test this layer only config = ReformerConfig() - shape = (1, 7, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (2, 7, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) @@ -260,7 +260,9 @@ def test_lsh_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) - hf_output = hf_layer(hf_input, hf_input)[0] + + hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) trax_torch_output = torch.tensor(np.asarray(trax_output)) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) From ddb2f098f48337bb1c25840c65b4c954d9398ad5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 13:32:46 +0200 Subject: [PATCH 027/162] update init --- src/transformers/__init__.py | 2 ++ src/transformers/configuration_reformer.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b16a760816de..a238cd3cad60 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -191,6 +191,7 @@ BertForQuestionAnswering, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BertLayer ) from .modeling_openai import ( OpenAIGPTPreTrainedModel, @@ -319,6 +320,7 @@ ElectraModel, load_tf_weights_in_electra, ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, + ) from .modeling_reformer import ( ReformerAttention, diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ec781269f958..17b6e1a6294d 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -105,7 +105,7 @@ def __init__( num_attention_heads=2, num_buckets=4, num_hashes=2, - chunk_length=7, + chunk_length=64, num_chunks_before=1, num_chunks_after=0, intermediate_size=64, From cbb5ab9f4084966bde8670e70509790a4456db85 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 14:15:26 +0200 Subject: [PATCH 028/162] fix device for GPU --- src/transformers/modeling_reformer.py | 6 +- w! | 435 ++++++++++++++++++++++++++ 2 files changed, 438 insertions(+), 3 deletions(-) create mode 100644 w! diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1b3c91ce6f51..99632b61501c 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -152,7 +152,7 @@ def _hash_vectors(self, vectors): # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 # random_rotations = torch.randn(rotations_shape, device=vectors.device) numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # TODO: IMPORTANT: At the moment we use the same random rotation over all batches @@ -218,11 +218,11 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # Causal mask if self.is_decoder: - causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long() + causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) query_key_dots = query_key_dots - causal_mask * 1e9 # Self mask - self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long() + self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) query_key_dots = query_key_dots - self_mask * 1e5 # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. diff --git a/w! b/w! new file mode 100644 index 000000000000..99632b61501c --- /dev/null +++ b/w! @@ -0,0 +1,435 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model. """ + +import logging + +import torch +from torch import nn + +from .activations import gelu, gelu_new, swish + +# DELETE later +import numpy + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} +# TODO: fill with pretrained model weights +# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", +# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", +# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", +# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} +ReformerLayerNorm = torch.nn.LayerNorm + + +class LSHSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.hash_seed = config.seed + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.chunk_length = config.chunk_length + self.num_chunks_before = config.num_chunks_before + self.num_chunks_after = config.num_chunks_after + self.output_attentions = config.output_attentions + self.is_decoder = config.is_decoder + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states, + head_mask=None, + ): + + # get SeqLen and BatchSize + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) + value_vectors = self._transpose_for_scores(mixed_value_vectors) + + assert query_key_vectors.shape[-1] == self.attention_head_size + assert value_vectors.shape[-1] == self.attention_head_size + + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors) + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length + + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) + + query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) + value_vectors = self._gather_by_expansion(value_vectors, ticker) + + # q_info = ticker + + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) + ticker = self._split_dim_by(ticker, -1, self.chunk_length) + + # Optionally include adjacent chunks. + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + + key_vectors = self._len_and_dim_norm(query_key_vectors) + + out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) + + if self.num_hashes > 1: + out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) + logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + + out_vectors = self._transpose_for_output(out_vectors) + outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + return outputs + + def _hash_vectors(self, vectors): + batch_size = vectors.shape[0] + + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.num_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for num_bucket in self.num_buckets: + assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) + rotation_size += num_bucket + num_buckets *= num_bucket + + rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + + # TODO: delete later when integration tests are ok + # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 +# random_rotations = torch.randn(rotations_shape, device=vectors.device) + numpy.random.seed(self.hash_seed) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) + # rotated_vectors has dim: + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + # TODO: IMPORTANT: At the moment we use the same random rotation over all batches + # and heads -> is that bad? It seems like in original reformer a different random + # rotation is used over heads and batches + rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for num_bucket in self.num_buckets: + rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] + cur_sum += num_bucket // 2 + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + + if buckets is None: + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + + cur_product *= num_bucket + + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(self.num_hashes, device=vectors.device) + offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + + # repeat same values for Batch_size and Num_Attn_Heads + offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) + offset_buckets = self._merge_by_middle_dim(buckets + offsets) + + return offset_buckets + + def _get_ticker_and_undo_ticker(self, sequence_length, buckets): + batch_size = buckets.shape[0] + + # TODO: what is ticker? Is ticker something like indices?? Ask authors + ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) + + buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + + # Hash-based sort + sorted_ticker = torch.argsort(buckets_and_t, dim=-1) + undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_length) + return sorted_ticker, undo_sorted_ticker + + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): + key_vectors = self._look_adjacent(key_vectors) + value_vectors = self._look_adjacent(value_vectors) + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # TODO(PVP): better naming + query_info = ticker + key_value_info = self._look_adjacent(ticker) + + # Causal mask + if self.is_decoder: + causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - causal_mask * 1e9 + + # Self mask + self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - self_mask * 1e5 + + # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. + + # Note: Self mask is used because Q and K projection weights are shared. + # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): + + # " While attention to the future is not allowed, typical implementations of the + # Transformer do allow a position to attend to itself. + # Such behavior is undesirable in a shared-QK formulation because the dot-product + # of a query vector with itself will almost always be greater than the dot product of a + # query vector with a vector at another position. We therefore modify the masking + # to forbid a token from attending to itself, except in situations + # where a token has no other valid attention targets (e.g. the first token in a sequence) " + + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + dots = torch.exp(query_key_dots - logits) + + # dropout + dots = self.dropout(dots) + + # Mask heads if we want to + if head_mask is not None: + dots = dots * head_mask + + # attend values + out_vectors = torch.matmul(dots, value_vectors) + + # merge chunk length + logits = self._merge_by_middle_dim(logits).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors) + + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_ticker) + + return out_vectors, logits, dots + + def _look_adjacent(self, vectors): + """ Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if self.num_chunks_before == 0 and self.num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-self.num_chunks_before, self.num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _len_and_dim_norm(self, vectors): + vectors = self._len_norm(vectors) + vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) + return vectors + + def _len_norm(self, x, epsilon=1e-6): + variance = torch.mean(x**2, -1, keepdim=True) + norm_x = x / torch.sqrt(variance + epsilon) + return norm_x + + def _gather_by_expansion(self, vectors, idxs): + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, self.num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) + + def _transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def _transpose_for_output(self, x): + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) + + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + + def _merge_by_middle_dim(self, vectors): + batch_size = vectors.shape[0] + new_dim_shape = (batch_size, self.num_attention_heads, -1) + + if len(vectors.shape) == 5: + return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) + elif len(vectors.shape) == 4: + return torch.reshape(vectors, new_dim_shape) + else: + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # residual connection + output = (hidden_states + input_tensor) + return output + + +class ReformerAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.self_attention = LSHSelfAttention(config) + self.output = ReformerSelfOutput(config) + + def forward( + self, + prev_attention_output, + hidden_states, + head_mask=None, + ): + norm_hidden_states = self.layer_norm(hidden_states) + self_attention_outputs = self.self_attention( + norm_hidden_states, head_mask + ) + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + attention_output = self.output(self_attention_outputs[0], prev_attention_output) + outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them + return outputs + + +class ReformerFeedForwardDense(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.act_fn(hidden_states) + return hidden_states + + +class ReformerFeedForwardOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + output = (hidden_states + input_tensor) + return output + + +class ReformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = ReformerFeedForwardDense(config) + self.output = ReformerFeedForwardOutput(config) + + def forward(self, attention_output, prev_attention_output): + norm_attention_output = self.layer_norm(attention_output) + dense_output = self.dense(norm_attention_output) + output = self.output(dense_output, prev_attention_output) + return output + + +class ReformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = ReformerAttention(config) + self.feed_forward = ReformerFeedForward(config) + + def forward( + self, + prev_attention_output, + hidden_states, + head_mask=None, + ): + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + # Y_1 = attention_output; Y_2 = output + attention_outputs = self.attention(prev_attention_output, hidden_states, head_mask) + attention_output = attention_outputs[0] + output = self.feed_forward(attention_output, prev_attention_output) + + outputs = (attention_output, output) + attention_outputs[1:] + return outputs From f17fd5b3093fb72aade12f7215b58eba6949043d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 14:17:50 +0200 Subject: [PATCH 029/162] fix chunk length init for tests --- src/transformers/configuration_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 17b6e1a6294d..ec781269f958 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -105,7 +105,7 @@ def __init__( num_attention_heads=2, num_buckets=4, num_hashes=2, - chunk_length=64, + chunk_length=7, num_chunks_before=1, num_chunks_after=0, intermediate_size=64, From eca8cce9749e985c9dc755a0f64cdfb42f60ccf7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 20:22:32 +0200 Subject: [PATCH 030/162] include morgans optimization --- src/transformers/configuration_reformer.py | 4 +- src/transformers/modeling_reformer.py | 18 +- tests/test_modeling_reformer.py | 3 + w! | 435 --------------------- 4 files changed, 14 insertions(+), 446 deletions(-) delete mode 100644 w! diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ec781269f958..776f6ad88ba4 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -110,8 +110,8 @@ def __init__( num_chunks_after=0, intermediate_size=64, hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.3, + attention_probs_dropout_prob=0.3, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 99632b61501c..b76624ed4a6b 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -71,7 +71,7 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.dropout = config.attention_probs_dropout_prob def forward( self, @@ -242,7 +242,7 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker dots = torch.exp(query_key_dots - logits) # dropout - dots = self.dropout(dots) + dots = nn.functional.dropout(dots, self.dropout, self.training) # Mask heads if we want to if head_mask is not None: @@ -299,7 +299,7 @@ def _gather_by_expansion(self, vectors, idxs): def _transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + return x.transpose(2, 1) def _transpose_for_output(self, x): x = x.permute(0, 2, 1, 3) @@ -332,11 +332,11 @@ class ReformerSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = config.hidden_dropout_prob def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) # residual connection output = (hidden_states + input_tensor) return output @@ -371,7 +371,7 @@ class ReformerFeedForwardDense(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = config.hidden_dropout_prob if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: @@ -379,7 +379,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) hidden_states = self.act_fn(hidden_states) return hidden_states @@ -388,11 +388,11 @@ class ReformerFeedForwardOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = config.hidden_dropout_prob def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) output = (hidden_states + input_tensor) return output diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 0913d471c3b9..7715413cfb0f 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -260,6 +260,7 @@ def test_lsh_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) + hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) @@ -282,6 +283,8 @@ def test_lsh_block(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_block = ReformerLayer(config) self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) + hf_block.eval() + hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) diff --git a/w! b/w! deleted file mode 100644 index 99632b61501c..000000000000 --- a/w! +++ /dev/null @@ -1,435 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch REFORMER model. """ - -import logging - -import torch -from torch import nn - -from .activations import gelu, gelu_new, swish - -# DELETE later -import numpy - - -logger = logging.getLogger(__name__) - -REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} -# TODO: fill with pretrained model weights -# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", -# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", -# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", -# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", - - -def mish(x): - return x * torch.tanh(nn.functional.softplus(x)) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} -ReformerLayerNorm = torch.nn.LayerNorm - - -class LSHSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.hash_seed = config.seed - self.num_hashes = config.num_hashes - self.num_buckets = config.num_buckets - self.chunk_length = config.chunk_length - self.num_chunks_before = config.num_chunks_before - self.num_chunks_after = config.num_chunks_after - self.output_attentions = config.output_attentions - self.is_decoder = config.is_decoder - - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - - self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def forward( - self, - hidden_states, - head_mask=None, - ): - - # get SeqLen and BatchSize - sequence_length = hidden_states.shape[1] - batch_size = hidden_states.shape[0] - - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) - - query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) - value_vectors = self._transpose_for_scores(mixed_value_vectors) - - assert query_key_vectors.shape[-1] == self.attention_head_size - assert value_vectors.shape[-1] == self.attention_head_size - - # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors) - assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) - value_vectors = self._gather_by_expansion(value_vectors, ticker) - - # q_info = ticker - - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) - ticker = self._split_dim_by(ticker, -1, self.chunk_length) - - # Optionally include adjacent chunks. - if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - - key_vectors = self._len_and_dim_norm(query_key_vectors) - - out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) - - if self.num_hashes > 1: - out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) - logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) - - probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) - out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) - - out_vectors = self._transpose_for_output(out_vectors) - outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) - return outputs - - def _hash_vectors(self, vectors): - batch_size = vectors.shape[0] - - # See https://arxiv.org/pdf/1509.02897.pdf - # We sample a different random rotation for each round of hashing to - # decrease the probability of hash misses. - if isinstance(self.num_buckets, int): - assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) - rotation_size = self.num_buckets - num_buckets = self.num_buckets - else: - # Factorize the hash if self.num_buckets is a list or tuple - rotation_size, num_buckets = 0, 1 - for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) - rotation_size += num_bucket - num_buckets *= num_bucket - - rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) - - # TODO: delete later when integration tests are ok - # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 -# random_rotations = torch.randn(rotations_shape, device=vectors.device) - numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) - # rotated_vectors has dim: - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 - # TODO: IMPORTANT: At the moment we use the same random rotation over all batches - # and heads -> is that bad? It seems like in original reformer a different random - # rotation is used over heads and batches - rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) - - if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - # Get the buckets for them and combine. - buckets, cur_sum, cur_product = None, 0, 1 - for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] - cur_sum += num_bucket // 2 - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - - if buckets is None: - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) - - cur_product *= num_bucket - - # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). - # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. - offsets = torch.arange(self.num_hashes, device=vectors.device) - offsets = torch.reshape(offsets * num_buckets, (-1, 1)) - - # repeat same values for Batch_size and Num_Attn_Heads - offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets) - - return offset_buckets - - def _get_ticker_and_undo_ticker(self, sequence_length, buckets): - batch_size = buckets.shape[0] - - # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) - ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) - - buckets_and_t = sequence_length * buckets + (ticker % sequence_length) - - # Hash-based sort - sorted_ticker = torch.argsort(buckets_and_t, dim=-1) - undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) - - sorted_ticker = (sorted_ticker % sequence_length) - return sorted_ticker, undo_sorted_ticker - - def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): - key_vectors = self._look_adjacent(key_vectors) - value_vectors = self._look_adjacent(value_vectors) - - # get logits and dots - query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - - # TODO(PVP): better naming - query_info = ticker - key_value_info = self._look_adjacent(ticker) - - # Causal mask - if self.is_decoder: - causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - causal_mask * 1e9 - - # Self mask - self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - self_mask * 1e5 - - # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. - - # Note: Self mask is used because Q and K projection weights are shared. - # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): - - # " While attention to the future is not allowed, typical implementations of the - # Transformer do allow a position to attend to itself. - # Such behavior is undesirable in a shared-QK formulation because the dot-product - # of a query vector with itself will almost always be greater than the dot product of a - # query vector with a vector at another position. We therefore modify the masking - # to forbid a token from attending to itself, except in situations - # where a token has no other valid attention targets (e.g. the first token in a sequence) " - - logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) - dots = torch.exp(query_key_dots - logits) - - # dropout - dots = self.dropout(dots) - - # Mask heads if we want to - if head_mask is not None: - dots = dots * head_mask - - # attend values - out_vectors = torch.matmul(dots, value_vectors) - - # merge chunk length - logits = self._merge_by_middle_dim(logits).squeeze(-1) - out_vectors = self._merge_by_middle_dim(out_vectors) - - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_ticker) - - return out_vectors, logits, dots - - def _look_adjacent(self, vectors): - """ Used to implement attention between consecutive chunks. - - Args: - vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if self.num_chunks_before == 0 and self.num_chunks_after == 0: - return vectors - - slices = [] - for i in range(-self.num_chunks_before, self.num_chunks_after + 1): - if i == 0: - slices.append(vectors) - else: - slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) - return torch.cat(slices, dim=3) - - def _len_and_dim_norm(self, vectors): - vectors = self._len_norm(vectors) - vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) - return vectors - - def _len_norm(self, x, epsilon=1e-6): - variance = torch.mean(x**2, -1, keepdim=True) - norm_x = x / torch.sqrt(variance + epsilon) - return norm_x - - def _gather_by_expansion(self, vectors, idxs): - expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - vectors = vectors.repeat(1, 1, self.num_hashes, 1) - return torch.gather(vectors, 2, expanded_idxs) - - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def _transpose_for_output(self, x): - x = x.permute(0, 2, 1, 3) - return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) - - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): - batch_size = vectors.shape[0] - split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) - - if len(vectors.shape) == 4: - return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) - elif len(vectors.shape) == 3: - return torch.reshape(vectors, split_dim_shape) - else: - raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - - def _merge_by_middle_dim(self, vectors): - batch_size = vectors.shape[0] - new_dim_shape = (batch_size, self.num_attention_heads, -1) - - if len(vectors.shape) == 5: - return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) - elif len(vectors.shape) == 4: - return torch.reshape(vectors, new_dim_shape) - else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) - - -class ReformerSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - # residual connection - output = (hidden_states + input_tensor) - return output - - -class ReformerAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attention = LSHSelfAttention(config) - self.output = ReformerSelfOutput(config) - - def forward( - self, - prev_attention_output, - hidden_states, - head_mask=None, - ): - norm_hidden_states = self.layer_norm(hidden_states) - self_attention_outputs = self.self_attention( - norm_hidden_states, head_mask - ) - - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - attention_output = self.output(self_attention_outputs[0], prev_attention_output) - outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them - return outputs - - -class ReformerFeedForwardDense(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - if isinstance(config.hidden_act, str): - self.act_fn = ACT2FN[config.hidden_act] - else: - self.act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.act_fn(hidden_states) - return hidden_states - - -class ReformerFeedForwardOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - output = (hidden_states + input_tensor) - return output - - -class ReformerFeedForward(nn.Module): - def __init__(self, config): - super().__init__() - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dense = ReformerFeedForwardDense(config) - self.output = ReformerFeedForwardOutput(config) - - def forward(self, attention_output, prev_attention_output): - norm_attention_output = self.layer_norm(attention_output) - dense_output = self.dense(norm_attention_output) - output = self.output(dense_output, prev_attention_output) - return output - - -class ReformerLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = ReformerAttention(config) - self.feed_forward = ReformerFeedForward(config) - - def forward( - self, - prev_attention_output, - hidden_states, - head_mask=None, - ): - - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - # Y_1 = attention_output; Y_2 = output - attention_outputs = self.attention(prev_attention_output, hidden_states, head_mask) - attention_output = attention_outputs[0] - output = self.feed_forward(attention_output, prev_attention_output) - - outputs = (attention_output, output) + attention_outputs[1:] - return outputs From db2ebb1bbe97c85566e10e303739a13e985801ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 21:50:58 +0200 Subject: [PATCH 031/162] improve memory a bit --- src/transformers/modeling_reformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index b76624ed4a6b..2763935f67bc 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -218,12 +218,14 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # Causal mask if self.is_decoder: - causal_mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - causal_mask * 1e9 + # TODO: This line can be improved in terms of memory + mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - mask * 1e9 # Self mask - self_mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - self_mask * 1e5 + # TODO: This line can be improved in terms of memory + mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - mask * 1e5 # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. From 04aa067947fb65dd0fea44d2b8e113eb80b04130 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 6 Apr 2020 22:02:09 +0200 Subject: [PATCH 032/162] improve comment --- src/transformers/modeling_reformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 2763935f67bc..0442947c6049 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -218,12 +218,12 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # Causal mask if self.is_decoder: - # TODO: This line can be improved in terms of memory + # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) query_key_dots = query_key_dots - mask * 1e9 # Self mask - # TODO: This line can be improved in terms of memory + # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) query_key_dots = query_key_dots - mask * 1e5 From 4aec75e25b550e2ab5fac859b965fc93593e6dcf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Apr 2020 10:14:38 +0200 Subject: [PATCH 033/162] factorize num_buckets --- src/transformers/configuration_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 776f6ad88ba4..cb888351eb02 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -103,7 +103,7 @@ def __init__( hidden_size=32, num_hidden_layers=12, num_attention_heads=2, - num_buckets=4, + num_buckets=[2, 2], num_hashes=2, chunk_length=7, num_chunks_before=1, From d030e39d4f2fadb26f0e7532632ddcf62359aee2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Apr 2020 12:56:38 +0200 Subject: [PATCH 034/162] better testing parameters --- tests/test_modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 7715413cfb0f..237f4877c3b4 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -271,7 +271,7 @@ def test_lsh_layer(self): def test_lsh_block(self): config = ReformerConfig() - shape = (3, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead + shape = (2, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) From d31808992ac6f4c82fd2b82bf5fe8a25fa29d95b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Apr 2020 16:54:55 +0200 Subject: [PATCH 035/162] make whole model work --- src/transformers/@! | 382 +++++++++++ src/transformers/__init__.py | 4 +- src/transformers/configuration_reformer.py | 17 +- src/transformers/modeling_reformer.py | 314 ++++++++- src/transformers/w! | 744 +++++++++++++++++++++ 5 files changed, 1450 insertions(+), 11 deletions(-) create mode 100644 src/transformers/@! create mode 100644 src/transformers/w! diff --git a/src/transformers/@! b/src/transformers/@! new file mode 100644 index 000000000000..fcb6965f5cfa --- /dev/null +++ b/src/transformers/@! @@ -0,0 +1,382 @@ +# coding=utf-8 # Copyright 2020 Huggingface +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +import numpy as np + +# trax imports - to be deleted later +from trax import math as trax_math +from trax.shapes import ShapeDtype as trax_ShapeDtype +import jax +from trax.layers.research.efficient_attention_v2 import ( + LSHSelfAttention as TraxLSHSelfAttention, +) +from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock +from trax.models.reformer.reformer import ReformerLM as TraxReformer +from trax import layers as tl + +from transformers import ReformerAttention, ReformerLayer, ReformerConfig + + +from transformers import is_torch_available # noqa: F401 +from .utils import require_torch, torch_device # noqa: F401 + + +if is_torch_available(): + import torch # noqa: F401 +# from transformers.modeling_reformer import () + +PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" + + +class TraxUtils(object): + """ class that will help for testing in the beginning + should be deleted step-by-step + + README (HOW-TO-INSTALL TRAX): + 1) git clone https://github.com/patrickvonplaten/trax.git + + - I had to do one tiny change to make the imports work, + see: https://github.com/patrickvonplaten/trax/commit/6c23e88afe7f1c57b0c38eeaa4d450e5f912590c) + 2) link your PYTHON_PATH to ~/trax/trax + 3) pip install all the missing packages HINT: the package gin is installed + + - HINT: the package gin is installed with pip install gin-config==0.1.4 + and not pip install gin. + - The other packages can just be installed with pip install form + error message " missing" + """ + + def __init__(self, shape): + self._shape = shape + + def convert_to_jax_array(self, np_array): + return jax.numpy.asarray(np_array) + + def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): + with trax_math.use_backend("jax"): + if shape is None: + shape = self._shape + input_signature = trax_ShapeDtype(shape, dtype) + return input_signature + + def get_layer( + self, + config, + use_reference_code=True, + mode="eval", + path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", + **kwargs + ): + + with trax_math.use_backend("jax"): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + layer = TraxLSHSelfAttention( + n_heads=config.num_attention_heads, + d_qk=hidden_size_per_head, + d_v=hidden_size_per_head, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + attention_dropout=config.attention_probs_dropout_prob, + output_dropout=config.hidden_dropout_prob, + hash_seed=config.seed, + causal=config.is_decoder, + use_reference_code=use_reference_code, + mode=mode, + path_to_save_weights=path_to_save_weights + ) + + return layer + + def forward_layer( + self, + np_input_data, + layer, + input_signature=None, + random_number_generator=None, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + + if input_signature is None: + input_signature = self.get_input_signature() + + weights, state = layer.init(input_signature) + + if random_number_generator is None: + random_number_generator = layer.new_rngs(1)[0] + + output = layer( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + + def get_block( + self, + config, + use_reference_code=True, + share_qk=True, + ff_use_sru=0, + mode="eval", + path_to_save_weights=PATH_TO_SAVE_WEIGHTS, + ): + + with trax_math.use_backend("jax"): + with jax.disable_jit(): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + list_of_layers = TraxLSHAttentionBlock( + d_model=config.d_model, + d_ff=config.d_ff, + d_attention_key=hidden_size_per_head, + d_attention_value=hidden_size_per_head, + n_heads=config.num_attention_heads, + n_attention_chunks=config.num_attention_chunks, + attention_type=tl.LSHSelfAttention, + dropout=config.hidden_dropout_prob, + share_qk=share_qk, + ff_activation=tl.Gelu, + ff_use_sru=ff_use_sru, + ff_chunk_size=config.ff_chunk_size, + mode=mode, + causal=config.is_decoder, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + use_reference_code=use_reference_code, + hash_seed=config.seed, + path_to_save_weights=path_to_save_weights + ) + block = tl.Serial(tl.ReversibleSerial([list_of_layers])) + + return block + + def forward_block( + self, + np_input_data, + block, + input_signature=None, + random_number_generator=None, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + input_data = (input_data,) * 2 + + if input_signature is None: + input_signature = self.get_input_signature() + input_signature = (input_signature, input_signature) + + weights, state = block.init(input_signature) + + if random_number_generator is None: + random_number_generator = block.new_rngs(1)[0] + + output = block( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + + def get_model( + self, + config, + use_reference_code=True, + share_qk=True, + ff_use_sru=0, + axial_pos_shape=(), + d_axial_pos_embs=None, + mode="eval", + path_to_save_weights=PATH_TO_SAVE_WEIGHTS, + ): + with trax_math.use_backend("jax"): + with jax.disable_jit(): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + model = TraxReformer( + vocab_size=config.vocab_size, + d_model=config.d_model, + d_ff=config.d_ff, + d_attention_key=hidden_size_per_head, + d_attention_value=hidden_size_per_head, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + dropout=config.hidden_dropout_prob, + max_len=config.max_position_embeddings, + n_chunks=config.num_chunks, + n_attention_chunks=config.num_attention_chunks, + attention_type=tl.LSHSelfAttention, + share_qk=share_qk, + axial_pos_shape=axial_pos_shape, + d_axial_pos_embs=d_axial_pos_embs, + ff_activation=tl.Gelu, + ff_use_sru=ff_use_sru, + ff_chunk_size=config.ff_chunk_size, + mode=mode, + causal=config.is_decoder, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + use_reference_code=use_reference_code, + hash_seed=config.seed, + path_to_save_weights=path_to_save_weights + ) + + return model + + def forward_model( + self, + np_input_data, + model, + input_signature=None, + random_number_generator=None, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + input_data = (input_data,) * 2 + + if input_signature is None: + input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) + input_signature = (input_signature, input_signature) + + weights, state = model.init(input_signature) + + if random_number_generator is None: + random_number_generator = model.new_rngs(1)[0] + + output = model( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + + +@require_torch +class ReformerIntegrationTests(unittest.TestCase): + + def _set_param(self, torch_layer, weight, bias=None): + with torch.no_grad(): + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) + torch_layer.weight = torch.nn.Parameter(weight) + if bias is not None: + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) + torch_layer.bias = torch.nn.Parameter(bias) + + def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + np_dense = np.asarray(weights[2]) + + self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + + def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): + weights = weights[0] + + # layernorm 1 + layer_norm_1 = weights[0][0][0] + layer_norm_1_weight = np.asarray(layer_norm_1[0]) + layer_norm_1_bias = np.asarray(layer_norm_1[1]) + self._set_param(torch_layer.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) + + # lsh weights + output + lsh_weights = weights[0][1] + self._set_layer_weights_in_torch(lsh_weights, torch_layer.attention, hidden_size) + + # intermediate weighs + intermediate_weights = weights[2][0][2][2] + + # layernorm 2 + layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) + layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) + self._set_param(torch_layer.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) + + # intermediate dense + inter_dense_weight = np.asarray(intermediate_weights[1][0]) + inter_dense_bias = np.asarray(intermediate_weights[1][1]) + self._set_param(torch_layer.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) + + # intermediate out + out_dense_weight = np.asarray(intermediate_weights[4][0]) + out_dense_bias = np.asarray(intermediate_weights[4][1]) + self._set_param(torch_layer.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) + + def test_lsh_layer(self): + # Remove residual connection in ReformerSelfOutput to test this layer only + # Remove layer norm in ReformerAttention to test this layer only + config = ReformerConfig() + shape = (2, 7, config.hidden_size) # Batch x SeqLen x hiddenSize + np_input = np.random.rand(*shape) + + trax_utils = TraxUtils(shape) + trax_layer = trax_utils.get_layer(config) + trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) + + hf_input = torch.tensor(np_input, dtype=torch.float) + hf_layer = ReformerAttention(config) + self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) + hf_layer.eval() + + hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) + + trax_torch_output = torch.tensor(np.asarray(trax_output)) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) + + def test_lsh_block(self): + config = ReformerConfig() + + shape = (2, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.rand(*shape) + + trax_utils = TraxUtils(shape) + trax_block = trax_utils.get_block(config) + trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) + trax_torch_output_1 = torch.tensor(np.asarray(trax_output[0])) + trax_torch_output_2 = torch.tensor(np.asarray(trax_output[1])) + + hf_input = torch.tensor(np_input, dtype=torch.float) + hf_block = ReformerLayer(config) + self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) + hf_block.eval() + + hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] + + self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) + self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) + + def test_reformer_model(self): + config = ReformerConfig() + + shape = (2, 7) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.randint(0, config.vocab_size, size=shape) + np_zeros = np.zeros((shape[0], 1), dtype=np.int32) + import ipdb + ipdb.set_trace() + torch_np_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) + + trax_utils = TraxUtils(shape) + trax_model = trax_utils.get_model(config) + trax_output, trax_weights, trax_state = trax_utils.forward_model(np_input, model=trax_model) + + import ipdb + ipdb.set_trace() diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a238cd3cad60..556cb5c05575 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -324,7 +324,9 @@ from .modeling_reformer import ( ReformerAttention, - ReformerLayer + ReformerLayer, + ReformerModel, + ReformerModelWithLMHead ) # Optimization diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index cb888351eb02..34f96b88e137 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -99,24 +99,25 @@ class ReformerConfig(PretrainedConfig): def __init__( self, - vocab_size=30522, - hidden_size=32, - num_hidden_layers=12, + vocab_size=10, + hidden_size=16, + num_hidden_layers=3, num_attention_heads=2, - num_buckets=[2, 2], + num_buckets=2, num_hashes=2, chunk_length=7, num_chunks_before=1, num_chunks_after=0, - intermediate_size=64, + intermediate_size=32, hidden_act="gelu", hidden_dropout_prob=0.3, attention_probs_dropout_prob=0.3, - max_position_embeddings=512, + max_position_embeddings=20, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, + sinusoidal_pos_embds=True, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -129,10 +130,11 @@ def __init__( # TO CHANGE LATER: self.seed = 0 self.num_attention_chunks = 1 + self.num_chunks = 0 self.ff_chunk_size = 0 self.d_model = hidden_size self.d_ff = intermediate_size - self.is_decoder = True + self.is_decoder = False # self.ff_activation = # GELU self.vocab_size = vocab_size @@ -152,3 +154,4 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps + self.sinusoidal_pos_embds = sinusoidal_pos_embds diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 0442947c6049..8568bca722c2 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -18,9 +18,13 @@ import logging import torch +import numpy as np from torch import nn +from torch.nn import CrossEntropyLoss from .activations import gelu, gelu_new, swish +from .configuration_reformer import ReformerConfig +from .modeling_utils import PreTrainedModel # DELETE later import numpy @@ -35,6 +39,8 @@ # "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", # "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", +PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" + def mish(x): return x * torch.tanh(nn.functional.softplus(x)) @@ -44,6 +50,56 @@ def mish(x): ReformerLayerNorm = torch.nn.LayerNorm +def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + out.requires_grad = False + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + # TODO(PVP): check sinusoidal position embeddings -> not 100% the same as done in trax + if config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.position_embeddings.weight + ) + + self.dropout = config.hidden_dropout_prob + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + + # TODO(PVP): In https://github.com/google/trax/blob/d947ebc184b94bc48d7e48cc63156a047f00992e/trax/models/reformer/reformer.py#L775 there are two dropouts right after each other as + # it is implemented here. Check if this is ok. + + embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) + embeddings = embeddings + position_embeddings + embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) + return embeddings + + class LSHSelfAttention(nn.Module): def __init__(self, config): super().__init__() @@ -353,8 +409,8 @@ def __init__(self, config): def forward( self, - prev_attention_output, hidden_states, + prev_attention_output, head_mask=None, ): norm_hidden_states = self.layer_norm(hidden_states) @@ -429,9 +485,261 @@ def forward( # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # X_1 = prev_attention_output; X_2 = hidden_states # Y_1 = attention_output; Y_2 = output - attention_outputs = self.attention(prev_attention_output, hidden_states, head_mask) + attention_outputs = self.attention(hidden_states, prev_attention_output, head_mask) attention_output = attention_outputs[0] - output = self.feed_forward(attention_output, prev_attention_output) + output = self.feed_forward(attention_output, hidden_states) outputs = (attention_output, output) + attention_outputs[1:] + return outputs + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.hidden_dropout_prob + + def forward( + self, + hidden_states, + head_mask=None, + ): + all_hidden_states = () + all_attentions = () + attention_output = hidden_states + + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (attention_output, hidden_states) + + layer_outputs = layer_module(attention_output, hidden_states, head_mask[i]) + attention_output, hidden_states = layer_outputs[:2] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (attention_output, hidden_states) + + # Concatenate 2 RevNet outputs + hidden_states = torch.cat([attention_output, hidden_states], dim=-1) + + # Apply layer norm to concatenated hidden states + hidden_states = self.layer_norm(hidden_states) + + # Apply dropout + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class ReformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = ReformerConfig + pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "reformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, ReformerLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class ReformerModel(ReformerPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + :obj:`is_decoder` argument of the configuration set to :obj:`True`; an + :obj:`encoder_hidden_states` is expected as an input to the forward pass. + + .. _`Attention is all you need`: + https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() # noqa: F841 + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] # noqa: F841 + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + ) + sequence_output = encoder_outputs[0] + + # add hidden_states and attentions if they are here + outputs = (sequence_output,) + encoder_outputs[1:] + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +class ReformerOnlyLMHead(nn.Module): + def __init__(self, config): + super().__init__() + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class ReformerModelWithLMHead(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def tie_weights(self): + # TODO(PVP): output and input embeddings are + # apparently not tied so skip this step + pass + + def forward( + self, + input_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + lm_labels=None, + ): + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + sequence_output = reformer_outputs[0] + lm_logits = self.lm_head(sequence_output) + outputs = (lm_logits,) + reformer_outputs[1:] + + if lm_labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) diff --git a/src/transformers/w! b/src/transformers/w! new file mode 100644 index 000000000000..1ed9099b4b44 --- /dev/null +++ b/src/transformers/w! @@ -0,0 +1,744 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model. """ + +import logging + +import torch +import numpy as np +from torch import nn + +from .activations import gelu, gelu_new, swish +from .configuration_reformer import ReformerConfig +from .modeling_utils import PreTrainedModel + +# DELETE later +import numpy + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} +# TODO: fill with pretrained model weights +# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", +# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", +# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", +# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", + +PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} +ReformerLayerNorm = torch.nn.LayerNorm + + +def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + out.requires_grad = False + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + if config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.position_embeddings.weight + ) + + self.dropout = config.hidden_dropout_prob + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + + # TODO: In https://github.com/google/trax/blob/d947ebc184b94bc48d7e48cc63156a047f00992e/trax/models/reformer/reformer.py#L775 there are two dropouts right after each other as + # it is implemented here. Check if this is ok. + + embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) + embeddings = embeddings + position_embeddings + embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) + return embeddings + + +class LSHSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.hash_seed = config.seed + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.chunk_length = config.chunk_length + self.num_chunks_before = config.num_chunks_before + self.num_chunks_after = config.num_chunks_after + self.output_attentions = config.output_attentions + self.is_decoder = config.is_decoder + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + self.dropout = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states, + head_mask=None, + ): + + # get SeqLen and BatchSize + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + mixed_query_key_vectors = self.query_key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) + value_vectors = self._transpose_for_scores(mixed_value_vectors) + + assert query_key_vectors.shape[-1] == self.attention_head_size + assert value_vectors.shape[-1] == self.attention_head_size + + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors) + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length + + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) + + query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) + value_vectors = self._gather_by_expansion(value_vectors, ticker) + + # q_info = ticker + + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) + ticker = self._split_dim_by(ticker, -1, self.chunk_length) + + # Optionally include adjacent chunks. + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + + key_vectors = self._len_and_dim_norm(query_key_vectors) + + out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) + + if self.num_hashes > 1: + out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) + logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + + out_vectors = self._transpose_for_output(out_vectors) + outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + return outputs + + def _hash_vectors(self, vectors): + batch_size = vectors.shape[0] + + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.num_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for num_bucket in self.num_buckets: + assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) + rotation_size += num_bucket + num_buckets *= num_bucket + + rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + + # TODO: delete later when integration tests are ok + # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 +# random_rotations = torch.randn(rotations_shape, device=vectors.device) + numpy.random.seed(self.hash_seed) + random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) + # rotated_vectors has dim: + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + # TODO: IMPORTANT: At the moment we use the same random rotation over all batches + # and heads -> is that bad? It seems like in original reformer a different random + # rotation is used over heads and batches + rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for num_bucket in self.num_buckets: + rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] + cur_sum += num_bucket // 2 + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + + if buckets is None: + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + + cur_product *= num_bucket + + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(self.num_hashes, device=vectors.device) + offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + + # repeat same values for Batch_size and Num_Attn_Heads + offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) + offset_buckets = self._merge_by_middle_dim(buckets + offsets) + + return offset_buckets + + def _get_ticker_and_undo_ticker(self, sequence_length, buckets): + batch_size = buckets.shape[0] + + # TODO: what is ticker? Is ticker something like indices?? Ask authors + ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) + + buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + + # Hash-based sort + sorted_ticker = torch.argsort(buckets_and_t, dim=-1) + undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + + sorted_ticker = (sorted_ticker % sequence_length) + return sorted_ticker, undo_sorted_ticker + + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): + key_vectors = self._look_adjacent(key_vectors) + value_vectors = self._look_adjacent(value_vectors) + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # TODO(PVP): better naming + query_info = ticker + key_value_info = self._look_adjacent(ticker) + + # Causal mask + if self.is_decoder: + # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called + mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - mask * 1e9 + + # Self mask + # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called + mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) + query_key_dots = query_key_dots - mask * 1e5 + + # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. + + # Note: Self mask is used because Q and K projection weights are shared. + # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): + + # " While attention to the future is not allowed, typical implementations of the + # Transformer do allow a position to attend to itself. + # Such behavior is undesirable in a shared-QK formulation because the dot-product + # of a query vector with itself will almost always be greater than the dot product of a + # query vector with a vector at another position. We therefore modify the masking + # to forbid a token from attending to itself, except in situations + # where a token has no other valid attention targets (e.g. the first token in a sequence) " + + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + dots = torch.exp(query_key_dots - logits) + + # dropout + dots = nn.functional.dropout(dots, self.dropout, self.training) + + # Mask heads if we want to + if head_mask is not None: + dots = dots * head_mask + + # attend values + out_vectors = torch.matmul(dots, value_vectors) + + # merge chunk length + logits = self._merge_by_middle_dim(logits).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors) + + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_ticker) + + return out_vectors, logits, dots + + def _look_adjacent(self, vectors): + """ Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if self.num_chunks_before == 0 and self.num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-self.num_chunks_before, self.num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _len_and_dim_norm(self, vectors): + vectors = self._len_norm(vectors) + vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) + return vectors + + def _len_norm(self, x, epsilon=1e-6): + variance = torch.mean(x**2, -1, keepdim=True) + norm_x = x / torch.sqrt(variance + epsilon) + return norm_x + + def _gather_by_expansion(self, vectors, idxs): + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, self.num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) + + def _transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.transpose(2, 1) + + def _transpose_for_output(self, x): + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) + + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + + def _merge_by_middle_dim(self, vectors): + batch_size = vectors.shape[0] + new_dim_shape = (batch_size, self.num_attention_heads, -1) + + if len(vectors.shape) == 5: + return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) + elif len(vectors.shape) == 4: + return torch.reshape(vectors, new_dim_shape) + else: + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.dropout = config.hidden_dropout_prob + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + # residual connection + output = (hidden_states + input_tensor) + return output + + +class ReformerAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.self_attention = LSHSelfAttention(config) + self.output = ReformerSelfOutput(config) + + def forward( + self, + hidden_states, + prev_attention_output, + head_mask=None, + ): + norm_hidden_states = self.layer_norm(hidden_states) + self_attention_outputs = self.self_attention( + norm_hidden_states, head_mask + ) + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + attention_output = self.output(self_attention_outputs[0], prev_attention_output) + outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them + return outputs + + +class ReformerFeedForwardDense(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = config.hidden_dropout_prob + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = self.act_fn(hidden_states) + return hidden_states + + +class ReformerFeedForwardOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = config.hidden_dropout_prob + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + output = (hidden_states + input_tensor) + return output + + +class ReformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = ReformerFeedForwardDense(config) + self.output = ReformerFeedForwardOutput(config) + + def forward(self, attention_output, prev_attention_output): + norm_attention_output = self.layer_norm(attention_output) + dense_output = self.dense(norm_attention_output) + output = self.output(dense_output, prev_attention_output) + return output + + +class ReformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = ReformerAttention(config) + self.feed_forward = ReformerFeedForward(config) + + def forward( + self, + prev_attention_output, + hidden_states, + head_mask=None, + ): + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # X_1 = prev_attention_output; X_2 = hidden_states + # Y_1 = attention_output; Y_2 = output + attention_outputs = self.attention(hidden_states, prev_attention_output, head_mask) + attention_output = attention_outputs[0] + output = self.feed_forward(attention_output, hidden_states) + + outputs = (attention_output, output) + attention_outputs[1:] + + return outputs + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.hidden_dropout_prob + + def forward( + self, + hidden_states, + head_mask=None, + ): + all_hidden_states = () + all_attentions = () + attention_output = hidden_states + + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (attention_output, hidden_states) + + layer_outputs = layer_module(attention_output, hidden_states, head_mask[i]) + attention_output, hidden_states = layer_outputs[:2] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (attention_output, hidden_states) + + trax_2_layer_output_1 = np.load(PATH_TO_SAVE_WEIGHTS + '/encoder_output_1.npy') + trax_2_layer_output_2 = np.load(PATH_TO_SAVE_WEIGHTS + '/encoder_output_2.npy') + + import ipdb + ipdb.set_trace() + + # Concatenate 2 RevNet outputs + hidden_states = torch.cat([attention_output, hidden_states], dim=-1) + + # Apply layer norm to concatenated hidden states + hidden_states = self.layer_norm(hidden_states) + + # Apply dropout + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class ReformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = ReformerConfig + pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "reformer" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, ReformerLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class ReformerModel(ReformerPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + :obj:`is_decoder` argument of the configuration set to :obj:`True`; an + :obj:`encoder_hidden_states` is expected as an input to the forward pass. + + .. _`Attention is all you need`: + https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() # noqa: F841 + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] # noqa: F841 + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + ) + sequence_output = encoder_outputs[0] + + # add hidden_states and attentions if they are here + outputs = (sequence_output,) + encoder_outputs[1:] + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +class ReformerWithLMHead(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.predictions.decoder + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + lm_labels=None, + ): + outputs = self.reformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + # Although this may seem awkward, BertForMaskedLM supports two scenarios: + # 1. If a tensor that contains the indices of masked labels is provided, + # the cross-entropy is the MLM cross-entropy that measures the likelihood + # of predictions for masked words. + # 2. If `lm_labels` is provided we are in a causal scenario where we + # try to predict the next token for each input in the decoder. + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + if lm_labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + prediction_scores = prediction_scores[:, :-1, :].contiguous() + lm_labels = lm_labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) + outputs = (ltr_lm_loss,) + outputs + + return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) From 4f0b114d03a942dea6b07bf873b1b69bd414dca1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Apr 2020 16:55:24 +0200 Subject: [PATCH 036/162] make lm model work --- src/transformers/@! | 382 ---------------- src/transformers/w! | 744 -------------------------------- tests/test_modeling_reformer.py | 155 ++++++- 3 files changed, 141 insertions(+), 1140 deletions(-) delete mode 100644 src/transformers/@! delete mode 100644 src/transformers/w! diff --git a/src/transformers/@! b/src/transformers/@! deleted file mode 100644 index fcb6965f5cfa..000000000000 --- a/src/transformers/@! +++ /dev/null @@ -1,382 +0,0 @@ -# coding=utf-8 # Copyright 2020 Huggingface -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest -import numpy as np - -# trax imports - to be deleted later -from trax import math as trax_math -from trax.shapes import ShapeDtype as trax_ShapeDtype -import jax -from trax.layers.research.efficient_attention_v2 import ( - LSHSelfAttention as TraxLSHSelfAttention, -) -from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock -from trax.models.reformer.reformer import ReformerLM as TraxReformer -from trax import layers as tl - -from transformers import ReformerAttention, ReformerLayer, ReformerConfig - - -from transformers import is_torch_available # noqa: F401 -from .utils import require_torch, torch_device # noqa: F401 - - -if is_torch_available(): - import torch # noqa: F401 -# from transformers.modeling_reformer import () - -PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" - - -class TraxUtils(object): - """ class that will help for testing in the beginning - should be deleted step-by-step - - README (HOW-TO-INSTALL TRAX): - 1) git clone https://github.com/patrickvonplaten/trax.git - - - I had to do one tiny change to make the imports work, - see: https://github.com/patrickvonplaten/trax/commit/6c23e88afe7f1c57b0c38eeaa4d450e5f912590c) - 2) link your PYTHON_PATH to ~/trax/trax - 3) pip install all the missing packages HINT: the package gin is installed - - - HINT: the package gin is installed with pip install gin-config==0.1.4 - and not pip install gin. - - The other packages can just be installed with pip install form - error message " missing" - """ - - def __init__(self, shape): - self._shape = shape - - def convert_to_jax_array(self, np_array): - return jax.numpy.asarray(np_array) - - def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): - with trax_math.use_backend("jax"): - if shape is None: - shape = self._shape - input_signature = trax_ShapeDtype(shape, dtype) - return input_signature - - def get_layer( - self, - config, - use_reference_code=True, - mode="eval", - path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", - **kwargs - ): - - with trax_math.use_backend("jax"): - hidden_size_per_head = config.hidden_size // config.num_attention_heads - layer = TraxLSHSelfAttention( - n_heads=config.num_attention_heads, - d_qk=hidden_size_per_head, - d_v=hidden_size_per_head, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - attention_dropout=config.attention_probs_dropout_prob, - output_dropout=config.hidden_dropout_prob, - hash_seed=config.seed, - causal=config.is_decoder, - use_reference_code=use_reference_code, - mode=mode, - path_to_save_weights=path_to_save_weights - ) - - return layer - - def forward_layer( - self, - np_input_data, - layer, - input_signature=None, - random_number_generator=None, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - - if input_signature is None: - input_signature = self.get_input_signature() - - weights, state = layer.init(input_signature) - - if random_number_generator is None: - random_number_generator = layer.new_rngs(1)[0] - - output = layer( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - def get_block( - self, - config, - use_reference_code=True, - share_qk=True, - ff_use_sru=0, - mode="eval", - path_to_save_weights=PATH_TO_SAVE_WEIGHTS, - ): - - with trax_math.use_backend("jax"): - with jax.disable_jit(): - hidden_size_per_head = config.hidden_size // config.num_attention_heads - list_of_layers = TraxLSHAttentionBlock( - d_model=config.d_model, - d_ff=config.d_ff, - d_attention_key=hidden_size_per_head, - d_attention_value=hidden_size_per_head, - n_heads=config.num_attention_heads, - n_attention_chunks=config.num_attention_chunks, - attention_type=tl.LSHSelfAttention, - dropout=config.hidden_dropout_prob, - share_qk=share_qk, - ff_activation=tl.Gelu, - ff_use_sru=ff_use_sru, - ff_chunk_size=config.ff_chunk_size, - mode=mode, - causal=config.is_decoder, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - use_reference_code=use_reference_code, - hash_seed=config.seed, - path_to_save_weights=path_to_save_weights - ) - block = tl.Serial(tl.ReversibleSerial([list_of_layers])) - - return block - - def forward_block( - self, - np_input_data, - block, - input_signature=None, - random_number_generator=None, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - input_data = (input_data,) * 2 - - if input_signature is None: - input_signature = self.get_input_signature() - input_signature = (input_signature, input_signature) - - weights, state = block.init(input_signature) - - if random_number_generator is None: - random_number_generator = block.new_rngs(1)[0] - - output = block( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - def get_model( - self, - config, - use_reference_code=True, - share_qk=True, - ff_use_sru=0, - axial_pos_shape=(), - d_axial_pos_embs=None, - mode="eval", - path_to_save_weights=PATH_TO_SAVE_WEIGHTS, - ): - with trax_math.use_backend("jax"): - with jax.disable_jit(): - hidden_size_per_head = config.hidden_size // config.num_attention_heads - model = TraxReformer( - vocab_size=config.vocab_size, - d_model=config.d_model, - d_ff=config.d_ff, - d_attention_key=hidden_size_per_head, - d_attention_value=hidden_size_per_head, - n_layers=config.num_hidden_layers, - n_heads=config.num_attention_heads, - dropout=config.hidden_dropout_prob, - max_len=config.max_position_embeddings, - n_chunks=config.num_chunks, - n_attention_chunks=config.num_attention_chunks, - attention_type=tl.LSHSelfAttention, - share_qk=share_qk, - axial_pos_shape=axial_pos_shape, - d_axial_pos_embs=d_axial_pos_embs, - ff_activation=tl.Gelu, - ff_use_sru=ff_use_sru, - ff_chunk_size=config.ff_chunk_size, - mode=mode, - causal=config.is_decoder, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - use_reference_code=use_reference_code, - hash_seed=config.seed, - path_to_save_weights=path_to_save_weights - ) - - return model - - def forward_model( - self, - np_input_data, - model, - input_signature=None, - random_number_generator=None, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - input_data = (input_data,) * 2 - - if input_signature is None: - input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) - input_signature = (input_signature, input_signature) - - weights, state = model.init(input_signature) - - if random_number_generator is None: - random_number_generator = model.new_rngs(1)[0] - - output = model( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - -@require_torch -class ReformerIntegrationTests(unittest.TestCase): - - def _set_param(self, torch_layer, weight, bias=None): - with torch.no_grad(): - assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) - torch_layer.weight = torch.nn.Parameter(weight) - if bias is not None: - assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) - torch_layer.bias = torch.nn.Parameter(bias) - - def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): - # set torch weights for 1-to-1 comparison - np_query_key = np.asarray(weights[0]) - np_value = np.asarray(weights[1]) - np_dense = np.asarray(weights[2]) - - self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) - - def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): - weights = weights[0] - - # layernorm 1 - layer_norm_1 = weights[0][0][0] - layer_norm_1_weight = np.asarray(layer_norm_1[0]) - layer_norm_1_bias = np.asarray(layer_norm_1[1]) - self._set_param(torch_layer.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) - - # lsh weights + output - lsh_weights = weights[0][1] - self._set_layer_weights_in_torch(lsh_weights, torch_layer.attention, hidden_size) - - # intermediate weighs - intermediate_weights = weights[2][0][2][2] - - # layernorm 2 - layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) - layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - self._set_param(torch_layer.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) - - # intermediate dense - inter_dense_weight = np.asarray(intermediate_weights[1][0]) - inter_dense_bias = np.asarray(intermediate_weights[1][1]) - self._set_param(torch_layer.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) - - # intermediate out - out_dense_weight = np.asarray(intermediate_weights[4][0]) - out_dense_bias = np.asarray(intermediate_weights[4][1]) - self._set_param(torch_layer.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) - - def test_lsh_layer(self): - # Remove residual connection in ReformerSelfOutput to test this layer only - # Remove layer norm in ReformerAttention to test this layer only - config = ReformerConfig() - shape = (2, 7, config.hidden_size) # Batch x SeqLen x hiddenSize - np_input = np.random.rand(*shape) - - trax_utils = TraxUtils(shape) - trax_layer = trax_utils.get_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) - - hf_input = torch.tensor(np_input, dtype=torch.float) - hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) - hf_layer.eval() - - hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] - hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - - trax_torch_output = torch.tensor(np.asarray(trax_output)) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - - def test_lsh_block(self): - config = ReformerConfig() - - shape = (2, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead - np_input = np.random.rand(*shape) - - trax_utils = TraxUtils(shape) - trax_block = trax_utils.get_block(config) - trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) - trax_torch_output_1 = torch.tensor(np.asarray(trax_output[0])) - trax_torch_output_2 = torch.tensor(np.asarray(trax_output[1])) - - hf_input = torch.tensor(np_input, dtype=torch.float) - hf_block = ReformerLayer(config) - self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) - hf_block.eval() - - hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] - - self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) - self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) - - def test_reformer_model(self): - config = ReformerConfig() - - shape = (2, 7) # Batch x SeqLen x ModelDimPerHead - np_input = np.random.randint(0, config.vocab_size, size=shape) - np_zeros = np.zeros((shape[0], 1), dtype=np.int32) - import ipdb - ipdb.set_trace() - torch_np_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) - - trax_utils = TraxUtils(shape) - trax_model = trax_utils.get_model(config) - trax_output, trax_weights, trax_state = trax_utils.forward_model(np_input, model=trax_model) - - import ipdb - ipdb.set_trace() diff --git a/src/transformers/w! b/src/transformers/w! deleted file mode 100644 index 1ed9099b4b44..000000000000 --- a/src/transformers/w! +++ /dev/null @@ -1,744 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch REFORMER model. """ - -import logging - -import torch -import numpy as np -from torch import nn - -from .activations import gelu, gelu_new, swish -from .configuration_reformer import ReformerConfig -from .modeling_utils import PreTrainedModel - -# DELETE later -import numpy - - -logger = logging.getLogger(__name__) - -REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} -# TODO: fill with pretrained model weights -# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", -# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", -# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", -# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", - -PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" - - -def mish(x): - return x * torch.tanh(nn.functional.softplus(x)) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} -ReformerLayerNorm = torch.nn.LayerNorm - - -def create_sinusoidal_embeddings(n_pos, dim, out): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) - out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - out.requires_grad = False - - -class ReformerEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - if config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.position_embeddings.weight - ) - - self.dropout = config.hidden_dropout_prob - - def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - position_embeddings = self.position_embeddings(position_ids) - - # TODO: In https://github.com/google/trax/blob/d947ebc184b94bc48d7e48cc63156a047f00992e/trax/models/reformer/reformer.py#L775 there are two dropouts right after each other as - # it is implemented here. Check if this is ok. - - embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) - embeddings = embeddings + position_embeddings - embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) - return embeddings - - -class LSHSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.hash_seed = config.seed - self.num_hashes = config.num_hashes - self.num_buckets = config.num_buckets - self.chunk_length = config.chunk_length - self.num_chunks_before = config.num_chunks_before - self.num_chunks_after = config.num_chunks_after - self.output_attentions = config.output_attentions - self.is_decoder = config.is_decoder - - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - - self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - - self.dropout = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states, - head_mask=None, - ): - - # get SeqLen and BatchSize - sequence_length = hidden_states.shape[1] - batch_size = hidden_states.shape[0] - - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) - - query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) - value_vectors = self._transpose_for_scores(mixed_value_vectors) - - assert query_key_vectors.shape[-1] == self.attention_head_size - assert value_vectors.shape[-1] == self.attention_head_size - - # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors) - assert int(buckets.shape[-1]) == self.num_hashes * sequence_length - - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) - - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) - value_vectors = self._gather_by_expansion(value_vectors, ticker) - - # q_info = ticker - - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) - ticker = self._split_dim_by(ticker, -1, self.chunk_length) - - # Optionally include adjacent chunks. - if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - - key_vectors = self._len_and_dim_norm(query_key_vectors) - - out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) - - if self.num_hashes > 1: - out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) - logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) - - probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) - out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) - - out_vectors = self._transpose_for_output(out_vectors) - outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) - return outputs - - def _hash_vectors(self, vectors): - batch_size = vectors.shape[0] - - # See https://arxiv.org/pdf/1509.02897.pdf - # We sample a different random rotation for each round of hashing to - # decrease the probability of hash misses. - if isinstance(self.num_buckets, int): - assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) - rotation_size = self.num_buckets - num_buckets = self.num_buckets - else: - # Factorize the hash if self.num_buckets is a list or tuple - rotation_size, num_buckets = 0, 1 - for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) - rotation_size += num_bucket - num_buckets *= num_bucket - - rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) - - # TODO: delete later when integration tests are ok - # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 -# random_rotations = torch.randn(rotations_shape, device=vectors.device) - numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) - # rotated_vectors has dim: - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 - # TODO: IMPORTANT: At the moment we use the same random rotation over all batches - # and heads -> is that bad? It seems like in original reformer a different random - # rotation is used over heads and batches - rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) - - if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - # Get the buckets for them and combine. - buckets, cur_sum, cur_product = None, 0, 1 - for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] - cur_sum += num_bucket // 2 - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) - - if buckets is None: - buckets = torch.argmax(rotated_vectors, dim=-1) - else: - buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) - - cur_product *= num_bucket - - # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). - # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. - offsets = torch.arange(self.num_hashes, device=vectors.device) - offsets = torch.reshape(offsets * num_buckets, (-1, 1)) - - # repeat same values for Batch_size and Num_Attn_Heads - offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets) - - return offset_buckets - - def _get_ticker_and_undo_ticker(self, sequence_length, buckets): - batch_size = buckets.shape[0] - - # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) - ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) - - buckets_and_t = sequence_length * buckets + (ticker % sequence_length) - - # Hash-based sort - sorted_ticker = torch.argsort(buckets_and_t, dim=-1) - undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) - - sorted_ticker = (sorted_ticker % sequence_length) - return sorted_ticker, undo_sorted_ticker - - def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): - key_vectors = self._look_adjacent(key_vectors) - value_vectors = self._look_adjacent(value_vectors) - - # get logits and dots - query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - - # TODO(PVP): better naming - query_info = ticker - key_value_info = self._look_adjacent(ticker) - - # Causal mask - if self.is_decoder: - # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - mask * 1e9 - - # Self mask - # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - mask * 1e5 - - # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. - - # Note: Self mask is used because Q and K projection weights are shared. - # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): - - # " While attention to the future is not allowed, typical implementations of the - # Transformer do allow a position to attend to itself. - # Such behavior is undesirable in a shared-QK formulation because the dot-product - # of a query vector with itself will almost always be greater than the dot product of a - # query vector with a vector at another position. We therefore modify the masking - # to forbid a token from attending to itself, except in situations - # where a token has no other valid attention targets (e.g. the first token in a sequence) " - - logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) - dots = torch.exp(query_key_dots - logits) - - # dropout - dots = nn.functional.dropout(dots, self.dropout, self.training) - - # Mask heads if we want to - if head_mask is not None: - dots = dots * head_mask - - # attend values - out_vectors = torch.matmul(dots, value_vectors) - - # merge chunk length - logits = self._merge_by_middle_dim(logits).squeeze(-1) - out_vectors = self._merge_by_middle_dim(out_vectors) - - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_ticker) - - return out_vectors, logits, dots - - def _look_adjacent(self, vectors): - """ Used to implement attention between consecutive chunks. - - Args: - vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if self.num_chunks_before == 0 and self.num_chunks_after == 0: - return vectors - - slices = [] - for i in range(-self.num_chunks_before, self.num_chunks_after + 1): - if i == 0: - slices.append(vectors) - else: - slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) - return torch.cat(slices, dim=3) - - def _len_and_dim_norm(self, vectors): - vectors = self._len_norm(vectors) - vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) - return vectors - - def _len_norm(self, x, epsilon=1e-6): - variance = torch.mean(x**2, -1, keepdim=True) - norm_x = x / torch.sqrt(variance + epsilon) - return norm_x - - def _gather_by_expansion(self, vectors, idxs): - expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - vectors = vectors.repeat(1, 1, self.num_hashes, 1) - return torch.gather(vectors, 2, expanded_idxs) - - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.transpose(2, 1) - - def _transpose_for_output(self, x): - x = x.permute(0, 2, 1, 3) - return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) - - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): - batch_size = vectors.shape[0] - split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) - - if len(vectors.shape) == 4: - return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) - elif len(vectors.shape) == 3: - return torch.reshape(vectors, split_dim_shape) - else: - raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - - def _merge_by_middle_dim(self, vectors): - batch_size = vectors.shape[0] - new_dim_shape = (batch_size, self.num_attention_heads, -1) - - if len(vectors.shape) == 5: - return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) - elif len(vectors.shape) == 4: - return torch.reshape(vectors, new_dim_shape) - else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) - - -class ReformerSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.dropout = config.hidden_dropout_prob - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - # residual connection - output = (hidden_states + input_tensor) - return output - - -class ReformerAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attention = LSHSelfAttention(config) - self.output = ReformerSelfOutput(config) - - def forward( - self, - hidden_states, - prev_attention_output, - head_mask=None, - ): - norm_hidden_states = self.layer_norm(hidden_states) - self_attention_outputs = self.self_attention( - norm_hidden_states, head_mask - ) - - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - attention_output = self.output(self_attention_outputs[0], prev_attention_output) - outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them - return outputs - - -class ReformerFeedForwardDense(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.dropout = config.hidden_dropout_prob - if isinstance(config.hidden_act, str): - self.act_fn = ACT2FN[config.hidden_act] - else: - self.act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - hidden_states = self.act_fn(hidden_states) - return hidden_states - - -class ReformerFeedForwardOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = config.hidden_dropout_prob - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - output = (hidden_states + input_tensor) - return output - - -class ReformerFeedForward(nn.Module): - def __init__(self, config): - super().__init__() - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dense = ReformerFeedForwardDense(config) - self.output = ReformerFeedForwardOutput(config) - - def forward(self, attention_output, prev_attention_output): - norm_attention_output = self.layer_norm(attention_output) - dense_output = self.dense(norm_attention_output) - output = self.output(dense_output, prev_attention_output) - return output - - -class ReformerLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = ReformerAttention(config) - self.feed_forward = ReformerFeedForward(config) - - def forward( - self, - prev_attention_output, - hidden_states, - head_mask=None, - ): - - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - # Y_1 = attention_output; Y_2 = output - attention_outputs = self.attention(hidden_states, prev_attention_output, head_mask) - attention_output = attention_outputs[0] - output = self.feed_forward(attention_output, hidden_states) - - outputs = (attention_output, output) + attention_outputs[1:] - - return outputs - - -class ReformerEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) - # Reformer is using Rev Nets, thus last layer outputs are concatenated and - # Layer Norm is done over 2 * hidden_size - self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) - self.dropout = config.hidden_dropout_prob - - def forward( - self, - hidden_states, - head_mask=None, - ): - all_hidden_states = () - all_attentions = () - attention_output = hidden_states - - for i, layer_module in enumerate(self.layer): - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (attention_output, hidden_states) - - layer_outputs = layer_module(attention_output, hidden_states, head_mask[i]) - attention_output, hidden_states = layer_outputs[:2] - - if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) - - # Add last layer - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (attention_output, hidden_states) - - trax_2_layer_output_1 = np.load(PATH_TO_SAVE_WEIGHTS + '/encoder_output_1.npy') - trax_2_layer_output_2 = np.load(PATH_TO_SAVE_WEIGHTS + '/encoder_output_2.npy') - - import ipdb - ipdb.set_trace() - - # Concatenate 2 RevNet outputs - hidden_states = torch.cat([attention_output, hidden_states], dim=-1) - - # Apply layer norm to concatenated hidden states - hidden_states = self.layer_norm(hidden_states) - - # Apply dropout - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - - outputs = (hidden_states,) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - -class ReformerPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - config_class = ReformerConfig - pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP - base_model_prefix = "reformer" - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, ReformerLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -class ReformerModel(ReformerPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - :obj:`is_decoder` argument of the configuration set to :obj:`True`; an - :obj:`encoder_hidden_states` is expected as an input to the forward pass. - - .. _`Attention is all you need`: - https://arxiv.org/abs/1706.03762 - - """ - - def __init__(self, config): - super().__init__(config) - self.config = config - - self.embeddings = ReformerEmbeddings(config) - self.encoder = ReformerEncoder(config) - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - ): - r""" - Return: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: - last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape - :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - - Examples:: - - """ - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() # noqa: F841 - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] # noqa: F841 - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_mask = [None] * self.config.num_hidden_layers - - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds - ) - - encoder_outputs = self.encoder( - embedding_output, - head_mask=head_mask, - ) - sequence_output = encoder_outputs[0] - - # add hidden_states and attentions if they are here - outputs = (sequence_output,) + encoder_outputs[1:] - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) - - -class ReformerWithLMHead(ReformerPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.reformer = ReformerModel(config) - self.lm_head = ReformerOnlyLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.lm_head.predictions.decoder - - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) - def forward( - self, - input_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - lm_labels=None, - ): - outputs = self.reformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here - - # Although this may seem awkward, BertForMaskedLM supports two scenarios: - # 1. If a tensor that contains the indices of masked labels is provided, - # the cross-entropy is the MLM cross-entropy that measures the likelihood - # of predictions for masked words. - # 2. If `lm_labels` is provided we are in a causal scenario where we - # try to predict the next token for each input in the decoder. - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - outputs = (masked_lm_loss,) + outputs - - if lm_labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - prediction_scores = prediction_scores[:, :-1, :].contiguous() - lm_labels = lm_labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) - outputs = (ltr_lm_loss,) + outputs - - return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 237f4877c3b4..204c840832a4 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -24,9 +24,10 @@ LSHSelfAttention as TraxLSHSelfAttention, ) from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock +from trax.models.reformer.reformer import ReformerLM as TraxReformer from trax import layers as tl -from transformers import ReformerAttention, ReformerLayer, ReformerConfig +from transformers import ReformerAttention, ReformerLayer, ReformerConfig, ReformerModelWithLMHead from transformers import is_torch_available # noqa: F401 @@ -64,11 +65,11 @@ def __init__(self, shape): def convert_to_jax_array(self, np_array): return jax.numpy.asarray(np_array) - def get_input_signature(self, shape=None): + def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): with trax_math.use_backend("jax"): if shape is None: shape = self._shape - input_signature = trax_ShapeDtype(shape) + input_signature = trax_ShapeDtype(shape, dtype) return input_signature def get_layer( @@ -163,9 +164,9 @@ def get_block( hash_seed=config.seed, path_to_save_weights=path_to_save_weights ) - layer = tl.Serial(tl.ReversibleSerial([list_of_layers])) + block = tl.Serial(tl.ReversibleSerial([list_of_layers])) - return layer + return block def forward_block( self, @@ -193,6 +194,79 @@ def forward_block( return output, weights, state + def get_model( + self, + config, + use_reference_code=True, + share_qk=True, + ff_use_sru=0, + axial_pos_shape=(), + d_axial_pos_embs=None, + mode="eval", + path_to_save_weights=PATH_TO_SAVE_WEIGHTS, + ): + with trax_math.use_backend("jax"): + with jax.disable_jit(): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + model = TraxReformer( + vocab_size=config.vocab_size, + d_model=config.d_model, + d_ff=config.d_ff, + d_attention_key=hidden_size_per_head, + d_attention_value=hidden_size_per_head, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + dropout=config.hidden_dropout_prob, + max_len=config.max_position_embeddings, + n_chunks=config.num_chunks, + n_attention_chunks=config.num_attention_chunks, + attention_type=tl.LSHSelfAttention, + share_qk=share_qk, + axial_pos_shape=axial_pos_shape, + d_axial_pos_embs=d_axial_pos_embs, + ff_activation=tl.Gelu, + ff_use_sru=ff_use_sru, + ff_chunk_size=config.ff_chunk_size, + mode=mode, + causal=config.is_decoder, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + n_hashes=config.num_hashes, + n_buckets=config.num_buckets, + use_reference_code=use_reference_code, + hash_seed=config.seed, + path_to_save_weights=path_to_save_weights + ) + + return model + + def forward_model( + self, + np_input_data, + model, + input_signature=None, + random_number_generator=None, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + input_data = (input_data,) * 2 + + if input_signature is None: + input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) + input_signature = (input_signature, input_signature) + + weights, state = model.init(input_signature) + + if random_number_generator is None: + random_number_generator = model.new_rngs(1)[0] + + output = model( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + @require_torch class ReformerIntegrationTests(unittest.TestCase): @@ -215,18 +289,17 @@ def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) - def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): - weights = weights[0] + def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): # layernorm 1 layer_norm_1 = weights[0][0][0] layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) - self._set_param(torch_layer.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) + self._set_param(torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) # lsh weights + output lsh_weights = weights[0][1] - self._set_layer_weights_in_torch(lsh_weights, torch_layer.attention, hidden_size) + self._set_layer_weights_in_torch(lsh_weights, torch_block.attention, hidden_size) # intermediate weighs intermediate_weights = weights[2][0][2][2] @@ -234,17 +307,51 @@ def _set_block_weights_in_torch(self, weights, torch_layer, hidden_size): # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - self._set_param(torch_layer.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) + self._set_param(torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - self._set_param(torch_layer.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) + self._set_param(torch_block.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) - self._set_param(torch_layer.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) + self._set_param(torch_block.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) + + def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): + + # reformer model + torch_model_reformer = torch_model.reformer + + # word embeds + word_embeddings = np.asarray(weights[1]) + self._set_param(torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings)) + + # pos embeds + pos_embeddings = np.asarray(weights[3]) + self._set_param(torch_model_reformer.embeddings.position_embeddings, torch.tensor(pos_embeddings[0])) + + trax_layer_weights = weights[5] + assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len(trax_layer_weights), "HF and trax model do not have the same number of layers" + for layer_idx, layer in enumerate(torch_model_reformer.encoder.layer): + block_weights = trax_layer_weights[4 * layer_idx: 4 * (layer_idx + 1)] + self._set_block_weights_in_torch(block_weights, layer, hidden_size) + + # output weights + out_weights = weights[6] + + # output layer norm + layer_norm_out_weight = np.asarray(out_weights[0][0]) + layer_norm_out_bias = np.asarray(out_weights[0][1]) + self._set_param(torch_model_reformer.encoder.layer_norm, torch.tensor(layer_norm_out_weight), torch.tensor(layer_norm_out_bias)) + + # output embeddings + output_embed_weights = np.asarray(out_weights[2][0]) + output_embed_bias = np.asarray(out_weights[2][1]) + self._set_param(torch_model.lm_head.decoder, torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), torch.tensor(output_embed_bias)) + + pass def test_lsh_layer(self): # Remove residual connection in ReformerSelfOutput to test this layer only @@ -282,7 +389,7 @@ def test_lsh_block(self): hf_input = torch.tensor(np_input, dtype=torch.float) hf_block = ReformerLayer(config) - self._set_block_weights_in_torch(trax_weights, hf_block, config.hidden_size) + self._set_block_weights_in_torch(trax_weights[0], hf_block, config.hidden_size) hf_block.eval() hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] @@ -290,4 +397,24 @@ def test_lsh_block(self): self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) - pass + def test_reformer_lm_model(self): + config = ReformerConfig() + + shape = (2, 7) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.randint(0, config.vocab_size, size=shape) + np_zeros = np.zeros((shape[0], 1), dtype=np.int) + + trax_utils = TraxUtils(shape) + trax_model = trax_utils.get_model(config) + trax_output, trax_weights, trax_state = trax_utils.forward_model(np_input, model=trax_model) + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + + hf_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) + hf_model = ReformerModelWithLMHead(config) + self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) + hf_model.eval() + + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + + self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) From 6c8bad6877909ed8d1f9c4ce2433dffa1792836d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Apr 2020 11:18:26 +0200 Subject: [PATCH 037/162] add t5 copy paste tokenizer --- src/transformers/tokenization_reformer.py | 190 ++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 src/transformers/tokenization_reformer.py diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py new file mode 100644 index 000000000000..22d368e47514 --- /dev/null +++ b/src/transformers/tokenization_reformer.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2018 Reformer Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model Reformer.""" + + +import logging +import os +from shutil import copyfile + +from .tokenization_utils import PreTrainedTokenizer + + +logger = logging.getLogger(__name__) + +SPIECE_UNDERLINE = "▁" + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to file names for serializing Tokenizer instances +#################################################### +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to pretrained vocabulary URL for all the model shortcut names. +#################################################### +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + } +} + +#################################################### +# Mapping from model shortcut names to max length of inputs +#################################################### +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "t5-small": 512, + "t5-base": 512, + "t5-large": 512, + "t5-3b": 512, + "t5-11b": 512, +} + +# IMPORTANT: This is just a copy-paste from T5 and not compared to the original Reformer +# Tokenizer yet!!!! + + +class ReformerTokenizer(PreTrainedTokenizer): + """ + Constructs an Reformer tokenizer. Based on `SentencePiece `__ . + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`string`): + `SentencePiece `__ file (generally has a `.spm` extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (:obj:`string`, `optional`, defaults to ""): + The end of sequence token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the end + of sequence. The token used is the :obj:`sep_token`. + unk_token (:obj:`string`, `optional`, defaults to ""): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`string`, `optional`, defaults to ""): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + **kwargs + ): + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + self.max_len_single_sentence = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens + + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use ReformerTokenizer:" + "https://github.com/google/sentencepiece" + "pip install sentencepiece" + ) + raise + + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece" + "pip install sentencepiece" + ) + raise + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text, sample=False): + """ Take as input a string and return a list of strings (tokens) for words/sub-words + """ + if not sample: + pieces = self.sp_model.EncodeAsPieces(text) + else: + pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) + return pieces + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index < self.sp_model.get_piece_size(): + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = self.sp_model.decode_pieces(tokens) + return out_string + + def save_vocabulary(self, save_directory): + """ Save the sentencepiece vocabulary (copy original file) and special tokens file + to a directory. + """ + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) From b71ef16d0a0ad6579396d42c1e8eaa6396f577bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Apr 2020 14:21:46 +0200 Subject: [PATCH 038/162] add chunking feed forward --- src/transformers/configuration_reformer.py | 23 ++++---- src/transformers/modeling_reformer.py | 63 ++++++++++++++++++++-- tests/test_modeling_reformer.py | 4 ++ 3 files changed, 76 insertions(+), 14 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 34f96b88e137..dfc0c38f4dd4 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -58,8 +58,8 @@ class ReformerConfig(PretrainedConfig): TODO (PVP) num_chunks_after (:obj:`int`, optional, defaults to ): TODO (PVP) - intermediate_size (:obj:`int`, optional, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + feed_forward_size (:obj:`int`, optional, defaults to 3072): + Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the Transformer encoder. hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. @@ -108,7 +108,9 @@ def __init__( chunk_length=7, num_chunks_before=1, num_chunks_after=0, - intermediate_size=32, + chunk_size_lm_head=1, + chunk_size_feed_forward=1, + feed_forward_size=32, hidden_act="gelu", hidden_dropout_prob=0.3, attention_probs_dropout_prob=0.3, @@ -116,11 +118,10 @@ def __init__( type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, - pad_token_id=0, sinusoidal_pos_embds=True, **kwargs ): - super().__init__(pad_token_id=pad_token_id, **kwargs) + super().__init__(**kwargs) # TODO: not added yet # predict_mem_len=2048 @@ -129,11 +130,11 @@ def __init__( # TO CHANGE LATER: self.seed = 0 - self.num_attention_chunks = 1 - self.num_chunks = 0 - self.ff_chunk_size = 0 + self.num_attention_chunks = 1 # this is not used in LSHSelfAttention + self.num_chunks = 1 + self.ff_chunk_size = chunk_size_feed_forward self.d_model = hidden_size - self.d_ff = intermediate_size + self.d_ff = feed_forward_size self.is_decoder = False # self.ff_activation = # GELU @@ -147,7 +148,7 @@ def __init__( self.num_chunks_after = num_chunks_after self.num_chunks_before = num_chunks_before self.hidden_act = hidden_act - self.intermediate_size = intermediate_size + self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings @@ -155,3 +156,5 @@ def __init__( self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 8568bca722c2..834b1a57aeab 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -19,6 +19,8 @@ import torch import numpy as np +import inspect + from torch import nn from torch.nn import CrossEntropyLoss @@ -26,6 +28,8 @@ from .configuration_reformer import ReformerConfig from .modeling_utils import PreTrainedModel +from typing import Callable + # DELETE later import numpy @@ -58,6 +62,45 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out.requires_grad = False +def apply_chunking_to_forward(chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. + It then applies a layer `forward_fn` to each chunk independently to save memory. + If the `forward_fn` is independent across the `chunk_dim` this function will yield the + same result as not applying it. + + Args: + chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` + chunk_dim: int - the dimension over which the input_tensors should be chunked + forward_fn: fn - the forward fn of the model + input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked + Returns: + a Tensor with the same shape the foward_fn would have given if applied + """ + + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) + tensor_shape = input_tensors[0].shape + assert all(input_tensor.shape == tensor_shape for input_tensor in input_tensors), "All input tenors have to be of the same shape" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + assert num_args_in_forward_chunk_fn == len(input_tensors), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(num_args_in_forward_chunk_fn, len(input_tensors)) + + if chunk_size > 0: + assert input_tensors[0].shape[chunk_dim] % chunk_size == 0, "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(input_tensors[0][chunk_dim], chunk_size) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) + + class ReformerEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -428,7 +471,7 @@ def forward( class ReformerFeedForwardDense(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) self.dropout = config.hidden_dropout_prob if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] @@ -445,7 +488,7 @@ def forward(self, hidden_states): class ReformerFeedForwardOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) self.dropout = config.hidden_dropout_prob def forward(self, hidden_states, input_tensor): @@ -455,14 +498,20 @@ def forward(self, hidden_states, input_tensor): return output -class ReformerFeedForward(nn.Module): +class ChunkReformerFeedForward(nn.Module): def __init__(self, config): super().__init__() + self.seq_len_dim = 1 + self.chunk_size_feed_forward = config.chunk_size_feed_forward self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) + # TODO(PVP): Does this work with backpropagation? def forward(self, attention_output, prev_attention_output): + return apply_chunking_to_forward(self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, prev_attention_output) + + def forward_chunk(self, attention_output, prev_attention_output): norm_attention_output = self.layer_norm(attention_output) dense_output = self.dense(norm_attention_output) output = self.output(dense_output, prev_attention_output) @@ -473,7 +522,7 @@ class ReformerLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = ReformerAttention(config) - self.feed_forward = ReformerFeedForward(config) + self.feed_forward = ChunkReformerFeedForward(config) def forward( self, @@ -686,6 +735,8 @@ def __init__(self, config): super().__init__() # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size + self.seq_len_dim = 1 + self.chunk_size_lm_head = config.chunk_size_lm_head self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) @@ -693,6 +744,10 @@ def __init__(self, config): self.decoder.bias = self.bias def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + + # TODO(PVP): Does this work with backpropagation? + def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 204c840832a4..664ab82274c9 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -304,6 +304,10 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): # intermediate weighs intermediate_weights = weights[2][0][2][2] + # Chunked Feed Forward + if len(intermediate_weights) == 4: + intermediate_weights = intermediate_weights[2] + # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) From 99427c65229f05e98f07098bbb7a75e3907a5f00 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 09:53:34 +0200 Subject: [PATCH 039/162] clean config --- src/transformers/__init__.py | 1 + src/transformers/configuration_reformer.py | 22 +++++++++------------- src/transformers/tokenization_reformer.py | 13 +++++-------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 556cb5c05575..92f4457bd714 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -139,6 +139,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast +from .tokenization_reformer import ReformerTokenizer from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index dfc0c38f4dd4..1935ee7e420f 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -70,8 +70,6 @@ class ReformerConfig(PretrainedConfig): max_position_embeddings (:obj:`int`, optional, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - type_vocab_size (:obj:`int`, optional, defaults to 2): - The vocabulary size of the `token_type_ids` passed into :class:`~transformers.ReformerModel`. initializer_range (:obj:`float`, optional, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): @@ -99,23 +97,22 @@ class ReformerConfig(PretrainedConfig): def __init__( self, - vocab_size=10, - hidden_size=16, - num_hidden_layers=3, + vocab_size=200, + hidden_size=128, + num_hidden_layers=2, num_attention_heads=2, num_buckets=2, - num_hashes=2, + num_hashes=4, chunk_length=7, num_chunks_before=1, num_chunks_after=0, - chunk_size_lm_head=1, - chunk_size_feed_forward=1, - feed_forward_size=32, + chunk_size_lm_head=0, + chunk_size_feed_forward=0, + feed_forward_size=128, hidden_act="gelu", hidden_dropout_prob=0.3, attention_probs_dropout_prob=0.3, - max_position_embeddings=20, - type_vocab_size=2, + max_position_embeddings=512, initializer_range=0.02, layer_norm_eps=1e-12, sinusoidal_pos_embds=True, @@ -131,7 +128,7 @@ def __init__( # TO CHANGE LATER: self.seed = 0 self.num_attention_chunks = 1 # this is not used in LSHSelfAttention - self.num_chunks = 1 + self.num_chunks = 0 self.ff_chunk_size = chunk_size_feed_forward self.d_model = hidden_size self.d_ff = feed_forward_size @@ -152,7 +149,6 @@ def __init__( self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.sinusoidal_pos_embds = sinusoidal_pos_embds diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index 22d368e47514..c47467ab1adc 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -26,6 +26,11 @@ SPIECE_UNDERLINE = "▁" + +# IMPORTANT: This is just a copy-paste from T5 and not compared to the original Reformer +# Tokenizer yet!!!! + + #################################################### # Mapping from the keyword arguments names of Tokenizer `__init__` # to file names for serializing Tokenizer instances @@ -45,16 +50,8 @@ # Mapping from model shortcut names to max length of inputs #################################################### PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "t5-small": 512, - "t5-base": 512, - "t5-large": 512, - "t5-3b": 512, - "t5-11b": 512, } -# IMPORTANT: This is just a copy-paste from T5 and not compared to the original Reformer -# Tokenizer yet!!!! - class ReformerTokenizer(PreTrainedTokenizer): """ From 4ffa9255076f89527c67a9dc7e47b900a31e9cb6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 12:14:32 +0200 Subject: [PATCH 040/162] add improved assert statements --- src/transformers/modeling_reformer.py | 29 +++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 834b1a57aeab..f075192f14ae 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -109,6 +109,8 @@ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.max_position_embeddings = config.max_position_embeddings + # TODO(PVP): check sinusoidal position embeddings -> not 100% the same as done in trax if config.sinusoidal_pos_embds: create_sinusoidal_embeddings( @@ -118,6 +120,7 @@ def __init__(self, config): self.dropout = config.hidden_dropout_prob def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: input_shape = input_ids.size() else: @@ -132,6 +135,8 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) + assert position_ids.shape[-1] <= self.max_position_embeddings, "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format(position_ids.shape[-1], self.max_position_embeddings) + position_embeddings = self.position_embeddings(position_ids) # TODO(PVP): In https://github.com/google/trax/blob/d947ebc184b94bc48d7e48cc63156a047f00992e/trax/models/reformer/reformer.py#L775 there are two dropouts right after each other as @@ -157,6 +162,7 @@ def __init__(self, config): self.num_chunks_after = config.num_chunks_after self.output_attentions = config.output_attentions self.is_decoder = config.is_decoder + self.max_position_embeddings = config.max_position_embeddings self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -191,6 +197,10 @@ def forward( assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size + if self.num_buckets is None: + # set `num_buckets` on the fly + self._set_num_buckets_on_the_fly(sequence_length) + # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors) assert int(buckets.shape[-1]) == self.num_hashes * sequence_length @@ -200,8 +210,6 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) value_vectors = self._gather_by_expansion(value_vectors, ticker) - # q_info = ticker - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) ticker = self._split_dim_by(ticker, -1, self.chunk_length) @@ -225,6 +233,7 @@ def forward( out_vectors = self._transpose_for_output(out_vectors) outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + return outputs def _hash_vectors(self, vectors): @@ -304,6 +313,18 @@ def _get_ticker_and_undo_ticker(self, sequence_length, buckets): sorted_ticker = (sorted_ticker % sequence_length) return sorted_ticker, undo_sorted_ticker + def _set_num_buckets_on_the_fly(self, sequence_length): + # recommended `num_buckets` from paper + num_buckets = 2 * sequence_length // self.chunk_length + + # factorize `num_buckets` if `num_buckets` becomes too large + num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chuck_length) + if num_buckets > 2 * num_buckets_limit: + num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] + + logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) + self.num_buckets = num_buckets + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): key_vectors = self._look_adjacent(key_vectors) value_vectors = self._look_adjacent(value_vectors) @@ -636,6 +657,7 @@ class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config + self.chunk_length = config.chunk_length self.embeddings = ReformerEmbeddings(config) self.encoder = ReformerEncoder(config) @@ -696,6 +718,8 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 + assert input_shape[-1] % self.chunk_length == 0, "Sequence Length {} has to be a multiple of config.chunk_length {}. Please consider padding the input to a length of {}.".format(input_shape[-1], self.chunk_length, input_shape[-1] + (self.chunk_length - input_shape[-1] % self.chunk_length)) + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -777,6 +801,7 @@ def forward( inputs_embeds=None, lm_labels=None, ): + reformer_outputs = self.reformer( input_ids, position_ids=position_ids, From b116e3c22bbce25cbcdab7ad88a3e7431003d1b0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 12:38:18 +0200 Subject: [PATCH 041/162] make tokenizer work --- src/transformers/tokenization_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index c47467ab1adc..c0df33f1a5ab 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -90,7 +90,7 @@ def __init__( eos_token="", unk_token="", pad_token="", - additional_special_tokens=None, + additional_special_tokens=[], **kwargs ): super().__init__( From 79a0babda18332d770c972be66c763a566803bfc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 13:45:55 +0200 Subject: [PATCH 042/162] improve test --- tests/test_modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 664ab82274c9..2a5a2f9638e3 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -361,7 +361,7 @@ def test_lsh_layer(self): # Remove residual connection in ReformerSelfOutput to test this layer only # Remove layer norm in ReformerAttention to test this layer only config = ReformerConfig() - shape = (2, 7, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) From aceb586e7403e62ded5fc09915bb6c2a01d20a6b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 13:52:16 +0200 Subject: [PATCH 043/162] correct typo --- src/transformers/modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index f075192f14ae..699e63dd1eed 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -318,7 +318,7 @@ def _set_num_buckets_on_the_fly(self, sequence_length): num_buckets = 2 * sequence_length // self.chunk_length # factorize `num_buckets` if `num_buckets` becomes too large - num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chuck_length) + num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length) if num_buckets > 2 * num_buckets_limit: num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] From a4814bd9280331265af07f505ae4a44f41d89f8c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 19:16:42 +0200 Subject: [PATCH 044/162] extend config --- src/transformers/configuration_reformer.py | 6 ++++++ src/transformers/modeling_reformer.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 1935ee7e420f..c4ec0f2550f5 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -116,6 +116,9 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, sinusoidal_pos_embds=True, + axial_pos_embds=False, + axial_pos_shape=(7, 2), + axial_pos_embds_dim=(64, 64), **kwargs ): super().__init__(**kwargs) @@ -152,5 +155,8 @@ def __init__( self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = axial_pos_shape + self.axial_pos_embds_dim = axial_pos_embds_dim self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 699e63dd1eed..c8692b618ab1 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -629,7 +629,8 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.weight.requires_grad: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) From 5eeeb2546d93f688cb45c2b19d4ec3a5bad1e9f5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 20:37:59 +0200 Subject: [PATCH 045/162] add complexer test --- tests/test_modeling_reformer.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 2a5a2f9638e3..c0e860fd1eb7 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -162,7 +162,7 @@ def get_block( n_buckets=config.num_buckets, use_reference_code=use_reference_code, hash_seed=config.seed, - path_to_save_weights=path_to_save_weights + path_to_save_weights=path_to_save_weights, ) block = tl.Serial(tl.ReversibleSerial([list_of_layers])) @@ -200,8 +200,6 @@ def get_model( use_reference_code=True, share_qk=True, ff_use_sru=0, - axial_pos_shape=(), - d_axial_pos_embs=None, mode="eval", path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): @@ -222,8 +220,8 @@ def get_model( n_attention_chunks=config.num_attention_chunks, attention_type=tl.LSHSelfAttention, share_qk=share_qk, - axial_pos_shape=axial_pos_shape, - d_axial_pos_embs=d_axial_pos_embs, + axial_pos_shape=config.axial_pos_shape, + d_axial_pos_embs=config.axial_pos_embds_dim, ff_activation=tl.Gelu, ff_use_sru=ff_use_sru, ff_chunk_size=config.ff_chunk_size, @@ -332,10 +330,6 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): word_embeddings = np.asarray(weights[1]) self._set_param(torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings)) - # pos embeds - pos_embeddings = np.asarray(weights[3]) - self._set_param(torch_model_reformer.embeddings.position_embeddings, torch.tensor(pos_embeddings[0])) - trax_layer_weights = weights[5] assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len(trax_layer_weights), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layer): @@ -404,7 +398,7 @@ def test_lsh_block(self): def test_reformer_lm_model(self): config = ReformerConfig() - shape = (2, 7) # Batch x SeqLen x ModelDimPerHead + shape = (2, 14) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) From 0ee5db4eebdc77fcd740e0bfc487259e2dde33b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 10:19:42 +0200 Subject: [PATCH 046/162] add new axial position embeddings --- src/transformers/configuration_reformer.py | 6 ++-- src/transformers/modeling_reformer.py | 41 ++++++++++++++++++++-- tests/test_modeling_reformer.py | 7 ++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index c4ec0f2550f5..0f6f2b9ae858 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -114,9 +114,10 @@ def __init__( attention_probs_dropout_prob=0.3, max_position_embeddings=512, initializer_range=0.02, + axial_norm_std=1.0, layer_norm_eps=1e-12, - sinusoidal_pos_embds=True, - axial_pos_embds=False, + sinusoidal_pos_embds=False, + axial_pos_embds=True, axial_pos_shape=(7, 2), axial_pos_embds_dim=(64, 64), **kwargs @@ -158,5 +159,6 @@ def __init__( self.axial_pos_embds = axial_pos_embds self.axial_pos_shape = axial_pos_shape self.axial_pos_embds_dim = axial_pos_embds_dim + self.axial_norm_std = axial_norm_std self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index c8692b618ab1..db1f91756506 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -101,6 +101,38 @@ def apply_chunking_to_forward(chunk_size: int, chunk_dim: int, forward_fn: Calla return forward_fn(*input_tensors) +class AxialEmbeddings(nn.Module): + + def __init__(self, config): + super().__init__() + self.axial_pos_shape = config.axial_pos_shape + self.axial_pos_embds_dim = config.axial_pos_embds_dim + self.weights = nn.ParameterList() + + assert sum(self.axial_pos_embds_dim) == config.hidden_size, "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format(self.axial_pos_embds_dim, config.hidden_size) + + # create weights + for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): + # create shape + ax_shape = [1] * len(self.axial_pos_shape) + ax_shape[axis] = self.axial_pos_shape[axis] + ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,) + + # create tenser and init + self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) + + def forward(self, position_ids): + assert np.prod(self.axial_pos_shape) == position_ids.shape[-1], "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, position_ids.shape[-1]) + # broadcast weights to correct shape + broadcasted_weights = [weight.expand((position_ids.shape[0],) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights] + + # TODO (PVP): Add dropout when training + # get correct position encodings + position_encodings = torch.cat([torch.reshape(weight, position_ids.shape + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) + + return position_encodings + + class ReformerEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -111,11 +143,13 @@ def __init__(self, config): self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.max_position_embeddings = config.max_position_embeddings - # TODO(PVP): check sinusoidal position embeddings -> not 100% the same as done in trax + assert not (config.sinusoidal_pos_embds and config.axial_pos_embds), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" if config.sinusoidal_pos_embds: create_sinusoidal_embeddings( n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.position_embeddings.weight ) + elif config.axial_pos_embds: + self.position_embeddings = AxialEmbeddings(config) self.dropout = config.hidden_dropout_prob @@ -626,7 +660,10 @@ class ReformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, AxialEmbeddings): + for weight in module.weights: + torch.nn.init.normal_(weight, std=self.config.axial_norm_std) + elif isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c0e860fd1eb7..ee113d8f5da0 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -330,6 +330,13 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): word_embeddings = np.asarray(weights[1]) self._set_param(torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings)) + if isinstance(weights[3], tuple): + position_embeddings = torch_model_reformer.embeddings.position_embeddings + for emb_idx in range(len(position_embeddings.weights)): + emb_weights = np.asarray(weights[3][emb_idx][0]) + assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format(position_embeddings[emb_idx]) + position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + trax_layer_weights = weights[5] assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len(trax_layer_weights), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layer): From 938aa8b0071ecd63176aa3ada595237455364dd8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 11:55:15 +0200 Subject: [PATCH 047/162] add local block attention layer --- src/transformers/configuration_reformer.py | 2 + src/transformers/modeling_reformer.py | 247 +++++++++++++++------ tests/test_modeling_reformer.py | 74 +++++- 3 files changed, 252 insertions(+), 71 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 0f6f2b9ae858..61d5ecae26eb 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -120,6 +120,7 @@ def __init__( axial_pos_embds=True, axial_pos_shape=(7, 2), axial_pos_embds_dim=(64, 64), + attn_type="lsh", **kwargs ): super().__init__(**kwargs) @@ -162,3 +163,4 @@ def __init__( self.axial_norm_std = axial_norm_std self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward + self.attn_type = attn_type diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index db1f91756506..f31e07a9c2f8 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -16,6 +16,7 @@ """PyTorch REFORMER model. """ import logging +import itertools import torch import numpy as np @@ -182,7 +183,64 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): return embeddings -class LSHSelfAttention(nn.Module): +class EfficientAttentionUtils(object): + """ + A few utilities for nn.Modules in Reformer, to be used as a mixin. + """ + + def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): + """ Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if num_chunks_before == 0 and num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-num_chunks_before, num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _transpose_for_scores(self, x, num_attn_heads, attn_head_size): + new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) + x = x.view(*new_x_shape) + return x.transpose(2, 1) + + def _transpose_for_output(self, x, num_attn_heads, attn_head_size): + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) + + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size): + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (attn_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + + def _merge_by_middle_dim(self, vectors, num_attn_heads): + batch_size = vectors.shape[0] + new_dim_shape = (batch_size, num_attn_heads, -1) + + if len(vectors.shape) == 5: + return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) + elif len(vectors.shape) == 4: + return torch.reshape(vectors, new_dim_shape) + else: + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + + +class LSHSelfAttention(nn.Module, EfficientAttentionUtils): def __init__(self, config): super().__init__() @@ -217,7 +275,6 @@ def forward( hidden_states, head_mask=None, ): - # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] @@ -225,8 +282,8 @@ def forward( mixed_query_key_vectors = self.query_key(hidden_states) mixed_value_vectors = self.value(hidden_states) - query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors) - value_vectors = self._transpose_for_scores(mixed_value_vectors) + query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._transpose_for_scores(mixed_value_vectors, self.num_attention_heads, self.attention_head_size) assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size @@ -244,11 +301,10 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) value_vectors = self._gather_by_expansion(value_vectors, ticker) - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length) - ticker = self._split_dim_by(ticker, -1, self.chunk_length) + query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) - # Optionally include adjacent chunks. if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 @@ -257,15 +313,15 @@ def forward( out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) if self.num_hashes > 1: - out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length) - logits = self._split_dim_by(logits, self.num_hashes, sequence_length).unsqueeze(-1) + out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size) + logits = self._split_dim_by(logits, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size).unsqueeze(-1) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) - out_vectors = self._transpose_for_output(out_vectors) + out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) return outputs @@ -327,7 +383,7 @@ def _hash_vectors(self, vectors): # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets) + offset_buckets = self._merge_by_middle_dim(buckets + offsets, self.num_attention_heads) return offset_buckets @@ -360,15 +416,15 @@ def _set_num_buckets_on_the_fly(self, sequence_length): self.num_buckets = num_buckets def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): - key_vectors = self._look_adjacent(key_vectors) - value_vectors = self._look_adjacent(value_vectors) + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # TODO(PVP): better naming query_info = ticker - key_value_info = self._look_adjacent(ticker) + key_value_info = self._look_adjacent(ticker, self.num_chunks_before, self.num_chunks_after) # Causal mask if self.is_decoder: @@ -408,8 +464,8 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker out_vectors = torch.matmul(dots, value_vectors) # merge chunk length - logits = self._merge_by_middle_dim(logits).squeeze(-1) - out_vectors = self._merge_by_middle_dim(out_vectors) + logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) @@ -417,26 +473,6 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker return out_vectors, logits, dots - def _look_adjacent(self, vectors): - """ Used to implement attention between consecutive chunks. - - Args: - vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if self.num_chunks_before == 0 and self.num_chunks_after == 0: - return vectors - - slices = [] - for i in range(-self.num_chunks_before, self.num_chunks_after + 1): - if i == 0: - slices.append(vectors) - else: - slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) - return torch.cat(slices, dim=3) - def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) @@ -452,36 +488,107 @@ def _gather_by_expansion(self, vectors, idxs): vectors = vectors.repeat(1, 1, self.num_hashes, 1) return torch.gather(vectors, 2, expanded_idxs) - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.transpose(2, 1) - def _transpose_for_output(self, x): - x = x.permute(0, 2, 1, 3) - return torch.reshape(x, (x.size()[0], -1, self.num_attention_heads * self.attention_head_size)) +class LocalSelfAttention(nn.Module, EfficientAttentionUtils): + def __init__(self, config): + super().__init__() - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2): - batch_size = vectors.shape[0] - split_dim_shape = (batch_size, self.num_attention_heads, dim_factor_1, dim_factor_2) + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.chunk_length = config.chunk_length + self.num_chunks_before = config.num_chunks_before + self.num_chunks_after = config.num_chunks_after + self.output_attentions = config.output_attentions + self.is_decoder = config.is_decoder - if len(vectors.shape) == 4: - return torch.reshape(vectors, split_dim_shape + (self.attention_head_size,)) - elif len(vectors.shape) == 3: - return torch.reshape(vectors, split_dim_shape) - else: - raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size - def _merge_by_middle_dim(self, vectors): - batch_size = vectors.shape[0] - new_dim_shape = (batch_size, self.num_attention_heads, -1) + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) - if len(vectors.shape) == 5: - return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) - elif len(vectors.shape) == 4: - return torch.reshape(vectors, new_dim_shape) - else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + self.dropout = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states, + head_mask=None, + ): + # get SeqLen and BatchSize + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + mixed_query_vectors = self.query(hidden_states) + mixed_key_vectors = self.key(hidden_states) + mixed_value_vectors = self.value(hidden_states) + + query_vectors = self._transpose_for_scores(mixed_query_vectors, self.num_attention_heads, self.attention_head_size) + key_vectors = self._transpose_for_scores(mixed_key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._transpose_for_scores(mixed_value_vectors, self.num_attention_heads, self.attention_head_size) + + assert query_vectors.shape[-1] == self.attention_head_size + assert key_vectors.shape[-1] == self.attention_head_size + assert value_vectors.shape[-1] == self.attention_head_size + + if self.chunk_length is None: + assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + + key_vectors = key_vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32)) + + # chunk vectors + query_vectors = self._split_dim_by(query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + key_vectors = self._split_dim_by(key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + + # chunk indices + indices = torch.arange(sequence_length, device=query_vectors.device).repeat(batch_size, self.num_attention_heads, 1) + query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + + # append chunks before and after + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + key_value_indices = self._look_adjacent(key_value_indices, self.num_chunks_before, self.num_chunks_after) + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # Causal mask + if self.is_decoder: + # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called + mask = torch.lt(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).long().to(query_indices.device) + query_key_dots = query_key_dots - mask * 1e9 + + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + attention_probs = torch.exp(query_key_dots - logits) + + # dropout + attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # attend values + out_vectors = torch.matmul(attention_probs, value_vectors) + + # merge chunk length + logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) + out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) + + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + + out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) + outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + + return outputs class ReformerSelfOutput(nn.Module): @@ -499,10 +606,24 @@ def forward(self, hidden_states, input_tensor): class ReformerAttention(nn.Module): + + layer_id_iter = itertools.count() + def __init__(self, config): super().__init__() + self.layer_id = next(self.layer_id_iter) self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attention = LSHSelfAttention(config) + + if config.attn_type == "lsh": + self.self_attention = LSHSelfAttention(config) + elif config.attn_type == "local": + self.self_attention = LocalSelfAttention(config) + elif config.attn_type == "mixed": + if self.layer_id % 2 == 0: + self.self_attention = LocalSelfAttention(config) + else: + self.self_attention = LSHSelfAttention(config) + self.output = ReformerSelfOutput(config) def forward( diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index ee113d8f5da0..b6a9738736fa 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -22,6 +22,7 @@ import jax from trax.layers.research.efficient_attention_v2 import ( LSHSelfAttention as TraxLSHSelfAttention, + SelfAttention as TraxSelfAttention ) from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock from trax.models.reformer.reformer import ReformerLM as TraxReformer @@ -72,7 +73,33 @@ def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): input_signature = trax_ShapeDtype(shape, dtype) return input_signature - def get_layer( + def get_local_layer( + self, + config, + use_reference_code=True, + mode="eval", + path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", + **kwargs + ): + + with trax_math.use_backend("jax"): + hidden_size_per_head = config.hidden_size // config.num_attention_heads + layer = TraxSelfAttention( + n_heads=config.num_attention_heads, + d_qk=hidden_size_per_head, + d_v=hidden_size_per_head, + chunk_len=config.chunk_length, + n_chunks_before=config.num_chunks_before, + n_chunks_after=config.num_chunks_after, + attention_dropout=config.attention_probs_dropout_prob, + output_dropout=config.hidden_dropout_prob, + causal=config.is_decoder, + use_reference_code=use_reference_code, + mode=mode, + ) + return layer + + def get_lsh_layer( self, config, use_reference_code=True, @@ -100,7 +127,6 @@ def get_layer( mode=mode, path_to_save_weights=path_to_save_weights ) - return layer def forward_layer( @@ -277,7 +303,7 @@ def _set_param(self, torch_layer, weight, bias=None): assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) torch_layer.bias = torch.nn.Parameter(bias) - def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): + def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison np_query_key = np.asarray(weights[0]) np_value = np.asarray(weights[1]) @@ -287,6 +313,19 @@ def _set_layer_weights_in_torch(self, weights, torch_layer, hidden_size): self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + + np_query = np.asarray(weights[0]) + np_key = np.asarray(weights[1]) + np_value = np.asarray(weights[2]) + np_dense = np.asarray(weights[3]) + + self._set_param(torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) + self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): # layernorm 1 @@ -297,7 +336,7 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): # lsh weights + output lsh_weights = weights[0][1] - self._set_layer_weights_in_torch(lsh_weights, torch_block.attention, hidden_size) + self._set_layer_weights_in_torch_lsh(lsh_weights, torch_block.attention, hidden_size) # intermediate weighs intermediate_weights = weights[2][0][2][2] @@ -359,19 +398,38 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): pass def test_lsh_layer(self): - # Remove residual connection in ReformerSelfOutput to test this layer only - # Remove layer norm in ReformerAttention to test this layer only config = ReformerConfig() shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) - trax_layer = trax_utils.get_layer(config) + trax_layer = trax_utils.get_lsh_layer(config) + trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) + + hf_input = torch.tensor(np_input, dtype=torch.float) + hf_layer = ReformerAttention(config) + self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) + hf_layer.eval() + + hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) + + trax_torch_output = torch.tensor(np.asarray(trax_output)) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) + + def test_local_layer(self): + config = ReformerConfig() + shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize + np_input = np.random.rand(*shape) + + trax_utils = TraxUtils(shape) + trax_layer = trax_utils.get_local_layer(config) trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) hf_input = torch.tensor(np_input, dtype=torch.float) + config.attn_type = "local" hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size) + self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] From 4d7c23bf9188edb46d7e9f37add4683135128f34 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 12:05:47 +0200 Subject: [PATCH 048/162] clean tests --- src/transformers/configuration_reformer.py | 12 +----------- tests/test_modeling_reformer.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 61d5ecae26eb..10e5a3c351f0 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -125,21 +125,11 @@ def __init__( ): super().__init__(**kwargs) -# TODO: not added yet -# predict_mem_len=2048 -# predict_drop_len=256 -# n_parallel_heads=1 # TO CHANGE LATER: self.seed = 0 - self.num_attention_chunks = 1 # this is not used in LSHSelfAttention - self.num_chunks = 0 - self.ff_chunk_size = chunk_size_feed_forward - self.d_model = hidden_size - self.d_ff = feed_forward_size - self.is_decoder = False -# self.ff_activation = # GELU + self.is_decoder = True self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index b6a9738736fa..9f4eda7df214 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -167,18 +167,18 @@ def get_block( with jax.disable_jit(): hidden_size_per_head = config.hidden_size // config.num_attention_heads list_of_layers = TraxLSHAttentionBlock( - d_model=config.d_model, - d_ff=config.d_ff, + d_model=config.hidden_size, + d_ff=config.feed_forward_size, d_attention_key=hidden_size_per_head, d_attention_value=hidden_size_per_head, n_heads=config.num_attention_heads, - n_attention_chunks=config.num_attention_chunks, + n_attention_chunks=None, attention_type=tl.LSHSelfAttention, dropout=config.hidden_dropout_prob, share_qk=share_qk, ff_activation=tl.Gelu, ff_use_sru=ff_use_sru, - ff_chunk_size=config.ff_chunk_size, + ff_chunk_size=config.chunk_size_feed_forward, mode=mode, causal=config.is_decoder, chunk_len=config.chunk_length, @@ -227,6 +227,7 @@ def get_model( share_qk=True, ff_use_sru=0, mode="eval", + num_chunks=0, path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): with trax_math.use_backend("jax"): @@ -234,23 +235,23 @@ def get_model( hidden_size_per_head = config.hidden_size // config.num_attention_heads model = TraxReformer( vocab_size=config.vocab_size, - d_model=config.d_model, - d_ff=config.d_ff, + d_model=config.hidden_size, + d_ff=config.feed_forward_size, d_attention_key=hidden_size_per_head, d_attention_value=hidden_size_per_head, n_layers=config.num_hidden_layers, n_heads=config.num_attention_heads, dropout=config.hidden_dropout_prob, max_len=config.max_position_embeddings, - n_chunks=config.num_chunks, - n_attention_chunks=config.num_attention_chunks, + n_chunks=num_chunks, + n_attention_chunks=None, attention_type=tl.LSHSelfAttention, share_qk=share_qk, axial_pos_shape=config.axial_pos_shape, d_axial_pos_embs=config.axial_pos_embds_dim, ff_activation=tl.Gelu, ff_use_sru=ff_use_sru, - ff_chunk_size=config.ff_chunk_size, + ff_chunk_size=config.chunk_size_feed_forward, mode=mode, causal=config.is_decoder, chunk_len=config.chunk_length, From 50276de3eeb377fe74b3b848473af2784f364bf6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 12:46:07 +0200 Subject: [PATCH 049/162] refactor --- src/transformers/configuration_reformer.py | 2 ++ src/transformers/modeling_reformer.py | 25 +++++++--------------- tests/test_modeling_reformer.py | 24 +++++++++------------ 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 10e5a3c351f0..ac682d9e073a 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -98,6 +98,7 @@ class ReformerConfig(PretrainedConfig): def __init__( self, vocab_size=200, + attention_head_size=32, hidden_size=128, num_hidden_layers=2, num_attention_heads=2, @@ -131,6 +132,7 @@ def __init__( self.is_decoder = True self.vocab_size = vocab_size + self.attention_head_size = attention_head_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index f31e07a9c2f8..f2afec73d54d 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -245,7 +245,6 @@ def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size self.hash_seed = config.seed self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets @@ -256,14 +255,9 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) + self.hidden_size = config.hidden_size self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -494,21 +488,15 @@ def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size self.chunk_length = config.chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after self.output_attentions = config.output_attentions self.is_decoder = config.is_decoder - self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size - - if self.hidden_size % self.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) + self.hidden_size = config.hidden_size self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -594,7 +582,8 @@ def forward( class ReformerSelfOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + all_head_size = config.num_attention_heads * config.attention_head_size + self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) self.dropout = config.hidden_dropout_prob def forward(self, hidden_states, input_tensor): @@ -623,6 +612,8 @@ def __init__(self, config): self.self_attention = LocalSelfAttention(config) else: self.self_attention = LSHSelfAttention(config) + else: + raise NotImplementedError("config.attn_type: {} does not exist. Select one of ['lsh', 'local', 'mixed'].".format(config.attn_type)) self.output = ReformerSelfOutput(config) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 9f4eda7df214..7bc096e420db 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -83,11 +83,10 @@ def get_local_layer( ): with trax_math.use_backend("jax"): - hidden_size_per_head = config.hidden_size // config.num_attention_heads layer = TraxSelfAttention( n_heads=config.num_attention_heads, - d_qk=hidden_size_per_head, - d_v=hidden_size_per_head, + d_qk=config.attention_head_size, + d_v=config.attention_head_size, chunk_len=config.chunk_length, n_chunks_before=config.num_chunks_before, n_chunks_after=config.num_chunks_after, @@ -109,11 +108,10 @@ def get_lsh_layer( ): with trax_math.use_backend("jax"): - hidden_size_per_head = config.hidden_size // config.num_attention_heads layer = TraxLSHSelfAttention( n_heads=config.num_attention_heads, - d_qk=hidden_size_per_head, - d_v=hidden_size_per_head, + d_qk=config.attention_head_size, + d_v=config.attention_head_size, chunk_len=config.chunk_length, n_chunks_before=config.num_chunks_before, n_chunks_after=config.num_chunks_after, @@ -165,12 +163,11 @@ def get_block( with trax_math.use_backend("jax"): with jax.disable_jit(): - hidden_size_per_head = config.hidden_size // config.num_attention_heads list_of_layers = TraxLSHAttentionBlock( d_model=config.hidden_size, d_ff=config.feed_forward_size, - d_attention_key=hidden_size_per_head, - d_attention_value=hidden_size_per_head, + d_attention_key=config.attention_head_size, + d_attention_value=config.attention_head_size, n_heads=config.num_attention_heads, n_attention_chunks=None, attention_type=tl.LSHSelfAttention, @@ -232,13 +229,12 @@ def get_model( ): with trax_math.use_backend("jax"): with jax.disable_jit(): - hidden_size_per_head = config.hidden_size // config.num_attention_heads model = TraxReformer( vocab_size=config.vocab_size, d_model=config.hidden_size, d_ff=config.feed_forward_size, - d_attention_key=hidden_size_per_head, - d_attention_value=hidden_size_per_head, + d_attention_key=config.attention_head_size, + d_attention_value=config.attention_head_size, n_layers=config.num_hidden_layers, n_heads=config.num_attention_heads, dropout=config.hidden_dropout_prob, @@ -298,10 +294,10 @@ class ReformerIntegrationTests(unittest.TestCase): def _set_param(self, torch_layer, weight, bias=None): with torch.no_grad(): - assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch.layer) + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: - assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch.layer) + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): From 37a2b002a512c8c7cb1f61cd662154cef7e6a6be Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 19:16:59 +0200 Subject: [PATCH 050/162] better testing --- src/transformers/configuration_reformer.py | 8 +- tests/test_modeling_reformer.py | 318 +++++++++++++++++---- 2 files changed, 271 insertions(+), 55 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ac682d9e073a..94fe0a622871 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -119,8 +119,8 @@ def __init__( layer_norm_eps=1e-12, sinusoidal_pos_embds=False, axial_pos_embds=True, - axial_pos_shape=(7, 2), - axial_pos_embds_dim=(64, 64), + axial_pos_shape=[7, 2], + axial_pos_embds_dim=[64, 64], attn_type="lsh", **kwargs ): @@ -150,8 +150,8 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.sinusoidal_pos_embds = sinusoidal_pos_embds self.axial_pos_embds = axial_pos_embds - self.axial_pos_shape = axial_pos_shape - self.axial_pos_embds_dim = axial_pos_embds_dim + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) self.axial_norm_std = axial_norm_std self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 7bc096e420db..da8ea1406eba 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -17,18 +17,27 @@ import numpy as np # trax imports - to be deleted later +import trax from trax import math as trax_math from trax.shapes import ShapeDtype as trax_ShapeDtype + +import gin + import jax from trax.layers.research.efficient_attention_v2 import ( LSHSelfAttention as TraxLSHSelfAttention, - SelfAttention as TraxSelfAttention + SelfAttention as TraxSelfAttention, ) from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock from trax.models.reformer.reformer import ReformerLM as TraxReformer from trax import layers as tl -from transformers import ReformerAttention, ReformerLayer, ReformerConfig, ReformerModelWithLMHead +from transformers import ( + ReformerAttention, + ReformerLayer, + ReformerConfig, + ReformerModelWithLMHead, +) from transformers import is_torch_available # noqa: F401 @@ -39,7 +48,9 @@ import torch # noqa: F401 # from transformers.modeling_reformer import () -PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" +PATH_TO_SAVE_WEIGHTS = ( + "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" +) class TraxUtils(object): @@ -123,16 +134,12 @@ def get_lsh_layer( causal=config.is_decoder, use_reference_code=use_reference_code, mode=mode, - path_to_save_weights=path_to_save_weights + path_to_save_weights=path_to_save_weights, ) return layer def forward_layer( - self, - np_input_data, - layer, - input_signature=None, - random_number_generator=None, + self, np_input_data, layer, input_signature=None, random_number_generator=None, ): with trax_math.use_backend("jax"): input_data = self.convert_to_jax_array(np_input_data) @@ -192,11 +199,7 @@ def get_block( return block def forward_block( - self, - np_input_data, - block, - input_signature=None, - random_number_generator=None, + self, np_input_data, block, input_signature=None, random_number_generator=None, ): with trax_math.use_backend("jax"): input_data = self.convert_to_jax_array(np_input_data) @@ -221,12 +224,20 @@ def get_model( self, config, use_reference_code=True, - share_qk=True, + share_qk=False, ff_use_sru=0, mode="eval", num_chunks=0, + n_buckets=None, + chunk_size_feed_forward=None, path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): + n_buckets = n_buckets if n_buckets is not None else config.num_buckets + chunk_size_feed_forward = ( + chunk_size_feed_forward + if chunk_size_feed_forward + else config.chunk_size_feed_forward + ) with trax_math.use_backend("jax"): with jax.disable_jit(): model = TraxReformer( @@ -247,17 +258,17 @@ def get_model( d_axial_pos_embs=config.axial_pos_embds_dim, ff_activation=tl.Gelu, ff_use_sru=ff_use_sru, - ff_chunk_size=config.chunk_size_feed_forward, + ff_chunk_size=chunk_size_feed_forward, mode=mode, causal=config.is_decoder, chunk_len=config.chunk_length, n_chunks_before=config.num_chunks_before, n_chunks_after=config.num_chunks_after, n_hashes=config.num_hashes, - n_buckets=config.num_buckets, + n_buckets=n_buckets, use_reference_code=use_reference_code, hash_seed=config.seed, - path_to_save_weights=path_to_save_weights + path_to_save_weights=path_to_save_weights, ) return model @@ -268,6 +279,9 @@ def forward_model( model, input_signature=None, random_number_generator=None, + weights=None, + state=None, + only_init=False, ): with trax_math.use_backend("jax"): input_data = self.convert_to_jax_array(np_input_data) @@ -277,7 +291,11 @@ def forward_model( input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) input_signature = (input_signature, input_signature) - weights, state = model.init(input_signature) + if weights is None and state is None: + weights, state = model.init(input_signature) + + if only_init is True: + return if random_number_generator is None: random_number_generator = model.new_rngs(1)[0] @@ -291,13 +309,16 @@ def forward_model( @require_torch class ReformerIntegrationTests(unittest.TestCase): - def _set_param(self, torch_layer, weight, bias=None): with torch.no_grad(): - assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) + assert ( + torch_layer.weight.shape == weight.shape + ), "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: - assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) + assert ( + torch_layer.bias.shape == bias.shape + ), "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): @@ -306,34 +327,62 @@ def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) - self._set_param(torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + self._set_param( + torch_layer.self_attention.query_key, + torch.tensor(np_query_key) + .transpose(1, 2) + .contiguous() + .view(-1, hidden_size), + ) + self._set_param( + torch_layer.self_attention.value, + torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + self._set_param( + torch_layer.output.dense, + torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison - np_query = np.asarray(weights[0]) np_key = np.asarray(weights[1]) np_value = np.asarray(weights[2]) np_dense = np.asarray(weights[3]) - self._set_param(torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size)) - self._set_param(torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1)) + self._set_param( + torch_layer.self_attention.query, + torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + self._set_param( + torch_layer.self_attention.key, + torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + self._set_param( + torch_layer.self_attention.value, + torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + self._set_param( + torch_layer.output.dense, + torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): - # layernorm 1 layer_norm_1 = weights[0][0][0] layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) - self._set_param(torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias)) + self._set_param( + torch_block.attention.layer_norm, + torch.tensor(layer_norm_1_weight), + torch.tensor(layer_norm_1_bias), + ) # lsh weights + output lsh_weights = weights[0][1] - self._set_layer_weights_in_torch_lsh(lsh_weights, torch_block.attention, hidden_size) + self._set_layer_weights_in_torch_lsh( + lsh_weights, torch_block.attention, hidden_size + ) # intermediate weighs intermediate_weights = weights[2][0][2][2] @@ -345,38 +394,58 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - self._set_param(torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias)) + self._set_param( + torch_block.feed_forward.layer_norm, + torch.tensor(layer_norm_2_weight), + torch.tensor(layer_norm_2_bias), + ) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - self._set_param(torch_block.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias)) + self._set_param( + torch_block.feed_forward.dense.dense, + torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(inter_dense_bias), + ) # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) - self._set_param(torch_block.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias)) + self._set_param( + torch_block.feed_forward.output.dense, + torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(out_dense_bias), + ) def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): - # reformer model torch_model_reformer = torch_model.reformer # word embeds word_embeddings = np.asarray(weights[1]) - self._set_param(torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings)) + self._set_param( + torch_model_reformer.embeddings.word_embeddings, + torch.tensor(word_embeddings), + ) if isinstance(weights[3], tuple): position_embeddings = torch_model_reformer.embeddings.position_embeddings for emb_idx in range(len(position_embeddings.weights)): emb_weights = np.asarray(weights[3][emb_idx][0]) - assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format(position_embeddings[emb_idx]) - position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + assert ( + position_embeddings.weights[emb_idx].shape == emb_weights.shape + ), "{} emb does not match".format(position_embeddings[emb_idx]) + position_embeddings.weights[emb_idx] = torch.nn.Parameter( + torch.tensor(emb_weights) + ) trax_layer_weights = weights[5] - assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len(trax_layer_weights), "HF and trax model do not have the same number of layers" + assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len( + trax_layer_weights + ), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layer): - block_weights = trax_layer_weights[4 * layer_idx: 4 * (layer_idx + 1)] + block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] self._set_block_weights_in_torch(block_weights, layer, hidden_size) # output weights @@ -385,12 +454,20 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): # output layer norm layer_norm_out_weight = np.asarray(out_weights[0][0]) layer_norm_out_bias = np.asarray(out_weights[0][1]) - self._set_param(torch_model_reformer.encoder.layer_norm, torch.tensor(layer_norm_out_weight), torch.tensor(layer_norm_out_bias)) + self._set_param( + torch_model_reformer.encoder.layer_norm, + torch.tensor(layer_norm_out_weight), + torch.tensor(layer_norm_out_bias), + ) # output embeddings output_embed_weights = np.asarray(out_weights[2][0]) output_embed_bias = np.asarray(out_weights[2][1]) - self._set_param(torch_model.lm_head.decoder, torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), torch.tensor(output_embed_bias)) + self._set_param( + torch_model.lm_head.decoder, + torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), + torch.tensor(output_embed_bias), + ) pass @@ -401,7 +478,9 @@ def test_lsh_layer(self): trax_utils = TraxUtils(shape) trax_layer = trax_utils.get_lsh_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) + trax_output, trax_weights, trax_state = trax_utils.forward_layer( + np_input, layer=trax_layer + ) hf_input = torch.tensor(np_input, dtype=torch.float) hf_layer = ReformerAttention(config) @@ -421,12 +500,16 @@ def test_local_layer(self): trax_utils = TraxUtils(shape) trax_layer = trax_utils.get_local_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer(np_input, layer=trax_layer) + trax_output, trax_weights, trax_state = trax_utils.forward_layer( + np_input, layer=trax_layer + ) hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_type = "local" hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) + self._set_layer_weights_in_torch_local( + trax_weights, hf_layer, config.hidden_size + ) hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] @@ -443,7 +526,9 @@ def test_lsh_block(self): trax_utils = TraxUtils(shape) trax_block = trax_utils.get_block(config) - trax_output, trax_weights, trax_state = trax_utils.forward_block(np_input, block=trax_block) + trax_output, trax_weights, trax_state = trax_utils.forward_block( + np_input, block=trax_block + ) trax_torch_output_1 = torch.tensor(np.asarray(trax_output[0])) trax_torch_output_2 = torch.tensor(np.asarray(trax_output[1])) @@ -466,10 +551,14 @@ def test_reformer_lm_model(self): trax_utils = TraxUtils(shape) trax_model = trax_utils.get_model(config) - trax_output, trax_weights, trax_state = trax_utils.forward_model(np_input, model=trax_model) + trax_output, trax_weights, trax_state = trax_utils.forward_model( + np_input, model=trax_model + ) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - hf_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) hf_model = ReformerModelWithLMHead(config) self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) hf_model.eval() @@ -477,4 +566,131 @@ def test_reformer_lm_model(self): hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) + self.assertTrue( + torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) + ) + + def test_pretrained_crime_and_punishment_lm_model(self): + hf_model = ReformerModelWithLMHead.from_pretrained( + "patrickvonplaten/reformer-crime-and-punish" + ) + config = hf_model.config + + trax_model_path = ( + "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" + ) + + shape = (3, 1024) + np_input = np.random.randint(0, config.vocab_size, size=shape) + np_zeros = np.zeros((shape[0], 1), dtype=np.int) + + trax_utils = TraxUtils(shape) + trax_model = self.load_model(trax_model_path, trax_utils.get_input_signature(trax_math.numpy.int32)) + + trax_model = trax_utils.get_model( + config, mode="predict", n_buckets=[64, 128], chunk_size_feed_forward=0 + ) + + weights, state = trax_model.init( + trax_utils.get_input_signature(dtype=trax_math.numpy.int32) + ) + + trax_output = trax_model(np_input) + + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + + self.assertTrue( + torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) + ) + + def load_model(self, trax_model_path, input_signature): + gin.parse_config( + """ + import trax.layers + import trax.models + import trax.optimizers + import trax.supervised.inputs + import trax.supervised.trainer_lib + + # Parameters that will vary between experiments: + # ============================================================================== + train.model = @trax.models.ReformerLM + # Our model will have 6 layers, alternating between the LSH attention proposed + # in the Reformer paper and local attention within a certain context window. + n_layers = 6 + attn_type = [ + @SelfAttention, + @LSHSelfAttention, + @SelfAttention, + @LSHSelfAttention, + @SelfAttention, + @LSHSelfAttention, + ] + share_qk = False # LSH attention ignores this flag and always shares q & k + n_heads = 2 + attn_kv = 64 + dropout = 0.05 + n_tokens = 524288 + + # Parameters for MultifactorSchedule: + # ============================================================================== + MultifactorSchedule.constant = 0.01 + MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay' + MultifactorSchedule.warmup_steps = 100 + MultifactorSchedule.steps_per_cycle = 900 + + # Parameters for Adam: + # ============================================================================== + Adam.weight_decay_rate=0.0 + Adam.b1 = 0.86 + Adam.b2 = 0.92 + Adam.eps = 1e-9 + + # Parameters for SelfAttention: + # ============================================================================== + SelfAttention.attention_dropout = 0.05 + SelfAttention.chunk_len = 64 + SelfAttention.n_chunks_before = 1 + SelfAttention.n_parallel_heads = 1 + + # Parameters for LSHSelfAttention: + # ============================================================================== + LSHSelfAttention.attention_dropout = 0.0 + LSHSelfAttention.chunk_len = 64 + LSHSelfAttention.n_buckets = [64, 128] + LSHSelfAttention.n_chunks_after = 0 + LSHSelfAttention.n_chunks_before = 1 + LSHSelfAttention.n_hashes = 1 + LSHSelfAttention.n_parallel_heads = 1 + LSHSelfAttention.predict_drop_len = 128 + LSHSelfAttention.predict_mem_len = 1024 + + # Parameters for ReformerLM: + # ============================================================================== + ReformerLM.attention_type = %attn_type + ReformerLM.d_attention_key = %attn_kv + ReformerLM.d_attention_value = %attn_kv + ReformerLM.d_model = 256 + ReformerLM.d_ff = 512 + ReformerLM.dropout = %dropout + ReformerLM.ff_activation = @trax.layers.Relu + ReformerLM.max_len = %n_tokens + ReformerLM.mode = 'train' + ReformerLM.n_heads = %n_heads + ReformerLM.n_layers = %n_layers + ReformerLM.vocab_size = 320 + ReformerLM.share_qk = %share_qk + ReformerLM.axial_pos_shape = (512, 1024) + ReformerLM.d_axial_pos_embs= (64, 192) + """ + ) + trax_model = trax.models.ReformerLM(mode="predict") + trax_model.init(input_signature) + trax_model.init_from_file(trax_model_path) + return trax_model From 07c0c72b93c27b97c48bc4bed8f102724f5a7e5d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Apr 2020 20:03:38 +0200 Subject: [PATCH 051/162] save intermediate progress --- tests/test_modeling_reformer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index da8ea1406eba..6599ba5cbe77 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -585,18 +585,16 @@ def test_pretrained_crime_and_punishment_lm_model(self): np_zeros = np.zeros((shape[0], 1), dtype=np.int) trax_utils = TraxUtils(shape) - trax_model = self.load_model(trax_model_path, trax_utils.get_input_signature(trax_math.numpy.int32)) - - trax_model = trax_utils.get_model( - config, mode="predict", n_buckets=[64, 128], chunk_size_feed_forward=0 - ) + with trax_math.use_backend("jax"): + trax_input = trax_utils.convert_to_jax_array(np_input) + trax_input = (trax_input,) * 2 - weights, state = trax_model.init( - trax_utils.get_input_signature(dtype=trax_math.numpy.int32) - ) + input_signature = trax_utils.get_input_signature(dtype=np.int32) +# input_signature = (input_signature, input_signature) - trax_output = trax_model(np_input) + trax_model = self.load_model(trax_model_path, input_signature) + trax_output = trax_model(np_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) hf_input = torch.cat( From 060a6915f24a0c4c497a151dc641ea98fe1e84e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 12:45:11 +0200 Subject: [PATCH 052/162] clean test file --- tests/test_modeling_reformer.py | 687 ++++++++++++++++---------------- 1 file changed, 351 insertions(+), 336 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 6599ba5cbe77..65ddfdcde643 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -24,12 +24,7 @@ import gin import jax -from trax.layers.research.efficient_attention_v2 import ( - LSHSelfAttention as TraxLSHSelfAttention, - SelfAttention as TraxSelfAttention, -) from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock -from trax.models.reformer.reformer import ReformerLM as TraxReformer from trax import layers as tl from transformers import ( @@ -84,60 +79,6 @@ def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): input_signature = trax_ShapeDtype(shape, dtype) return input_signature - def get_local_layer( - self, - config, - use_reference_code=True, - mode="eval", - path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", - **kwargs - ): - - with trax_math.use_backend("jax"): - layer = TraxSelfAttention( - n_heads=config.num_attention_heads, - d_qk=config.attention_head_size, - d_v=config.attention_head_size, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - attention_dropout=config.attention_probs_dropout_prob, - output_dropout=config.hidden_dropout_prob, - causal=config.is_decoder, - use_reference_code=use_reference_code, - mode=mode, - ) - return layer - - def get_lsh_layer( - self, - config, - use_reference_code=True, - mode="eval", - path_to_save_weights="/home/patrick/hugging_face/experiments/reformer/intermediate_weights", - **kwargs - ): - - with trax_math.use_backend("jax"): - layer = TraxLSHSelfAttention( - n_heads=config.num_attention_heads, - d_qk=config.attention_head_size, - d_v=config.attention_head_size, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - attention_dropout=config.attention_probs_dropout_prob, - output_dropout=config.hidden_dropout_prob, - hash_seed=config.seed, - causal=config.is_decoder, - use_reference_code=use_reference_code, - mode=mode, - path_to_save_weights=path_to_save_weights, - ) - return layer - def forward_layer( self, np_input_data, layer, input_signature=None, random_number_generator=None, ): @@ -158,6 +99,39 @@ def forward_layer( return output, weights, state + def forward_model( + self, + np_input_data, + model, + input_signature=None, + random_number_generator=None, + weights=None, + state=None, + only_init=False, + ): + with trax_math.use_backend("jax"): + input_data = self.convert_to_jax_array(np_input_data) + input_data = (input_data,) * 2 + + if input_signature is None: + input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) + input_signature = (input_signature, input_signature) + + if weights is None and state is None: + weights, state = model.init(input_signature) + + if only_init is True: + return + + if random_number_generator is None: + random_number_generator = model.new_rngs(1)[0] + + output = model( + input_data, weights=weights, state=state, rng=random_number_generator + ) + + return output, weights, state + def get_block( self, config, @@ -167,7 +141,7 @@ def get_block( mode="eval", path_to_save_weights=PATH_TO_SAVE_WEIGHTS, ): - + # only works on old master branch with trax_math.use_backend("jax"): with jax.disable_jit(): list_of_layers = TraxLSHAttentionBlock( @@ -220,95 +194,333 @@ def forward_block( return output, weights, state - def get_model( - self, - config, - use_reference_code=True, - share_qk=False, - ff_use_sru=0, - mode="eval", - num_chunks=0, - n_buckets=None, - chunk_size_feed_forward=None, - path_to_save_weights=PATH_TO_SAVE_WEIGHTS, - ): - n_buckets = n_buckets if n_buckets is not None else config.num_buckets - chunk_size_feed_forward = ( - chunk_size_feed_forward - if chunk_size_feed_forward - else config.chunk_size_feed_forward + +@require_torch +class ReformerIntegrationTests(unittest.TestCase): + + def test_lsh_layer(self): + config = ReformerConfig() + shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize + np_input = np.random.rand(*shape) + + trax_utils = TraxUtils(shape) + trax_layer = self.load_lsh_layer(config) + trax_output, trax_weights, trax_state = trax_utils.forward_layer( + np_input, layer=trax_layer ) - with trax_math.use_backend("jax"): - with jax.disable_jit(): - model = TraxReformer( - vocab_size=config.vocab_size, - d_model=config.hidden_size, - d_ff=config.feed_forward_size, - d_attention_key=config.attention_head_size, - d_attention_value=config.attention_head_size, - n_layers=config.num_hidden_layers, - n_heads=config.num_attention_heads, - dropout=config.hidden_dropout_prob, - max_len=config.max_position_embeddings, - n_chunks=num_chunks, - n_attention_chunks=None, - attention_type=tl.LSHSelfAttention, - share_qk=share_qk, - axial_pos_shape=config.axial_pos_shape, - d_axial_pos_embs=config.axial_pos_embds_dim, - ff_activation=tl.Gelu, - ff_use_sru=ff_use_sru, - ff_chunk_size=chunk_size_feed_forward, - mode=mode, - causal=config.is_decoder, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=n_buckets, - use_reference_code=use_reference_code, - hash_seed=config.seed, - path_to_save_weights=path_to_save_weights, - ) - return model + hf_input = torch.tensor(np_input, dtype=torch.float) + config.attn_type = "lsh" + hf_layer = ReformerAttention(config) + self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) + hf_layer.eval() - def forward_model( - self, - np_input_data, - model, - input_signature=None, - random_number_generator=None, - weights=None, - state=None, - only_init=False, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - input_data = (input_data,) * 2 + hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - if input_signature is None: - input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) - input_signature = (input_signature, input_signature) + trax_torch_output = torch.tensor(np.asarray(trax_output)) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - if weights is None and state is None: - weights, state = model.init(input_signature) + def test_local_layer(self): + config = ReformerConfig() + shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize + np_input = np.random.rand(*shape) - if only_init is True: - return + trax_utils = TraxUtils(shape) + trax_layer = self.load_local_layer(config) + trax_output, trax_weights, trax_state = trax_utils.forward_layer( + np_input, layer=trax_layer + ) - if random_number_generator is None: - random_number_generator = model.new_rngs(1)[0] + hf_input = torch.tensor(np_input, dtype=torch.float) + config.attn_type = "local" + hf_layer = ReformerAttention(config) + self._set_layer_weights_in_torch_local( + trax_weights, hf_layer, config.hidden_size + ) + hf_layer.eval() - output = model( - input_data, weights=weights, state=state, rng=random_number_generator - ) + hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - return output, weights, state + trax_torch_output = torch.tensor(np.asarray(trax_output)) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) + def test_lsh_reformer_lm_model(self): + config = ReformerConfig() + + shape = (2, 14) # Batch x SeqLen x ModelDimPerHead + np_input = np.random.randint(0, config.vocab_size, size=shape) + np_zeros = np.zeros((shape[0], 1), dtype=np.int) + + trax_utils = TraxUtils(shape) + trax_model = self.load_lsh_reformer_lm_model(config) + trax_output, trax_weights, trax_state = trax_utils.forward_model( + np_input, model=trax_model + ) + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + config.attn_type = "lsh" + hf_model = ReformerModelWithLMHead(config) + self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) + hf_model.eval() + + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + + self.assertTrue( + torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) + ) + + def test_pretrained_crime_and_punishment_lm_model(self): + hf_model = ReformerModelWithLMHead.from_pretrained( + "patrickvonplaten/reformer-crime-and-punish" + ) + config = hf_model.config + + trax_model_path = ( + "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" + ) + + shape = (1, 192) + np_input = np.random.randint(0, config.vocab_size, size=shape) + np_zeros = np.zeros((shape[0], 1), dtype=np.int) + + trax_utils = TraxUtils(shape) + trax_input = trax_utils.convert_to_jax_array(np_input) + input_signature = trax_utils.get_input_signature(dtype=np.int32) + + trax_model = self.load_crime_and_punishment_model( + trax_model_path, input_signature + ) + trax_output = trax_model(trax_input) + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + + self.assertTrue( + torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) + ) + + def load_lsh_layer(self, config, mode="eval"): + gin_config = """ + import trax.layers + + # Parameters for LSHSelfAttention: + # ============================================================================== + LSHSelfAttention.n_heads={} + LSHSelfAttention.d_qk={} + LSHSelfAttention.d_v={} + LSHSelfAttention.chunk_len = {} + LSHSelfAttention.n_chunks_before = {} + LSHSelfAttention.n_chunks_after = {} + LSHSelfAttention.n_hashes = {} + LSHSelfAttention.n_buckets = {} + LSHSelfAttention.attention_dropout = {} + LSHSelfAttention.output_dropout = {} + LSHSelfAttention.lsh_seed = {} + LSHSelfAttention.causal= {} + """.format( + config.num_attention_heads, + config.attention_head_size, + config.attention_head_size, + config.chunk_length, + config.num_chunks_before, + config.num_chunks_after, + config.num_hashes, + config.num_buckets, + config.attention_probs_dropout_prob, + config.hidden_dropout_prob, + config.seed, + config.is_decoder, + ) + gin.parse_config(gin_config) + layer = trax.layers.LSHSelfAttention(mode=mode) + return layer + + def load_local_layer(self, config, mode="eval"): + gin_config = """ + import trax.layers + + # Parameters for SelfAttention: + # ============================================================================== + SelfAttention.n_heads={} + SelfAttention.d_qk={} + SelfAttention.d_v={} + SelfAttention.chunk_len = {} + SelfAttention.n_chunks_before = {} + SelfAttention.n_chunks_after = {} + SelfAttention.attention_dropout = {} + SelfAttention.output_dropout = {} + SelfAttention.causal= {} + """.format( + config.num_attention_heads, + config.attention_head_size, + config.attention_head_size, + config.chunk_length, + config.num_chunks_before, + config.num_chunks_after, + config.attention_probs_dropout_prob, + config.hidden_dropout_prob, + config.is_decoder, + ) + gin.parse_config(gin_config) + layer = trax.layers.SelfAttention(mode=mode) + return layer + + def load_lsh_reformer_lm_model(self, config, mode="eval"): + gin_config = """ + import trax.layers + import trax.models + + # Parameters for LSHSelfAttention: + # ============================================================================== + LSHSelfAttention.chunk_len = {} + LSHSelfAttention.n_chunks_before = {} + LSHSelfAttention.n_chunks_after = {} + LSHSelfAttention.n_hashes = {} + LSHSelfAttention.n_buckets = {} + LSHSelfAttention.lsh_seed = {} + LSHSelfAttention.causal= {} + + # Parameters for ReformerLM: + # ============================================================================== + ReformerLM.vocab_size = {} + ReformerLM.d_model = {} + ReformerLM.d_ff = {} + ReformerLM.d_attention_key = {} + ReformerLM.d_attention_value = {} + ReformerLM.n_layers = {} + ReformerLM.n_heads = {} + ReformerLM.max_len = {} + ReformerLM.axial_pos_shape = {} + ReformerLM.d_axial_pos_embs = {} + ReformerLM.ff_chunk_size = {} + + ReformerLM.n_chunks = 0 + ReformerLM.n_attention_chunks = None + ReformerLM.share_qk = False + ReformerLM.ff_use_sru = 0 + ReformerLM.attention_type = @trax.layers.LSHSelfAttention + ReformerLM.ff_activation = @trax.layers.Gelu + """.format( + config.chunk_length, + config.num_chunks_before, + config.num_chunks_after, + config.num_hashes, + config.num_buckets, + config.seed, + config.is_decoder, + config.vocab_size, + config.hidden_size, + config.feed_forward_size, + config.attention_head_size, + config.attention_head_size, + config.num_hidden_layers, + config.num_attention_heads, + config.max_position_embeddings, + config.axial_pos_shape, + config.axial_pos_embds_dim, + config.chunk_size_feed_forward, + ) + gin.parse_config(gin_config) + model = trax.models.ReformerLM(mode=mode) + return model + + def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode="predict"): + gin.parse_config( + """ + import trax.layers + import trax.models + import trax.optimizers + import trax.supervised.inputs + import trax.supervised.trainer_lib + + # Parameters that will vary between experiments: + # ============================================================================== + train.model = @trax.models.ReformerLM + # Our model will have 6 layers, alternating between the LSH attention proposed + # in the Reformer paper and local attention within a certain context window. + n_layers = 6 + attn_type = [ + @SelfAttention, + @LSHSelfAttention, + @SelfAttention, + @LSHSelfAttention, + @SelfAttention, + @LSHSelfAttention, + ] + share_qk = False # LSH attention ignores this flag and always shares q & k + n_heads = 2 + attn_kv = 64 + dropout = 0.05 + n_tokens = 524288 + + # Parameters for MultifactorSchedule: + # ============================================================================== + MultifactorSchedule.constant = 0.01 + MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay' + MultifactorSchedule.warmup_steps = 100 + MultifactorSchedule.steps_per_cycle = 900 + + # Parameters for Adam: + # ============================================================================== + Adam.weight_decay_rate=0.0 + Adam.b1 = 0.86 + Adam.b2 = 0.92 + Adam.eps = 1e-9 + + # Parameters for SelfAttention: + # ============================================================================== + SelfAttention.attention_dropout = 0.05 + SelfAttention.chunk_len = 64 + SelfAttention.n_chunks_before = 1 + SelfAttention.n_parallel_heads = 1 + + # Parameters for LSHSelfAttention: + # ============================================================================== + LSHSelfAttention.attention_dropout = 0.0 + LSHSelfAttention.chunk_len = 64 + LSHSelfAttention.n_buckets = [64, 128] + LSHSelfAttention.n_chunks_after = 0 + LSHSelfAttention.n_chunks_before = 1 + LSHSelfAttention.n_hashes = 1 + LSHSelfAttention.n_parallel_heads = 1 + LSHSelfAttention.predict_drop_len = 128 + LSHSelfAttention.predict_mem_len = 1024 + LSHSelfAttention.lsh_seed = 0 + + # Parameters for ReformerLM: + # ============================================================================== + ReformerLM.attention_type = %attn_type + ReformerLM.d_attention_key = %attn_kv + ReformerLM.d_attention_value = %attn_kv + ReformerLM.d_model = 256 + ReformerLM.d_ff = 512 + ReformerLM.dropout = %dropout + ReformerLM.ff_activation = @trax.layers.Relu + ReformerLM.max_len = %n_tokens + ReformerLM.mode = 'train' + ReformerLM.n_heads = %n_heads + ReformerLM.n_layers = %n_layers + ReformerLM.vocab_size = 320 + ReformerLM.share_qk = %share_qk + ReformerLM.axial_pos_shape = (512, 1024) + ReformerLM.d_axial_pos_embs= (64, 192) + """ + ) + trax_model = trax.models.ReformerLM(mode=mode) + trax_model.init(input_signature) + trax_model.init_from_file(trax_model_path, weights_only=True) + return trax_model -@require_torch -class ReformerIntegrationTests(unittest.TestCase): def _set_param(self, torch_layer, weight, bias=None): with torch.no_grad(): assert ( @@ -471,54 +683,8 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): pass - def test_lsh_layer(self): - config = ReformerConfig() - shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize - np_input = np.random.rand(*shape) - - trax_utils = TraxUtils(shape) - trax_layer = trax_utils.get_lsh_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer( - np_input, layer=trax_layer - ) - - hf_input = torch.tensor(np_input, dtype=torch.float) - hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) - hf_layer.eval() - - hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] - hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - - trax_torch_output = torch.tensor(np.asarray(trax_output)) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - - def test_local_layer(self): - config = ReformerConfig() - shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize - np_input = np.random.rand(*shape) - - trax_utils = TraxUtils(shape) - trax_layer = trax_utils.get_local_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer( - np_input, layer=trax_layer - ) - - hf_input = torch.tensor(np_input, dtype=torch.float) - config.attn_type = "local" - hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_local( - trax_weights, hf_layer, config.hidden_size - ) - hf_layer.eval() - - hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] - hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - - trax_torch_output = torch.tensor(np.asarray(trax_output)) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - def test_lsh_block(self): + # depreciated config = ReformerConfig() shape = (2, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead @@ -541,154 +707,3 @@ def test_lsh_block(self): self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) - - def test_reformer_lm_model(self): - config = ReformerConfig() - - shape = (2, 14) # Batch x SeqLen x ModelDimPerHead - np_input = np.random.randint(0, config.vocab_size, size=shape) - np_zeros = np.zeros((shape[0], 1), dtype=np.int) - - trax_utils = TraxUtils(shape) - trax_model = trax_utils.get_model(config) - trax_output, trax_weights, trax_state = trax_utils.forward_model( - np_input, model=trax_model - ) - trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) - hf_model = ReformerModelWithLMHead(config) - self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) - hf_model.eval() - - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - - self.assertTrue( - torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) - ) - - def test_pretrained_crime_and_punishment_lm_model(self): - hf_model = ReformerModelWithLMHead.from_pretrained( - "patrickvonplaten/reformer-crime-and-punish" - ) - config = hf_model.config - - trax_model_path = ( - "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" - ) - - shape = (3, 1024) - np_input = np.random.randint(0, config.vocab_size, size=shape) - np_zeros = np.zeros((shape[0], 1), dtype=np.int) - - trax_utils = TraxUtils(shape) - with trax_math.use_backend("jax"): - trax_input = trax_utils.convert_to_jax_array(np_input) - trax_input = (trax_input,) * 2 - - input_signature = trax_utils.get_input_signature(dtype=np.int32) -# input_signature = (input_signature, input_signature) - - trax_model = self.load_model(trax_model_path, input_signature) - - trax_output = trax_model(np_input) - trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - - self.assertTrue( - torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) - ) - - def load_model(self, trax_model_path, input_signature): - gin.parse_config( - """ - import trax.layers - import trax.models - import trax.optimizers - import trax.supervised.inputs - import trax.supervised.trainer_lib - - # Parameters that will vary between experiments: - # ============================================================================== - train.model = @trax.models.ReformerLM - # Our model will have 6 layers, alternating between the LSH attention proposed - # in the Reformer paper and local attention within a certain context window. - n_layers = 6 - attn_type = [ - @SelfAttention, - @LSHSelfAttention, - @SelfAttention, - @LSHSelfAttention, - @SelfAttention, - @LSHSelfAttention, - ] - share_qk = False # LSH attention ignores this flag and always shares q & k - n_heads = 2 - attn_kv = 64 - dropout = 0.05 - n_tokens = 524288 - - # Parameters for MultifactorSchedule: - # ============================================================================== - MultifactorSchedule.constant = 0.01 - MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay' - MultifactorSchedule.warmup_steps = 100 - MultifactorSchedule.steps_per_cycle = 900 - - # Parameters for Adam: - # ============================================================================== - Adam.weight_decay_rate=0.0 - Adam.b1 = 0.86 - Adam.b2 = 0.92 - Adam.eps = 1e-9 - - # Parameters for SelfAttention: - # ============================================================================== - SelfAttention.attention_dropout = 0.05 - SelfAttention.chunk_len = 64 - SelfAttention.n_chunks_before = 1 - SelfAttention.n_parallel_heads = 1 - - # Parameters for LSHSelfAttention: - # ============================================================================== - LSHSelfAttention.attention_dropout = 0.0 - LSHSelfAttention.chunk_len = 64 - LSHSelfAttention.n_buckets = [64, 128] - LSHSelfAttention.n_chunks_after = 0 - LSHSelfAttention.n_chunks_before = 1 - LSHSelfAttention.n_hashes = 1 - LSHSelfAttention.n_parallel_heads = 1 - LSHSelfAttention.predict_drop_len = 128 - LSHSelfAttention.predict_mem_len = 1024 - - # Parameters for ReformerLM: - # ============================================================================== - ReformerLM.attention_type = %attn_type - ReformerLM.d_attention_key = %attn_kv - ReformerLM.d_attention_value = %attn_kv - ReformerLM.d_model = 256 - ReformerLM.d_ff = 512 - ReformerLM.dropout = %dropout - ReformerLM.ff_activation = @trax.layers.Relu - ReformerLM.max_len = %n_tokens - ReformerLM.mode = 'train' - ReformerLM.n_heads = %n_heads - ReformerLM.n_layers = %n_layers - ReformerLM.vocab_size = 320 - ReformerLM.share_qk = %share_qk - ReformerLM.axial_pos_shape = (512, 1024) - ReformerLM.d_axial_pos_embs= (64, 192) - """ - ) - trax_model = trax.models.ReformerLM(mode="predict") - trax_model.init(input_signature) - trax_model.init_from_file(trax_model_path) - return trax_model From ace301f9161939021539da922d72bd269b797297 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 13:08:11 +0200 Subject: [PATCH 053/162] make shorter input length work for model --- src/transformers/activations.py | 3 --- src/transformers/modeling_reformer.py | 20 +++++++++++++++----- tests/test_modeling_reformer.py | 13 ++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 9fa175cfffb6..5429f2fe36ab 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -35,9 +35,6 @@ def gelu_new(x): gelu = F.gelu -gelu = _gelu_python - - ACT2FN = { "relu": F.relu, "swish": swish, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index f2afec73d54d..122c056396ed 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -123,13 +123,23 @@ def __init__(self, config): self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) def forward(self, position_ids): - assert np.prod(self.axial_pos_shape) == position_ids.shape[-1], "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, position_ids.shape[-1]) # broadcast weights to correct shape - broadcasted_weights = [weight.expand((position_ids.shape[0],) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights] + batch_size = position_ids.shape[0] + sequence_length = position_ids.shape[1] - # TODO (PVP): Add dropout when training - # get correct position encodings - position_encodings = torch.cat([torch.reshape(weight, position_ids.shape + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) + broadcasted_weights = [weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights] + + if self.training is True: + assert np.prod(self.axial_pos_shape) == sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, sequence_length) + + # TODO (PVP): Add dropout when training + # get correct position encodings + position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length) + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) + + else: + # reshape axial encodings and use only until sequence_length + position_encodings = torch.cat(broadcasted_weights, dim=-1) + position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[:, :sequence_length] return position_encodings diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 65ddfdcde643..c68ae8a60d25 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -289,22 +289,21 @@ def test_pretrained_crime_and_punishment_lm_model(self): np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + trax_utils = TraxUtils(shape) trax_input = trax_utils.convert_to_jax_array(np_input) input_signature = trax_utils.get_input_signature(dtype=np.int32) - trax_model = self.load_crime_and_punishment_model( trax_model_path, input_signature ) trax_output = trax_model(trax_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - self.assertTrue( torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) ) From 80d18db666d2a60842d2aba06bf81dad11dc7886 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 15:59:46 +0200 Subject: [PATCH 054/162] allow variable input length --- src/transformers/configuration_reformer.py | 15 +-- src/transformers/modeling_reformer.py | 37 ++++++- tests/test_modeling_reformer.py | 111 +++++++++++++-------- 3 files changed, 111 insertions(+), 52 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 94fe0a622871..2fe1815997c9 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -100,11 +100,11 @@ def __init__( vocab_size=200, attention_head_size=32, hidden_size=128, - num_hidden_layers=2, - num_attention_heads=2, + num_hidden_layers=1, + num_attention_heads=1, num_buckets=2, num_hashes=4, - chunk_length=7, + chunk_length=64, num_chunks_before=1, num_chunks_after=0, chunk_size_lm_head=0, @@ -119,12 +119,13 @@ def __init__( layer_norm_eps=1e-12, sinusoidal_pos_embds=False, axial_pos_embds=True, - axial_pos_shape=[7, 2], + axial_pos_shape=[32, 2], axial_pos_embds_dim=[64, 64], - attn_type="lsh", + attn_type="local", + pad_token_id=0, **kwargs ): - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, **kwargs) # TO CHANGE LATER: @@ -137,7 +138,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_hashes = num_hashes - self.num_buckets = num_buckets + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets self.chunk_length = chunk_length self.num_chunks_after = num_chunks_after self.num_chunks_before = num_chunks_before diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 122c056396ed..1353db24a578 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -130,13 +130,14 @@ def forward(self, position_ids): broadcasted_weights = [weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights] if self.training is True: - assert np.prod(self.axial_pos_shape) == sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, sequence_length) + assert int(np.prod(self.axial_pos_shape)) == sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, sequence_length) # TODO (PVP): Add dropout when training # get correct position encodings position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length) + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) else: + assert sequence_length >= int(np.prod(self.axial_pos_shape)), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): {}".format(self.axial_pos_shape, sequence_length) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[:, :sequence_length] @@ -503,6 +504,7 @@ def __init__(self, config): self.num_chunks_after = config.num_chunks_after self.output_attentions = config.output_attentions self.is_decoder = config.is_decoder + self.pad_token_id = config.pad_token_id self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -523,6 +525,17 @@ def forward( sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] + # needs padding +# if sequence_length % self.chunk_length != -1: +# assert not self.training, "In training, the sequnece length: {} has to be a multiple of config.chunk_length: {}".format(sequence_length, self.chunk_length) + # This should be done before +# assert self.is_decoder +# padding_length = self.chunk_length - sequence_length % self.chunk_length + # TODO: need to fix with attention mask here + # TODO: only allow for casaul masking or for all masking? +# hidden_states = torch.cat([hidden_states, torch.full((batch_size, padding_length, hidden_states.shape[-1]), 0, device=hidden_states.device)], dim=1) +# sequence_length = hidden_states.shape[1] + mixed_query_vectors = self.query(hidden_states) mixed_key_vectors = self.key(hidden_states) mixed_value_vectors = self.value(hidden_states) @@ -585,6 +598,7 @@ def forward( out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) +# outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors[:, :3, :],) return outputs @@ -817,7 +831,6 @@ class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config - self.chunk_length = config.chunk_length self.embeddings = ReformerEmbeddings(config) self.encoder = ReformerEncoder(config) @@ -877,8 +890,22 @@ def forward( raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 + real_sequence_length = input_shape[-1] + + # if needs padding + if input_shape[-1] % self.config.chunk_length != 0: + # TODO: should also allow this when self.is_decoder is False? + # TODO: need to improve attn mask input possibility here + assert self.training is False and self.config.is_decoder is True, "Sequence Length {} has to be a multiple of config.chunk_length {} if {}. Please consider padding the input to a length of {}.".format(input_shape[-2], self.config.chunk_length, "training" if self.training is True else "config.is_decoder = True", input_shape[-1] + (self.config.chunk_length - input_shape[-1] % self.config.chunk_length)) + + if input_ids is None: + raise NotImplementedError("Currently only supported for `input_ids`") - assert input_shape[-1] % self.chunk_length == 0, "Sequence Length {} has to be a multiple of config.chunk_length {}. Please consider padding the input to a length of {}.".format(input_shape[-1], self.chunk_length, input_shape[-1] + (self.chunk_length - input_shape[-1] % self.chunk_length)) + padding_length = self.config.chunk_length - input_shape[-1] % self.config.chunk_length + + # Extend `input_ids` with padding to match self.chunk_len + input_ids = torch.cat([input_ids, torch.full((input_shape[0], padding_length), self.config.pad_token_id, device=input_ids.device, dtype=torch.long)], dim=-1) + input_shape = input_ids.size() # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -909,6 +936,10 @@ def forward( ) sequence_output = encoder_outputs[0] + # if padding was applied + if real_sequence_length < input_shape[-1]: + sequence_output = sequence_output[:, :real_sequence_length] + # add hidden_states and attentions if they are here outputs = (sequence_output,) + encoder_outputs[1:] return outputs # sequence_output, pooled_output, (hidden_states), (attentions) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c68ae8a60d25..ce697750246d 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -83,7 +83,7 @@ def forward_layer( self, np_input_data, layer, input_signature=None, random_number_generator=None, ): with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) + input_data = np_input_data if input_signature is None: input_signature = self.get_input_signature() @@ -200,14 +200,15 @@ class ReformerIntegrationTests(unittest.TestCase): def test_lsh_layer(self): config = ReformerConfig() - shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (2, 192, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) - trax_layer = self.load_lsh_layer(config) + trax_layer = self.load_lsh_layer(config, mode="predict") trax_output, trax_weights, trax_state = trax_utils.forward_layer( np_input, layer=trax_layer ) + trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_type = "lsh" @@ -218,16 +219,18 @@ def test_lsh_layer(self): hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - trax_torch_output = torch.tensor(np.asarray(trax_output)) + import ipdb + ipdb.set_trace() + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_local_layer(self): config = ReformerConfig() - shape = (2, 14, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) - trax_layer = self.load_local_layer(config) + trax_layer = self.load_local_layer(config, mode="eval") trax_output, trax_weights, trax_state = trax_utils.forward_layer( np_input, layer=trax_layer ) @@ -246,15 +249,16 @@ def test_local_layer(self): trax_torch_output = torch.tensor(np.asarray(trax_output)) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - def test_lsh_reformer_lm_model(self): + def test_reformer_lm_model(self): config = ReformerConfig() - shape = (2, 14) # Batch x SeqLen x ModelDimPerHead + shape = (1, 64) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) trax_utils = TraxUtils(shape) - trax_model = self.load_lsh_reformer_lm_model(config) +# trax_model = self.load_reformer_lm_model(config, mode="predict") + trax_model = self.load_reformer_lm_model(config) trax_output, trax_weights, trax_state = trax_utils.forward_model( np_input, model=trax_model ) @@ -263,7 +267,6 @@ def test_lsh_reformer_lm_model(self): hf_input = torch.cat( [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 ) - config.attn_type = "lsh" hf_model = ReformerModelWithLMHead(config) self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) hf_model.eval() @@ -271,6 +274,9 @@ def test_lsh_reformer_lm_model(self): hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + import ipdb + ipdb.set_trace() + self.assertTrue( torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) ) @@ -292,8 +298,6 @@ def test_pretrained_crime_and_punishment_lm_model(self): hf_input = torch.cat( [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 ) - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) trax_utils = TraxUtils(shape) trax_input = trax_utils.convert_to_jax_array(np_input) @@ -301,6 +305,10 @@ def test_pretrained_crime_and_punishment_lm_model(self): trax_model = self.load_crime_and_punishment_model( trax_model_path, input_signature ) + + hf_output = hf_model(hf_input) + log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + trax_output = trax_model(trax_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) @@ -314,9 +322,9 @@ def load_lsh_layer(self, config, mode="eval"): # Parameters for LSHSelfAttention: # ============================================================================== - LSHSelfAttention.n_heads={} - LSHSelfAttention.d_qk={} - LSHSelfAttention.d_v={} + LSHSelfAttention.n_heads = {} + LSHSelfAttention.d_qk = {} + LSHSelfAttention.d_v = {} LSHSelfAttention.chunk_len = {} LSHSelfAttention.n_chunks_before = {} LSHSelfAttention.n_chunks_after = {} @@ -326,6 +334,7 @@ def load_lsh_layer(self, config, mode="eval"): LSHSelfAttention.output_dropout = {} LSHSelfAttention.lsh_seed = {} LSHSelfAttention.causal= {} + LSHSelfAttention.use_reference_code = True """.format( config.num_attention_heads, config.attention_head_size, @@ -350,15 +359,16 @@ def load_local_layer(self, config, mode="eval"): # Parameters for SelfAttention: # ============================================================================== - SelfAttention.n_heads={} - SelfAttention.d_qk={} - SelfAttention.d_v={} + SelfAttention.n_heads = {} + SelfAttention.d_qk = {} + SelfAttention.d_v = {} SelfAttention.chunk_len = {} SelfAttention.n_chunks_before = {} SelfAttention.n_chunks_after = {} SelfAttention.attention_dropout = {} SelfAttention.output_dropout = {} - SelfAttention.causal= {} + SelfAttention.causal = {} + SelfAttention.use_reference_code = True """.format( config.num_attention_heads, config.attention_head_size, @@ -374,7 +384,20 @@ def load_local_layer(self, config, mode="eval"): layer = trax.layers.SelfAttention(mode=mode) return layer - def load_lsh_reformer_lm_model(self, config, mode="eval"): + def load_reformer_lm_model(self, config, mode="eval"): + if config.hidden_act == "gelu": + hidden_act = "Gelu" + elif config.hidden_act == "relu": + hidden_act = "Relu" + else: + raise ValueError() + if config.attn_type == "lsh": + attn_type = "LSHSelfAttention" + elif config.attn_type == "local": + attn_type = "SelfAttention" + else: + raise ValueError() + gin_config = """ import trax.layers import trax.models @@ -388,6 +411,15 @@ def load_lsh_reformer_lm_model(self, config, mode="eval"): LSHSelfAttention.n_buckets = {} LSHSelfAttention.lsh_seed = {} LSHSelfAttention.causal= {} + LSHSelfAttention.use_reference_code = True + + # Parameters for SelfAttention: + # ============================================================================== + SelfAttention.chunk_len = {} + SelfAttention.n_chunks_before = {} + SelfAttention.n_chunks_after = {} + SelfAttention.causal= {} + SelfAttention.use_reference_code = True # Parameters for ReformerLM: # ============================================================================== @@ -402,13 +434,13 @@ def load_lsh_reformer_lm_model(self, config, mode="eval"): ReformerLM.axial_pos_shape = {} ReformerLM.d_axial_pos_embs = {} ReformerLM.ff_chunk_size = {} + ReformerLM.ff_activation = @trax.layers.{} + ReformerLM.attention_type = @trax.layers.{} ReformerLM.n_chunks = 0 ReformerLM.n_attention_chunks = None ReformerLM.share_qk = False ReformerLM.ff_use_sru = 0 - ReformerLM.attention_type = @trax.layers.LSHSelfAttention - ReformerLM.ff_activation = @trax.layers.Gelu """.format( config.chunk_length, config.num_chunks_before, @@ -417,6 +449,10 @@ def load_lsh_reformer_lm_model(self, config, mode="eval"): config.num_buckets, config.seed, config.is_decoder, + config.chunk_length, + config.num_chunks_before, + config.num_chunks_after, + config.is_decoder, config.vocab_size, config.hidden_size, config.feed_forward_size, @@ -428,6 +464,8 @@ def load_lsh_reformer_lm_model(self, config, mode="eval"): config.axial_pos_shape, config.axial_pos_embds_dim, config.chunk_size_feed_forward, + hidden_act, + attn_type ) gin.parse_config(gin_config) model = trax.models.ReformerLM(mode=mode) @@ -462,30 +500,14 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode dropout = 0.05 n_tokens = 524288 - # Parameters for MultifactorSchedule: - # ============================================================================== - MultifactorSchedule.constant = 0.01 - MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay' - MultifactorSchedule.warmup_steps = 100 - MultifactorSchedule.steps_per_cycle = 900 - - # Parameters for Adam: - # ============================================================================== - Adam.weight_decay_rate=0.0 - Adam.b1 = 0.86 - Adam.b2 = 0.92 - Adam.eps = 1e-9 - # Parameters for SelfAttention: # ============================================================================== - SelfAttention.attention_dropout = 0.05 SelfAttention.chunk_len = 64 SelfAttention.n_chunks_before = 1 SelfAttention.n_parallel_heads = 1 # Parameters for LSHSelfAttention: # ============================================================================== - LSHSelfAttention.attention_dropout = 0.0 LSHSelfAttention.chunk_len = 64 LSHSelfAttention.n_buckets = [64, 128] LSHSelfAttention.n_chunks_after = 0 @@ -590,10 +612,15 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): ) # lsh weights + output - lsh_weights = weights[0][1] - self._set_layer_weights_in_torch_lsh( - lsh_weights, torch_block.attention, hidden_size - ) + attn_weights = weights[0][1] + if len(attn_weights) < 4: + self._set_layer_weights_in_torch_lsh( + attn_weights, torch_block.attention, hidden_size + ) + else: + self._set_layer_weights_in_torch_local( + attn_weights, torch_block.attention, hidden_size + ) # intermediate weighs intermediate_weights = weights[2][0][2][2] From 86f4ac4d78897a77f844f20f98fa8e9dd872a98b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 17:16:26 +0200 Subject: [PATCH 055/162] refactor --- src/transformers/configuration_reformer.py | 6 ++-- src/transformers/modeling_reformer.py | 13 ++------ tests/test_modeling_reformer.py | 39 +++++++++++----------- 3 files changed, 25 insertions(+), 33 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 2fe1815997c9..33f2460d5e82 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -100,8 +100,8 @@ def __init__( vocab_size=200, attention_head_size=32, hidden_size=128, - num_hidden_layers=1, - num_attention_heads=1, + num_hidden_layers=4, + num_attention_heads=2, num_buckets=2, num_hashes=4, chunk_length=64, @@ -121,7 +121,7 @@ def __init__( axial_pos_embds=True, axial_pos_shape=[32, 2], axial_pos_embds_dim=[64, 64], - attn_type="local", + attn_type="lsh", pad_token_id=0, **kwargs ): diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1353db24a578..1160e84c43b6 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -189,6 +189,7 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): # it is implemented here. Check if this is ok. embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) + embeddings = embeddings + position_embeddings embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) return embeddings @@ -521,21 +522,11 @@ def forward( hidden_states, head_mask=None, ): + # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - # needs padding -# if sequence_length % self.chunk_length != -1: -# assert not self.training, "In training, the sequnece length: {} has to be a multiple of config.chunk_length: {}".format(sequence_length, self.chunk_length) - # This should be done before -# assert self.is_decoder -# padding_length = self.chunk_length - sequence_length % self.chunk_length - # TODO: need to fix with attention mask here - # TODO: only allow for casaul masking or for all masking? -# hidden_states = torch.cat([hidden_states, torch.full((batch_size, padding_length, hidden_states.shape[-1]), 0, device=hidden_states.device)], dim=1) -# sequence_length = hidden_states.shape[1] - mixed_query_vectors = self.query(hidden_states) mixed_key_vectors = self.key(hidden_states) mixed_value_vectors = self.value(hidden_states) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index ce697750246d..a864e2f20c3c 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -204,7 +204,7 @@ def test_lsh_layer(self): np_input = np.random.rand(*shape) trax_utils = TraxUtils(shape) - trax_layer = self.load_lsh_layer(config, mode="predict") + trax_layer = self.load_lsh_layer(config) trax_output, trax_weights, trax_state = trax_utils.forward_layer( np_input, layer=trax_layer ) @@ -219,9 +219,6 @@ def test_lsh_layer(self): hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) - import ipdb - ipdb.set_trace() - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_local_layer(self): @@ -252,21 +249,27 @@ def test_local_layer(self): def test_reformer_lm_model(self): config = ReformerConfig() - shape = (1, 64) # Batch x SeqLen x ModelDimPerHead + shape = (1, 4) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) trax_utils = TraxUtils(shape) -# trax_model = self.load_reformer_lm_model(config, mode="predict") - trax_model = self.load_reformer_lm_model(config) + + mode = "predict" + trax_model = self.load_reformer_lm_model(config, mode=mode) + trax_output, trax_weights, trax_state = trax_utils.forward_model( np_input, model=trax_model ) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) + if mode != "predict": + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + else: + hf_input = torch.tensor(np_input) + hf_model = ReformerModelWithLMHead(config) self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) hf_model.eval() @@ -274,9 +277,6 @@ def test_reformer_lm_model(self): hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - import ipdb - ipdb.set_trace() - self.assertTrue( torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) ) @@ -293,11 +293,8 @@ def test_pretrained_crime_and_punishment_lm_model(self): shape = (1, 192) np_input = np.random.randint(0, config.vocab_size, size=shape) - np_zeros = np.zeros((shape[0], 1), dtype=np.int) - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) + hf_input = torch.tensor(np_input) trax_utils = TraxUtils(shape) trax_input = trax_utils.convert_to_jax_array(np_input) @@ -405,6 +402,8 @@ def load_reformer_lm_model(self, config, mode="eval"): # Parameters for LSHSelfAttention: # ============================================================================== LSHSelfAttention.chunk_len = {} + LSHSelfAttention.predict_mem_len = {} + LSHSelfAttention.predict_drop_len = {} LSHSelfAttention.n_chunks_before = {} LSHSelfAttention.n_chunks_after = {} LSHSelfAttention.n_hashes = {} @@ -443,6 +442,8 @@ def load_reformer_lm_model(self, config, mode="eval"): ReformerLM.ff_use_sru = 0 """.format( config.chunk_length, + config.chunk_length, + config.chunk_length // 2, config.num_chunks_before, config.num_chunks_after, config.num_hashes, @@ -514,8 +515,8 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode LSHSelfAttention.n_chunks_before = 1 LSHSelfAttention.n_hashes = 1 LSHSelfAttention.n_parallel_heads = 1 - LSHSelfAttention.predict_drop_len = 128 - LSHSelfAttention.predict_mem_len = 1024 + LSHSelfAttention.predict_drop_len = 32 # different from original to make code equal + LSHSelfAttention.predict_mem_len = 64 # different from original to make code equal LSHSelfAttention.lsh_seed = 0 # Parameters for ReformerLM: From e571849e39a71e9ae9df0efb49b24293980f3748 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 17:51:25 +0200 Subject: [PATCH 056/162] make forward pass for pretrained model work --- src/transformers/configuration_reformer.py | 2 +- src/transformers/modeling_reformer.py | 3 +- tests/test_modeling_reformer.py | 217 ++------------------- 3 files changed, 20 insertions(+), 202 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 33f2460d5e82..2a865041a880 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -119,7 +119,7 @@ def __init__( layer_norm_eps=1e-12, sinusoidal_pos_embds=False, axial_pos_embds=True, - axial_pos_shape=[32, 2], + axial_pos_shape=[64, 3], axial_pos_embds_dim=[64, 64], attn_type="lsh", pad_token_id=0, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1160e84c43b6..e7bc31909351 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -108,6 +108,7 @@ def __init__(self, config): super().__init__() self.axial_pos_shape = config.axial_pos_shape self.axial_pos_embds_dim = config.axial_pos_embds_dim + self.chunk_length = config.chunk_length self.weights = nn.ParameterList() assert sum(self.axial_pos_embds_dim) == config.hidden_size, "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format(self.axial_pos_embds_dim, config.hidden_size) @@ -137,7 +138,7 @@ def forward(self, position_ids): position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length) + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) else: - assert sequence_length >= int(np.prod(self.axial_pos_shape)), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): {}".format(self.axial_pos_shape, sequence_length) + assert int(np.prod(self.axial_pos_shape)) >= sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format(self.axial_pos_shape, sequence_length, self.chunk_length) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[:, :sequence_length] diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index a864e2f20c3c..13df994355b9 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -18,18 +18,12 @@ # trax imports - to be deleted later import trax -from trax import math as trax_math from trax.shapes import ShapeDtype as trax_ShapeDtype - import gin -import jax -from trax.models.reformer.reformer import DecoderBlock as TraxLSHAttentionBlock -from trax import layers as tl from transformers import ( ReformerAttention, - ReformerLayer, ReformerConfig, ReformerModelWithLMHead, ) @@ -48,153 +42,6 @@ ) -class TraxUtils(object): - """ class that will help for testing in the beginning - should be deleted step-by-step - - README (HOW-TO-INSTALL TRAX): - 1) git clone https://github.com/patrickvonplaten/trax.git - - - I had to do one tiny change to make the imports work, - see: https://github.com/patrickvonplaten/trax/commit/6c23e88afe7f1c57b0c38eeaa4d450e5f912590c) - 2) link your PYTHON_PATH to ~/trax/trax - 3) pip install all the missing packages HINT: the package gin is installed - - - HINT: the package gin is installed with pip install gin-config==0.1.4 - and not pip install gin. - - The other packages can just be installed with pip install form - error message " missing" - """ - - def __init__(self, shape): - self._shape = shape - - def convert_to_jax_array(self, np_array): - return jax.numpy.asarray(np_array) - - def get_input_signature(self, shape=None, dtype=trax_math.numpy.float32): - with trax_math.use_backend("jax"): - if shape is None: - shape = self._shape - input_signature = trax_ShapeDtype(shape, dtype) - return input_signature - - def forward_layer( - self, np_input_data, layer, input_signature=None, random_number_generator=None, - ): - with trax_math.use_backend("jax"): - input_data = np_input_data - - if input_signature is None: - input_signature = self.get_input_signature() - - weights, state = layer.init(input_signature) - - if random_number_generator is None: - random_number_generator = layer.new_rngs(1)[0] - - output = layer( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - def forward_model( - self, - np_input_data, - model, - input_signature=None, - random_number_generator=None, - weights=None, - state=None, - only_init=False, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - input_data = (input_data,) * 2 - - if input_signature is None: - input_signature = self.get_input_signature(dtype=trax_math.numpy.int32) - input_signature = (input_signature, input_signature) - - if weights is None and state is None: - weights, state = model.init(input_signature) - - if only_init is True: - return - - if random_number_generator is None: - random_number_generator = model.new_rngs(1)[0] - - output = model( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - def get_block( - self, - config, - use_reference_code=True, - share_qk=True, - ff_use_sru=0, - mode="eval", - path_to_save_weights=PATH_TO_SAVE_WEIGHTS, - ): - # only works on old master branch - with trax_math.use_backend("jax"): - with jax.disable_jit(): - list_of_layers = TraxLSHAttentionBlock( - d_model=config.hidden_size, - d_ff=config.feed_forward_size, - d_attention_key=config.attention_head_size, - d_attention_value=config.attention_head_size, - n_heads=config.num_attention_heads, - n_attention_chunks=None, - attention_type=tl.LSHSelfAttention, - dropout=config.hidden_dropout_prob, - share_qk=share_qk, - ff_activation=tl.Gelu, - ff_use_sru=ff_use_sru, - ff_chunk_size=config.chunk_size_feed_forward, - mode=mode, - causal=config.is_decoder, - chunk_len=config.chunk_length, - n_chunks_before=config.num_chunks_before, - n_chunks_after=config.num_chunks_after, - n_hashes=config.num_hashes, - n_buckets=config.num_buckets, - use_reference_code=use_reference_code, - hash_seed=config.seed, - path_to_save_weights=path_to_save_weights, - ) - block = tl.Serial(tl.ReversibleSerial([list_of_layers])) - - return block - - def forward_block( - self, np_input_data, block, input_signature=None, random_number_generator=None, - ): - with trax_math.use_backend("jax"): - input_data = self.convert_to_jax_array(np_input_data) - input_data = (input_data,) * 2 - - if input_signature is None: - input_signature = self.get_input_signature() - input_signature = (input_signature, input_signature) - - weights, state = block.init(input_signature) - - if random_number_generator is None: - random_number_generator = block.new_rngs(1)[0] - - output = block( - input_data, weights=weights, state=state, rng=random_number_generator - ) - - return output, weights, state - - @require_torch class ReformerIntegrationTests(unittest.TestCase): @@ -203,11 +50,12 @@ def test_lsh_layer(self): shape = (2, 192, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) - trax_utils = TraxUtils(shape) trax_layer = self.load_lsh_layer(config) - trax_output, trax_weights, trax_state = trax_utils.forward_layer( - np_input, layer=trax_layer - ) + input_signature = trax_ShapeDtype(shape, np.float32) + trax_weights, trax_state = trax_layer.init(input_signature) + + trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) + trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) @@ -226,11 +74,10 @@ def test_local_layer(self): shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) - trax_utils = TraxUtils(shape) - trax_layer = self.load_local_layer(config, mode="eval") - trax_output, trax_weights, trax_state = trax_utils.forward_layer( - np_input, layer=trax_layer - ) + trax_layer = self.load_local_layer(config) + input_signature = trax_ShapeDtype(shape, np.float32) + trax_weights, trax_state = trax_layer.init(input_signature) + trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_type = "local" @@ -249,18 +96,17 @@ def test_local_layer(self): def test_reformer_lm_model(self): config = ReformerConfig() - shape = (1, 4) # Batch x SeqLen x ModelDimPerHead + shape = (1, 192) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) - trax_utils = TraxUtils(shape) - mode = "predict" trax_model = self.load_reformer_lm_model(config, mode=mode) - trax_output, trax_weights, trax_state = trax_utils.forward_model( - np_input, model=trax_model - ) + input_signature = trax_ShapeDtype(shape, np.int32) + trax_weights, trax_state = trax_model.init(input_signature) + trax_output = trax_model(np_input, weights=trax_weights, state=trax_state) + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) if mode != "predict": @@ -291,14 +137,12 @@ def test_pretrained_crime_and_punishment_lm_model(self): "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" ) - shape = (1, 192) + shape = (1, 128) np_input = np.random.randint(0, config.vocab_size, size=shape) hf_input = torch.tensor(np_input) - trax_utils = TraxUtils(shape) - trax_input = trax_utils.convert_to_jax_array(np_input) - input_signature = trax_utils.get_input_signature(dtype=np.int32) + input_signature = trax_ShapeDtype(shape, np.int32) trax_model = self.load_crime_and_punishment_model( trax_model_path, input_signature ) @@ -306,7 +150,7 @@ def test_pretrained_crime_and_punishment_lm_model(self): hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - trax_output = trax_model(trax_input) + trax_output = trax_model(np_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) self.assertTrue( @@ -707,30 +551,3 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), torch.tensor(output_embed_bias), ) - - pass - - def test_lsh_block(self): - # depreciated - config = ReformerConfig() - - shape = (2, 7, config.hidden_size) # Batch x SeqLen x ModelDimPerHead - np_input = np.random.rand(*shape) - - trax_utils = TraxUtils(shape) - trax_block = trax_utils.get_block(config) - trax_output, trax_weights, trax_state = trax_utils.forward_block( - np_input, block=trax_block - ) - trax_torch_output_1 = torch.tensor(np.asarray(trax_output[0])) - trax_torch_output_2 = torch.tensor(np.asarray(trax_output[1])) - - hf_input = torch.tensor(np_input, dtype=torch.float) - hf_block = ReformerLayer(config) - self._set_block_weights_in_torch(trax_weights[0], hf_block, config.hidden_size) - hf_block.eval() - - hf_output_1, hf_output_2 = hf_block(hf_input, hf_input)[:2] - - self.assertTrue(torch.allclose(hf_output_1, trax_torch_output_1, atol=1e-3)) - self.assertTrue(torch.allclose(hf_output_2, trax_torch_output_2, atol=1e-3)) From d5e1363869bd69e00b59d33e70e1bb14d5f230d6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Apr 2020 09:43:03 +0200 Subject: [PATCH 057/162] add generation possibility --- src/transformers/modeling_reformer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e7bc31909351..3bd6dc636e24 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1006,3 +1006,7 @@ def forward( outputs = (loss,) + outputs return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) + + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + # TODO(PVP): Add smart caching + return {"input_ids": input_ids} From 562d530ab517478f3fb2553c5f6112c7e6fb1d59 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Apr 2020 11:22:42 +0200 Subject: [PATCH 058/162] finish dropout and init --- src/transformers/modeling_reformer.py | 78 ++++++++++++++++++--------- tests/test_modeling_reformer.py | 2 +- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3bd6dc636e24..fc66513aef06 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -44,8 +44,6 @@ # "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", # "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", -PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" - def mish(x): return x * torch.tanh(nn.functional.softplus(x)) @@ -102,7 +100,10 @@ def apply_chunking_to_forward(chunk_size: int, chunk_dim: int, forward_fn: Calla return forward_fn(*input_tensors) -class AxialEmbeddings(nn.Module): +class AxialPositionEmbeddings(nn.Module): + """Constructs axial position embeddings. Useful for very long input + sequences to stay memory and computational efficient + """ def __init__(self, config): super().__init__() @@ -110,6 +111,7 @@ def __init__(self, config): self.axial_pos_embds_dim = config.axial_pos_embds_dim self.chunk_length = config.chunk_length self.weights = nn.ParameterList() + self.dropout = config.hidden_dropout_prob assert sum(self.axial_pos_embds_dim) == config.hidden_size, "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format(self.axial_pos_embds_dim, config.hidden_size) @@ -132,10 +134,16 @@ def forward(self, position_ids): if self.training is True: assert int(np.prod(self.axial_pos_shape)) == sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, sequence_length) - - # TODO (PVP): Add dropout when training - # get correct position encodings - position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length) + weight.shape[-1:]) for weight in broadcasted_weights], dim=-1) + if self.dropout > 0: + weights = torch.cat(broadcasted_weights, dim=-1) + # permute weights so that 2D correctly drops dims 1 and 2 + perm_weigthts = weights.permute(0, 3, 2, 1) + # drop entire matrix of last two dims (prev dims 1 and 2) + drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) + drop_weights = drop_perm_weights.permute(0, 3, 2, 1) + position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) + else: + position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], dim=-1) else: assert int(np.prod(self.axial_pos_shape)) >= sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format(self.axial_pos_shape, sequence_length, self.chunk_length) @@ -146,6 +154,26 @@ def forward(self, position_ids): return position_encodings +class PositionEmbeddings(nn.Module): + """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`. + """ + + def __init__(self, config): + self.max_position_embeddings = config.max_position_embeddings + self.dropout = config.hidden_dropout_prob + self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + if self.config.sinusoidal_pos_embds is True: + create_sinusoidal_embeddings( + n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.embedding.weight + ) + + def forward(self, position_ids): + position_embeddings = self.embedding(position_ids) + if self.config.sinusoidal_pos_embds is False: + position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) + return position_embeddings + + class ReformerEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -153,17 +181,9 @@ class ReformerEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.max_position_embeddings = config.max_position_embeddings - + self.position_embeddings = AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) assert not (config.sinusoidal_pos_embds and config.axial_pos_embds), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" - if config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.position_embeddings.weight - ) - elif config.axial_pos_embds: - self.position_embeddings = AxialEmbeddings(config) - + self.max_position_embeddings = config.max_position_embeddings self.dropout = config.hidden_dropout_prob def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): @@ -183,14 +203,10 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): inputs_embeds = self.word_embeddings(input_ids) assert position_ids.shape[-1] <= self.max_position_embeddings, "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format(position_ids.shape[-1], self.max_position_embeddings) + embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) position_embeddings = self.position_embeddings(position_ids) - # TODO(PVP): In https://github.com/google/trax/blob/d947ebc184b94bc48d7e48cc63156a047f00992e/trax/models/reformer/reformer.py#L775 there are two dropouts right after each other as - # it is implemented here. Check if this is ok. - - embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) - embeddings = embeddings + position_embeddings embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) return embeddings @@ -458,8 +474,11 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # where a token has no other valid attention targets (e.g. the first token in a sequence) " logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (num_chunks_before + num_chunks_after)]` dots = torch.exp(query_key_dots - logits) + # TODO(PVP): discuss with thom. Trax uses special dropout here where same dropout mask is applied for all "num_hashes * seq_len // chunk_length" dim. + # should be fine with normal dropout, no? # dropout dots = nn.functional.dropout(dots, self.dropout, self.training) @@ -788,19 +807,26 @@ class ReformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """ Initialize the weights """ - if isinstance(module, AxialEmbeddings): + if isinstance(module, AxialPositionEmbeddings): for weight in module.weights: torch.nn.init.normal_(weight, std=self.config.axial_norm_std) - elif isinstance(module, (nn.Linear, nn.Embedding)): + elif isinstance(module, nn.Embedding): + if module.weight.requires_grad: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + torch.nn.init.xavier_uniform(module.weight) + # TODO(PVP): discuss with Thom if necessary here to use different init +# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.data.normal_(mean=0.0, std=1e-6) + # TODO(PVP): discuss with Thom if necessary here to use different init +# module.bias.data.zero_() class ReformerModel(ReformerPreTrainedModel): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 13df994355b9..db59a277c266 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -100,7 +100,7 @@ def test_reformer_lm_model(self): np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) - mode = "predict" + mode = "train" trax_model = self.load_reformer_lm_model(config, mode=mode) input_signature = trax_ShapeDtype(shape, np.int32) From c98eafeb817cae879820ca886eedaf6bb1a02632 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Apr 2020 11:23:38 +0200 Subject: [PATCH 059/162] make style --- src/transformers/__init__.py | 9 +- src/transformers/configuration_reformer.py | 3 +- src/transformers/modeling_reformer.py | 269 +++++++++++++-------- src/transformers/tokenization_reformer.py | 8 +- tests/test_modeling_reformer.py | 98 +++----- 5 files changed, 202 insertions(+), 185 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 92f4457bd714..6fcad4e06db0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -192,7 +192,7 @@ BertForQuestionAnswering, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - BertLayer + BertLayer, ) from .modeling_openai import ( OpenAIGPTPreTrainedModel, @@ -323,12 +323,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ) - from .modeling_reformer import ( - ReformerAttention, - ReformerLayer, - ReformerModel, - ReformerModelWithLMHead - ) + from .modeling_reformer import ReformerAttention, ReformerLayer, ReformerModel, ReformerModelWithLMHead # Optimization from .optimization import ( diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 2a865041a880..d3f88f52f192 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -127,8 +127,7 @@ def __init__( ): super().__init__(pad_token_id=pad_token_id, **kwargs) - -# TO CHANGE LATER: + # TO CHANGE LATER: self.seed = 0 self.is_decoder = True diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index fc66513aef06..e0f5115b1c1b 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -15,13 +15,14 @@ # limitations under the License. """PyTorch REFORMER model. """ -import logging +import inspect import itertools +import logging +from typing import Callable -import torch +# DELETE later import numpy as np -import inspect - +import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -29,11 +30,6 @@ from .configuration_reformer import ReformerConfig from .modeling_utils import PreTrainedModel -from typing import Callable - -# DELETE later -import numpy - logger = logging.getLogger(__name__) @@ -61,7 +57,9 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out.requires_grad = False -def apply_chunking_to_forward(chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors) -> torch.Tensor: +def apply_chunking_to_forward( + chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors +) -> torch.Tensor: """ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. @@ -79,14 +77,24 @@ def apply_chunking_to_forward(chunk_size: int, chunk_dim: int, forward_fn: Calla assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) tensor_shape = input_tensors[0].shape - assert all(input_tensor.shape == tensor_shape for input_tensor in input_tensors), "All input tenors have to be of the same shape" + assert all( + input_tensor.shape == tensor_shape for input_tensor in input_tensors + ), "All input tenors have to be of the same shape" # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - assert num_args_in_forward_chunk_fn == len(input_tensors), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(num_args_in_forward_chunk_fn, len(input_tensors)) + assert num_args_in_forward_chunk_fn == len( + input_tensors + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( + num_args_in_forward_chunk_fn, len(input_tensors) + ) if chunk_size > 0: - assert input_tensors[0].shape[chunk_dim] % chunk_size == 0, "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(input_tensors[0][chunk_dim], chunk_size) + assert ( + input_tensors[0].shape[chunk_dim] % chunk_size == 0 + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( + input_tensors[0][chunk_dim], chunk_size + ) num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size @@ -113,7 +121,11 @@ def __init__(self, config): self.weights = nn.ParameterList() self.dropout = config.hidden_dropout_prob - assert sum(self.axial_pos_embds_dim) == config.hidden_size, "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format(self.axial_pos_embds_dim, config.hidden_size) + assert ( + sum(self.axial_pos_embds_dim) == config.hidden_size + ), "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format( + self.axial_pos_embds_dim, config.hidden_size + ) # create weights for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): @@ -130,10 +142,16 @@ def forward(self, position_ids): batch_size = position_ids.shape[0] sequence_length = position_ids.shape[1] - broadcasted_weights = [weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights] + broadcasted_weights = [ + weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights + ] if self.training is True: - assert int(np.prod(self.axial_pos_shape)) == sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(self.axial_pos_shape, sequence_length) + assert ( + int(np.prod(self.axial_pos_shape)) == sequence_length + ), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format( + self.axial_pos_shape, sequence_length + ) if self.dropout > 0: weights = torch.cat(broadcasted_weights, dim=-1) # permute weights so that 2D correctly drops dims 1 and 2 @@ -143,13 +161,22 @@ def forward(self, position_ids): drop_weights = drop_perm_weights.permute(0, 3, 2, 1) position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) else: - position_encodings = torch.cat([torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], dim=-1) + position_encodings = torch.cat( + [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], + dim=-1, + ) else: - assert int(np.prod(self.axial_pos_shape)) >= sequence_length, "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format(self.axial_pos_shape, sequence_length, self.chunk_length) + assert ( + int(np.prod(self.axial_pos_shape)) >= sequence_length + ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format( + self.axial_pos_shape, sequence_length, self.chunk_length + ) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) - position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[:, :sequence_length] + position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[ + :, :sequence_length + ] return position_encodings @@ -181,8 +208,12 @@ class ReformerEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) - assert not (config.sinusoidal_pos_embds and config.axial_pos_embds), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" + self.position_embeddings = ( + AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) + ) + assert not ( + config.sinusoidal_pos_embds and config.axial_pos_embds + ), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" self.max_position_embeddings = config.max_position_embeddings self.dropout = config.hidden_dropout_prob @@ -202,7 +233,11 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - assert position_ids.shape[-1] <= self.max_position_embeddings, "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format(position_ids.shape[-1], self.max_position_embeddings) + assert ( + position_ids.shape[-1] <= self.max_position_embeddings + ), "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format( + position_ids.shape[-1], self.max_position_embeddings + ) embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) position_embeddings = self.position_embeddings(position_ids) @@ -294,9 +329,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, - hidden_states, - head_mask=None, + self, hidden_states, head_mask=None, ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -305,8 +338,12 @@ def forward( mixed_query_key_vectors = self.query_key(hidden_states) mixed_value_vectors = self.value(hidden_states) - query_key_vectors = self._transpose_for_scores(mixed_query_key_vectors, self.num_attention_heads, self.attention_head_size) - value_vectors = self._transpose_for_scores(mixed_value_vectors, self.num_attention_heads, self.attention_head_size) + query_key_vectors = self._transpose_for_scores( + mixed_query_key_vectors, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._transpose_for_scores( + mixed_value_vectors, self.num_attention_heads, self.attention_head_size + ) assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size @@ -324,8 +361,12 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) value_vectors = self._gather_by_expansion(value_vectors, ticker) - query_key_vectors = self._split_dim_by(query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + query_key_vectors = self._split_dim_by( + query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._split_dim_by( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) if self.chunk_length is None: @@ -333,11 +374,17 @@ def forward( key_vectors = self._len_and_dim_norm(query_key_vectors) - out_vectors, logits, attention_probs = self._attend(query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask) + out_vectors, logits, attention_probs = self._attend( + query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask + ) if self.num_hashes > 1: - out_vectors = self._split_dim_by(out_vectors, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size) - logits = self._split_dim_by(logits, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size).unsqueeze(-1) + out_vectors = self._split_dim_by( + out_vectors, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + ) + logits = self._split_dim_by( + logits, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + ).unsqueeze(-1) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) @@ -356,14 +403,18 @@ def _hash_vectors(self, vectors): # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. if isinstance(self.num_buckets, int): - assert self.num_buckets % 2 == 0, 'There should be an even number of bucktes, but `self.num_bucktes`: {}'.format(self.num_buckets) + assert ( + self.num_buckets % 2 == 0 + ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format(self.num_buckets) rotation_size = self.num_buckets num_buckets = self.num_buckets else: # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, 'The number of buckets should be even, but `num_bucket`: {}'.format(num_bucket) + assert num_bucket % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( + num_bucket + ) rotation_size += num_bucket num_buckets *= num_bucket @@ -371,15 +422,17 @@ def _hash_vectors(self, vectors): # TODO: delete later when integration tests are ok # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 -# random_rotations = torch.randn(rotations_shape, device=vectors.device) + # random_rotations = torch.randn(rotations_shape, device=vectors.device) numpy.random.seed(self.hash_seed) - random_rotations = torch.tensor(numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device) + random_rotations = torch.tensor( + numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device + ) # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # TODO: IMPORTANT: At the moment we use the same random rotation over all batches # and heads -> is that bad? It seems like in original reformer a different random # rotation is used over heads and batches - rotated_vectors = torch.einsum('bmtd,dhr->bmhtr', vectors, random_rotations) + rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -388,7 +441,7 @@ def _hash_vectors(self, vectors): # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum:cur_sum + (num_bucket // 2)] + rotated_vectors = rotated_vectors[..., cur_sum : cur_sum + (num_bucket // 2)] cur_sum += num_bucket // 2 rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -423,7 +476,7 @@ def _get_ticker_and_undo_ticker(self, sequence_length, buckets): sorted_ticker = torch.argsort(buckets_and_t, dim=-1) undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) - sorted_ticker = (sorted_ticker % sequence_length) + sorted_ticker = sorted_ticker % sequence_length return sorted_ticker, undo_sorted_ticker def _set_num_buckets_on_the_fly(self, sequence_length): @@ -501,11 +554,13 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) - vectors = vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32)) + vectors = vectors / torch.sqrt( + torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32) + ) return vectors def _len_norm(self, x, epsilon=1e-6): - variance = torch.mean(x**2, -1, keepdim=True) + variance = torch.mean(x ** 2, -1, keepdim=True) norm_x = x / torch.sqrt(variance + epsilon) return norm_x @@ -538,9 +593,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, - hidden_states, - head_mask=None, + self, hidden_states, head_mask=None, ): # get SeqLen and BatchSize @@ -551,9 +604,13 @@ def forward( mixed_key_vectors = self.key(hidden_states) mixed_value_vectors = self.value(hidden_states) - query_vectors = self._transpose_for_scores(mixed_query_vectors, self.num_attention_heads, self.attention_head_size) + query_vectors = self._transpose_for_scores( + mixed_query_vectors, self.num_attention_heads, self.attention_head_size + ) key_vectors = self._transpose_for_scores(mixed_key_vectors, self.num_attention_heads, self.attention_head_size) - value_vectors = self._transpose_for_scores(mixed_value_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._transpose_for_scores( + mixed_value_vectors, self.num_attention_heads, self.attention_head_size + ) assert query_vectors.shape[-1] == self.attention_head_size assert key_vectors.shape[-1] == self.attention_head_size @@ -562,17 +619,31 @@ def forward( if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 - key_vectors = key_vectors / torch.sqrt(torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32)) + key_vectors = key_vectors / torch.sqrt( + torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32) + ) # chunk vectors - query_vectors = self._split_dim_by(query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) - key_vectors = self._split_dim_by(key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) - value_vectors = self._split_dim_by(value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + query_vectors = self._split_dim_by( + query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) + key_vectors = self._split_dim_by( + key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._split_dim_by( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) # chunk indices - indices = torch.arange(sequence_length, device=query_vectors.device).repeat(batch_size, self.num_attention_heads, 1) - query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) - key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + indices = torch.arange(sequence_length, device=query_vectors.device).repeat( + batch_size, self.num_attention_heads, 1 + ) + query_indices = self._split_dim_by( + indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) + key_value_indices = self._split_dim_by( + indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + ) # append chunks before and after key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) @@ -585,7 +656,9 @@ def forward( # Causal mask if self.is_decoder: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.lt(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).long().to(query_indices.device) + mask = ( + torch.lt(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).long().to(query_indices.device) + ) query_key_dots = query_key_dots - mask * 1e9 logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) @@ -609,7 +682,7 @@ def forward( out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) -# outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors[:, :3, :],) + # outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors[:, :3, :],) return outputs @@ -625,7 +698,7 @@ def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) # residual connection - output = (hidden_states + input_tensor) + output = hidden_states + input_tensor return output @@ -648,20 +721,19 @@ def __init__(self, config): else: self.self_attention = LSHSelfAttention(config) else: - raise NotImplementedError("config.attn_type: {} does not exist. Select one of ['lsh', 'local', 'mixed'].".format(config.attn_type)) + raise NotImplementedError( + "config.attn_type: {} does not exist. Select one of ['lsh', 'local', 'mixed'].".format( + config.attn_type + ) + ) self.output = ReformerSelfOutput(config) def forward( - self, - hidden_states, - prev_attention_output, - head_mask=None, + self, hidden_states, prev_attention_output, head_mask=None, ): norm_hidden_states = self.layer_norm(hidden_states) - self_attention_outputs = self.self_attention( - norm_hidden_states, head_mask - ) + self_attention_outputs = self.self_attention(norm_hidden_states, head_mask) # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # X_1 = prev_attention_output; X_2 = hidden_states @@ -696,7 +768,7 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - output = (hidden_states + input_tensor) + output = hidden_states + input_tensor return output @@ -711,7 +783,9 @@ def __init__(self, config): # TODO(PVP): Does this work with backpropagation? def forward(self, attention_output, prev_attention_output): - return apply_chunking_to_forward(self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, prev_attention_output) + return apply_chunking_to_forward( + self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, prev_attention_output + ) def forward_chunk(self, attention_output, prev_attention_output): norm_attention_output = self.layer_norm(attention_output) @@ -727,10 +801,7 @@ def __init__(self, config): self.feed_forward = ChunkReformerFeedForward(config) def forward( - self, - prev_attention_output, - hidden_states, - head_mask=None, + self, prev_attention_output, hidden_states, head_mask=None, ): # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) @@ -757,9 +828,7 @@ def __init__(self, config): self.dropout = config.hidden_dropout_prob def forward( - self, - hidden_states, - head_mask=None, + self, hidden_states, head_mask=None, ): all_hidden_states = () all_attentions = () @@ -819,13 +888,15 @@ def _init_weights(self, module): if module.weight.requires_grad: torch.nn.init.xavier_uniform(module.weight) # TODO(PVP): discuss with Thom if necessary here to use different init -# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.normal_(mean=0.0, std=1e-6) # TODO(PVP): discuss with Thom if necessary here to use different init + + # module.bias.data.zero_() @@ -870,12 +941,7 @@ def _prune_heads(self, heads_to_prune): self.encoder.layer[layer].attention.prune_heads(heads) def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, + self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, ): r""" Return: @@ -914,7 +980,14 @@ def forward( if input_shape[-1] % self.config.chunk_length != 0: # TODO: should also allow this when self.is_decoder is False? # TODO: need to improve attn mask input possibility here - assert self.training is False and self.config.is_decoder is True, "Sequence Length {} has to be a multiple of config.chunk_length {} if {}. Please consider padding the input to a length of {}.".format(input_shape[-2], self.config.chunk_length, "training" if self.training is True else "config.is_decoder = True", input_shape[-1] + (self.config.chunk_length - input_shape[-1] % self.config.chunk_length)) + assert ( + self.training is False and self.config.is_decoder is True + ), "Sequence Length {} has to be a multiple of config.chunk_length {} if {}. Please consider padding the input to a length of {}.".format( + input_shape[-2], + self.config.chunk_length, + "training" if self.training is True else "config.is_decoder = True", + input_shape[-1] + (self.config.chunk_length - input_shape[-1] % self.config.chunk_length), + ) if input_ids is None: raise NotImplementedError("Currently only supported for `input_ids`") @@ -922,7 +995,18 @@ def forward( padding_length = self.config.chunk_length - input_shape[-1] % self.config.chunk_length # Extend `input_ids` with padding to match self.chunk_len - input_ids = torch.cat([input_ids, torch.full((input_shape[0], padding_length), self.config.pad_token_id, device=input_ids.device, dtype=torch.long)], dim=-1) + input_ids = torch.cat( + [ + input_ids, + torch.full( + (input_shape[0], padding_length), + self.config.pad_token_id, + device=input_ids.device, + dtype=torch.long, + ), + ], + dim=-1, + ) input_shape = input_ids.size() # Prepare head mask if needed @@ -944,14 +1028,9 @@ def forward( else: head_mask = [None] * self.config.num_hidden_layers - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds - ) + embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) - encoder_outputs = self.encoder( - embedding_output, - head_mask=head_mask, - ) + encoder_outputs = self.encoder(embedding_output, head_mask=head_mask,) sequence_output = encoder_outputs[0] # if padding was applied @@ -1003,19 +1082,11 @@ def tie_weights(self): pass def forward( - self, - input_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - lm_labels=None, + self, input_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, lm_labels=None, ): reformer_outputs = self.reformer( - input_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, + input_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) sequence_output = reformer_outputs[0] diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index c0df33f1a5ab..2ac96b877541 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -41,16 +41,12 @@ # Mapping from the keyword arguments names of Tokenizer `__init__` # to pretrained vocabulary URL for all the model shortcut names. #################################################### -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - } -} +PRETRAINED_VOCAB_FILES_MAP = {"vocab_file": {}} #################################################### # Mapping from model shortcut names to max length of inputs #################################################### -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { -} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} class ReformerTokenizer(PreTrainedTokenizer): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index db59a277c266..fb979bd6b3cb 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -14,22 +14,16 @@ import unittest + +import gin import numpy as np # trax imports - to be deleted later import trax +from transformers import is_torch_available # noqa: F401 +from transformers import ReformerAttention, ReformerConfig, ReformerModelWithLMHead from trax.shapes import ShapeDtype as trax_ShapeDtype -import gin - -from transformers import ( - ReformerAttention, - ReformerConfig, - ReformerModelWithLMHead, -) - - -from transformers import is_torch_available # noqa: F401 from .utils import require_torch, torch_device # noqa: F401 @@ -37,14 +31,11 @@ import torch # noqa: F401 # from transformers.modeling_reformer import () -PATH_TO_SAVE_WEIGHTS = ( - "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" -) +PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" @require_torch class ReformerIntegrationTests(unittest.TestCase): - def test_lsh_layer(self): config = ReformerConfig() shape = (2, 192, config.hidden_size) # Batch x SeqLen x hiddenSize @@ -82,9 +73,7 @@ def test_local_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_type = "local" hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_local( - trax_weights, hf_layer, config.hidden_size - ) + self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] @@ -110,9 +99,7 @@ def test_reformer_lm_model(self): trax_torch_output = torch.tensor(np.asarray(trax_output[0])) if mode != "predict": - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) + hf_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) else: hf_input = torch.tensor(np_input) @@ -123,19 +110,13 @@ def test_reformer_lm_model(self): hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - self.assertTrue( - torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) - ) + self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) def test_pretrained_crime_and_punishment_lm_model(self): - hf_model = ReformerModelWithLMHead.from_pretrained( - "patrickvonplaten/reformer-crime-and-punish" - ) + hf_model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish") config = hf_model.config - trax_model_path = ( - "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" - ) + trax_model_path = "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" shape = (1, 128) np_input = np.random.randint(0, config.vocab_size, size=shape) @@ -143,9 +124,7 @@ def test_pretrained_crime_and_punishment_lm_model(self): hf_input = torch.tensor(np_input) input_signature = trax_ShapeDtype(shape, np.int32) - trax_model = self.load_crime_and_punishment_model( - trax_model_path, input_signature - ) + trax_model = self.load_crime_and_punishment_model(trax_model_path, input_signature) hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) @@ -153,9 +132,7 @@ def test_pretrained_crime_and_punishment_lm_model(self): trax_output = trax_model(np_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - self.assertTrue( - torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) - ) + self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) def load_lsh_layer(self, config, mode="eval"): gin_config = """ @@ -310,7 +287,7 @@ def load_reformer_lm_model(self, config, mode="eval"): config.axial_pos_embds_dim, config.chunk_size_feed_forward, hidden_act, - attn_type + attn_type, ) gin.parse_config(gin_config) model = trax.models.ReformerLM(mode=mode) @@ -389,14 +366,10 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode def _set_param(self, torch_layer, weight, bias=None): with torch.no_grad(): - assert ( - torch_layer.weight.shape == weight.shape - ), "{} layer.weight does not match".format(torch_layer) + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: - assert ( - torch_layer.bias.shape == bias.shape - ), "{} layer.bias does not match".format(torch_layer) + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): @@ -407,18 +380,14 @@ def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): self._set_param( torch_layer.self_attention.query_key, - torch.tensor(np_query_key) - .transpose(1, 2) - .contiguous() - .view(-1, hidden_size), + torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) self._set_param( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) self._set_param( - torch_layer.output.dense, - torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), ) def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size): @@ -433,16 +402,14 @@ def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size): torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), ) self._set_param( - torch_layer.self_attention.key, - torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), + torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) self._set_param( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) self._set_param( - torch_layer.output.dense, - torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), ) def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): @@ -451,21 +418,15 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) self._set_param( - torch_block.attention.layer_norm, - torch.tensor(layer_norm_1_weight), - torch.tensor(layer_norm_1_bias), + torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), ) # lsh weights + output attn_weights = weights[0][1] if len(attn_weights) < 4: - self._set_layer_weights_in_torch_lsh( - attn_weights, torch_block.attention, hidden_size - ) + self._set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size) else: - self._set_layer_weights_in_torch_local( - attn_weights, torch_block.attention, hidden_size - ) + self._set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) # intermediate weighs intermediate_weights = weights[2][0][2][2] @@ -478,9 +439,7 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) self._set_param( - torch_block.feed_forward.layer_norm, - torch.tensor(layer_norm_2_weight), - torch.tensor(layer_norm_2_bias), + torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), ) # intermediate dense @@ -508,20 +467,17 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): # word embeds word_embeddings = np.asarray(weights[1]) self._set_param( - torch_model_reformer.embeddings.word_embeddings, - torch.tensor(word_embeddings), + torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), ) if isinstance(weights[3], tuple): position_embeddings = torch_model_reformer.embeddings.position_embeddings for emb_idx in range(len(position_embeddings.weights)): emb_weights = np.asarray(weights[3][emb_idx][0]) - assert ( - position_embeddings.weights[emb_idx].shape == emb_weights.shape - ), "{} emb does not match".format(position_embeddings[emb_idx]) - position_embeddings.weights[emb_idx] = torch.nn.Parameter( - torch.tensor(emb_weights) + assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( + position_embeddings[emb_idx] ) + position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) trax_layer_weights = weights[5] assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len( From 9c9fab9c014608aeaef5b778158403396bee2dea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Apr 2020 11:26:48 +0200 Subject: [PATCH 060/162] refactor --- src/transformers/modeling_reformer.py | 4 ++-- tests/test_modeling_reformer.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e0f5115b1c1b..e2993bf52652 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -423,9 +423,9 @@ def _hash_vectors(self, vectors): # TODO: delete later when integration tests are ok # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 # random_rotations = torch.randn(rotations_shape, device=vectors.device) - numpy.random.seed(self.hash_seed) + np.random.seed(self.hash_seed) random_rotations = torch.tensor( - numpy.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device + np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device ) # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index fb979bd6b3cb..02bd9ccdef83 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -21,17 +21,15 @@ # trax imports - to be deleted later import trax from transformers import is_torch_available # noqa: F401 -from transformers import ReformerAttention, ReformerConfig, ReformerModelWithLMHead + from trax.shapes import ShapeDtype as trax_ShapeDtype from .utils import require_torch, torch_device # noqa: F401 if is_torch_available(): - import torch # noqa: F401 -# from transformers.modeling_reformer import () - -PATH_TO_SAVE_WEIGHTS = "/home/patrick/hugging_face/experiments/reformer/intermediate_weights" + import torch + from transformers import ReformerAttention, ReformerConfig, ReformerModelWithLMHead @require_torch @@ -89,7 +87,7 @@ def test_reformer_lm_model(self): np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) - mode = "train" + mode = "predict" trax_model = self.load_reformer_lm_model(config, mode=mode) input_signature = trax_ShapeDtype(shape, np.int32) From a188a391f2f6cd341ce0e1a355756d299fb44f51 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Apr 2020 14:27:45 +0200 Subject: [PATCH 061/162] add first version of RevNet Layers --- src/transformers/modeling_reformer.py | 160 ++++++++++++++++++-------- 1 file changed, 111 insertions(+), 49 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e2993bf52652..0878e9fe3d51 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -22,9 +22,14 @@ # DELETE later import numpy as np + +from operator import mul +from functools import reduce + import torch from torch import nn from torch.nn import CrossEntropyLoss +from torch.autograd.function import Function from .activations import gelu, gelu_new, swish from .configuration_reformer import ReformerConfig @@ -148,7 +153,7 @@ def forward(self, position_ids): if self.training is True: assert ( - int(np.prod(self.axial_pos_shape)) == sequence_length + reduce(mul, self.axial_pos_shape) == sequence_length ), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format( self.axial_pos_shape, sequence_length ) @@ -168,7 +173,7 @@ def forward(self, position_ids): else: assert ( - int(np.prod(self.axial_pos_shape)) >= sequence_length + reduce(mul, self.axial_pos_shape) >= sequence_length ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format( self.axial_pos_shape, sequence_length, self.chunk_length ) @@ -694,12 +699,13 @@ def __init__(self, config): self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) self.dropout = config.hidden_dropout_prob - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - # residual connection - output = hidden_states + input_tensor - return output + return hidden_states + +# output = hidden_states + input_tensor +# return output class ReformerAttention(nn.Module): @@ -729,15 +735,11 @@ def __init__(self, config): self.output = ReformerSelfOutput(config) - def forward( - self, hidden_states, prev_attention_output, head_mask=None, - ): + def forward(self, hidden_states, head_mask=None): norm_hidden_states = self.layer_norm(hidden_states) self_attention_outputs = self.self_attention(norm_hidden_states, head_mask) - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - attention_output = self.output(self_attention_outputs[0], prev_attention_output) + attention_output = self.output(self_attention_outputs[0]) outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them return outputs @@ -765,11 +767,10 @@ def __init__(self, config): self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) self.dropout = config.hidden_dropout_prob - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - output = hidden_states + input_tensor - return output + return hidden_states class ChunkReformerFeedForward(nn.Module): @@ -782,15 +783,14 @@ def __init__(self, config): self.output = ReformerFeedForwardOutput(config) # TODO(PVP): Does this work with backpropagation? - def forward(self, attention_output, prev_attention_output): + def forward(self, attention_output): return apply_chunking_to_forward( - self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, prev_attention_output - ) + self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output) - def forward_chunk(self, attention_output, prev_attention_output): + def forward_chunk(self, attention_output): norm_attention_output = self.layer_norm(attention_output) dense_output = self.dense(norm_attention_output) - output = self.output(dense_output, prev_attention_output) + output = self.output(dense_output) return output @@ -801,27 +801,60 @@ def __init__(self, config): self.feed_forward = ChunkReformerFeedForward(config) def forward( - self, prev_attention_output, hidden_states, head_mask=None, + self, prev_attn_output, hidden_states, head_mask=None, ): - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) - # X_1 = prev_attention_output; X_2 = hidden_states - # Y_1 = attention_output; Y_2 = output - attention_outputs = self.attention(hidden_states, prev_attention_output, head_mask) - attention_output = attention_outputs[0] - output = self.feed_forward(attention_output, hidden_states) + with torch.no_grad(): + # is this _init_seed needed? in https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py + attn_outputs = self.attention(hidden_states, head_mask) + attn_output = attn_outputs[0] + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # Y_1 = X_1 + f(X_2) + attn_output = prev_attn_output + attn_output + # Y_2 = X_2 + g(Y_1) + output = hidden_states = self.feed_forward(attn_output) - outputs = (attention_output, output) + attention_outputs[1:] + outputs = (attn_output, output) + attn_outputs[1:] return outputs + def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states, head_mask=None): + + with torch.enable_grad(): + next_attn_output.requires_grad = True + attn_output = self.attention(next_attn_output, head_mask) + torch.autograd.backward(attn_output, grad_hidden_states) + + with torch.no_grad(): + hidden_states = hidden_states - attn_output + del attn_output + + grad_next_attn_output = grad_next_attn_output + next_attn_output.grad + next_attn_output.grad = None + + with torch.enable_grad(): + hidden_states.requires_grad = True + output = self.feed_forward(hidden_states) + torch.autograd.backward(output, grad_next_attn_output, retain_graph=True) + + with torch.no_grad(): + next_attn_output = next_attn_output - output + del output + + grad_hidden_states = grad_hidden_states + hidden_states.grad + hidden_states.grad = None + hidden_states = hidden_states.detach() + + return next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states + class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) @@ -830,26 +863,8 @@ def __init__(self, config): def forward( self, hidden_states, head_mask=None, ): - all_hidden_states = () - all_attentions = () - attention_output = hidden_states - - for i, layer_module in enumerate(self.layer): - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (attention_output, hidden_states) - - layer_outputs = layer_module(attention_output, hidden_states, head_mask[i]) - attention_output, hidden_states = layer_outputs[:2] - - if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) - - # Add last layer - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (attention_output, hidden_states) - - # Concatenate 2 RevNet outputs - hidden_states = torch.cat([attention_output, hidden_states], dim=-1) + # Make this work + hidden_states, all_hidden_states, all_attentions = _ReversibleFunction.apply(hidden_states, self.layers, head_mask) # Apply layer norm to concatenated hidden states hidden_states = self.layer_norm(hidden_states) @@ -865,6 +880,53 @@ def forward( return outputs # last-layer hidden state, (all hidden states), (all attentions) +class _ReversibleFunction(Function): + + # Make this work + @staticmethod + def forward(ctx, hidden_states, layers, head_mask, output_hidden_states, output_attentions): + all_hidden_states = () + all_attentions = () + ctx.head_mask = head_mask + + attn_output = hidden_states + for i, layer in enumerate(layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (attn_output, hidden_states) + + layer_outputs = layer(attn_output, hidden_states, head_mask[i]) + attn_output, hidden_states = layer_outputs[:2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (attn_output, hidden_states) + + ctx.attn_output = attn_output.detach() + ctx.hidden_states = hidden_states.detach() + ctx.layers = layers + + # Concatenate 2 RevNet outputs + hidden_states = torch.cat([attn_output, hidden_states], dim=-1) + return hidden_states, all_hidden_states, all_attentions + + @staticmethod + def backward(ctx, grad_hidden_states): + grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) + attn_output = ctx.attn_output + hidden_states = ctx.hidden_states + head_mask = ctx.head_mask + + for i, layer in enumerate(ctx.layers[::-1]): + attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask[i]) + + grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) + + return grad_hidden_states, None, None + + class ReformerPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. From 8047573a1cd20ff990ce56f87e8523f2de1a5008 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 18 Apr 2020 21:09:55 +0200 Subject: [PATCH 062/162] make forward pass work and add convert file --- src/transformers/configuration_reformer.py | 2 +- ...ert_reformer_trax_checkpoint_to_pytorch.py | 215 ++++++++++++++++++ src/transformers/modeling_reformer.py | 121 +++++----- 3 files changed, 275 insertions(+), 63 deletions(-) create mode 100755 src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index d3f88f52f192..3eb8734d3887 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -121,7 +121,7 @@ def __init__( axial_pos_embds=True, axial_pos_shape=[64, 3], axial_pos_embds_dim=[64, 64], - attn_type="lsh", + attn_type="local", pad_token_id=0, **kwargs ): diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..bfea6935e9d6 --- /dev/null +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BERT checkpoint.""" + + +import argparse +import logging + +import numpy as np +import torch + +from tensorflow.compat.v1.io.gfile import GFile +import pickle + +from transformers import ReformerConfig, ReformerModelWithLMHead + + +logging.basicConfig(level=logging.INFO) + + +def set_param(torch_layer, weight, bias=None): + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) + torch_layer.weight = torch.nn.Parameter(weight) + if bias is not None: + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) + torch_layer.bias = torch.nn.Parameter(bias) + + +def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + np_dense = np.asarray(weights[2]) + + set_param( + torch_layer.self_attention.query_key, + torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.value, + torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) + + +def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + np_query = np.asarray(weights[0]) + np_key = np.asarray(weights[1]) + np_value = np.asarray(weights[2]) + np_dense = np.asarray(weights[3]) + + set_param( + torch_layer.self_attention.query, + torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.value, + torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) + + +def set_block_weights_in_torch(weights, torch_block, hidden_size): + # layernorm 1 + layer_norm_1 = weights[0][0][0] + layer_norm_1_weight = np.asarray(layer_norm_1[0]) + layer_norm_1_bias = np.asarray(layer_norm_1[1]) + set_param( + torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), + ) + + # lsh weights + output + attn_weights = weights[0][1] + if len(attn_weights) < 4: + set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size) + else: + set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) + + # intermediate weighs + intermediate_weights = weights[2][0][2][2] + + # Chunked Feed Forward + if len(intermediate_weights) == 4: + intermediate_weights = intermediate_weights[2] + + # layernorm 2 + layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) + layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) + set_param( + torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), + ) + + # intermediate dense + inter_dense_weight = np.asarray(intermediate_weights[1][0]) + inter_dense_bias = np.asarray(intermediate_weights[1][1]) + set_param( + torch_block.feed_forward.dense.dense, + torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(inter_dense_bias), + ) + + # intermediate out + out_dense_weight = np.asarray(intermediate_weights[4][0]) + out_dense_bias = np.asarray(intermediate_weights[4][1]) + set_param( + torch_block.feed_forward.output.dense, + torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(out_dense_bias), + ) + + +def set_model_weights_in_torch(weights, torch_model, hidden_size): + # reformer model + torch_model_reformer = torch_model.reformer + + # word embeds + word_embeddings = np.asarray(weights[1]) + set_param( + torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), + ) + + if isinstance(weights[3], tuple): + position_embeddings = torch_model_reformer.embeddings.position_embeddings + for emb_idx in range(len(position_embeddings.weights)): + emb_weights = np.asarray(weights[3][emb_idx][0]) + assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( + position_embeddings[emb_idx] + ) + position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + + trax_layer_weights = weights[5] + assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( + trax_layer_weights + ), "HF and trax model do not have the same number of layers" + for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): + block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] + set_block_weights_in_torch(block_weights, layer, hidden_size) + + # output weights + out_weights = weights[6] + + # output layer norm + layer_norm_out_weight = np.asarray(out_weights[0][0]) + layer_norm_out_bias = np.asarray(out_weights[0][1]) + set_param( + torch_model_reformer.encoder.layer_norm, + torch.tensor(layer_norm_out_weight), + torch.tensor(layer_norm_out_bias), + ) + + # output embeddings + output_embed_weights = np.asarray(out_weights[2][0]) + output_embed_bias = np.asarray(out_weights[2][1]) + set_param( + torch_model.lm_head.decoder, + torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), + torch.tensor(output_embed_bias), + ) + + +def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = ReformerConfig.from_json_file(config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = ReformerModelWithLMHead(config) + + with GFile(trax_model_pkl_path, 'rb') as f: + model_weights = pickle.load(f)['weights'] + + set_model_weights_in_torch(model_weights, model, config.hidden_size) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained Reformer model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 0878e9fe3d51..ceae2b8d98fc 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -320,7 +320,6 @@ def __init__(self, config): self.chunk_length = config.chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after - self.output_attentions = config.output_attentions self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings @@ -334,7 +333,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, + self, hidden_states, head_mask=None, do_output_attentions=False ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -397,7 +396,7 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) - outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) + outputs = (out_vectors, attention_probs) if do_output_attentions else (out_vectors,) return outputs @@ -583,7 +582,6 @@ def __init__(self, config): self.chunk_length = config.chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after - self.output_attentions = config.output_attentions self.is_decoder = config.is_decoder self.pad_token_id = config.pad_token_id @@ -598,7 +596,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, + self, hidden_states, head_mask=None, do_output_attentions=False, ): # get SeqLen and BatchSize @@ -686,8 +684,7 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) - outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors,) - # outputs = (out_vectors, attention_probs) if self.output_attentions else (out_vectors[:, :3, :],) + outputs = (out_vectors, attention_probs) if do_output_attentions else (out_vectors,) return outputs @@ -735,9 +732,9 @@ def __init__(self, config): self.output = ReformerSelfOutput(config) - def forward(self, hidden_states, head_mask=None): + def forward(self, hidden_states, head_mask=None, do_output_attentions=False): norm_hidden_states = self.layer_norm(hidden_states) - self_attention_outputs = self.self_attention(norm_hidden_states, head_mask) + self_attention_outputs = self.self_attention(norm_hidden_states, head_mask, do_output_attentions) attention_output = self.output(self_attention_outputs[0]) outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them @@ -801,19 +798,19 @@ def __init__(self, config): self.feed_forward = ChunkReformerFeedForward(config) def forward( - self, prev_attn_output, hidden_states, head_mask=None, + self, prev_attn_output, hidden_states, head_mask=None, do_output_attentions=False ): with torch.no_grad(): # is this _init_seed needed? in https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py - attn_outputs = self.attention(hidden_states, head_mask) + attn_outputs = self.attention(hidden_states, head_mask, do_output_attentions) attn_output = attn_outputs[0] # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # Y_1 = X_1 + f(X_2) attn_output = prev_attn_output + attn_output # Y_2 = X_2 + g(Y_1) - output = hidden_states = self.feed_forward(attn_output) + output = hidden_states + self.feed_forward(attn_output) outputs = (attn_output, output) + attn_outputs[1:] @@ -849,60 +846,27 @@ def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, return next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states -class ReformerEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.output_attentions = config.output_attentions - self.output_hidden_states = config.output_hidden_states - self.layers = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) - # Reformer is using Rev Nets, thus last layer outputs are concatenated and - # Layer Norm is done over 2 * hidden_size - self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) - self.dropout = config.hidden_dropout_prob - - def forward( - self, hidden_states, head_mask=None, - ): - # Make this work - hidden_states, all_hidden_states, all_attentions = _ReversibleFunction.apply(hidden_states, self.layers, head_mask) - - # Apply layer norm to concatenated hidden states - hidden_states = self.layer_norm(hidden_states) - - # Apply dropout - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - - outputs = (hidden_states,) - if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) - if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) - - class _ReversibleFunction(Function): - - # Make this work @staticmethod - def forward(ctx, hidden_states, layers, head_mask, output_hidden_states, output_attentions): - all_hidden_states = () - all_attentions = () + def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): ctx.head_mask = head_mask attn_output = hidden_states for i, layer in enumerate(layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (attn_output, hidden_states) + if do_output_hidden_states is True: + all_hidden_states.append(attn_output) + all_hidden_states.append(hidden_states) - layer_outputs = layer(attn_output, hidden_states, head_mask[i]) + layer_outputs = layer(attn_output, hidden_states, head_mask[i], do_output_attentions) attn_output, hidden_states = layer_outputs[:2] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) + if do_output_attentions: + all_attentions.append(layer_outputs[2]) # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (attn_output, hidden_states) + if do_output_hidden_states is True: + all_hidden_states.append(attn_output) + all_hidden_states.append(hidden_states) ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() @@ -910,7 +874,7 @@ def forward(ctx, hidden_states, layers, head_mask, output_hidden_states, output_ # Concatenate 2 RevNet outputs hidden_states = torch.cat([attn_output, hidden_states], dim=-1) - return hidden_states, all_hidden_states, all_attentions + return hidden_states @staticmethod def backward(ctx, grad_hidden_states): @@ -927,6 +891,39 @@ def backward(ctx, grad_hidden_states): return grad_hidden_states, None, None +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.hidden_dropout_prob + + def forward( + self, hidden_states, head_mask=None, do_output_hidden_states=False, do_output_attentions=False, + ): + # hidden_states and attention lists to be filled if wished + all_hidden_states = [] + all_attentions = [] + + # Make this work + hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) + + # Apply layer norm to concatenated hidden states + hidden_states = self.layer_norm(hidden_states) + + # Apply dropout + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + + outputs = (hidden_states,) + if do_output_hidden_states: + outputs = outputs + (tuple(all_hidden_states),) + if do_output_attentions: + outputs = outputs + (tuple(all_attentions),) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + class ReformerPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -948,7 +945,7 @@ def _init_weights(self, module): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: - torch.nn.init.xavier_uniform(module.weight) + torch.nn.init.xavier_uniform_(module.weight) # TODO(PVP): discuss with Thom if necessary here to use different init # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, ReformerLayerNorm): @@ -1003,7 +1000,7 @@ def _prune_heads(self, heads_to_prune): self.encoder.layer[layer].attention.prune_heads(heads) def forward( - self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, + self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, do_output_hidden_states=False, do_output_attentions=False ): r""" Return: @@ -1015,7 +1012,7 @@ def forward( of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. @@ -1092,7 +1089,7 @@ def forward( embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) - encoder_outputs = self.encoder(embedding_output, head_mask=head_mask,) + encoder_outputs = self.encoder(embedding_output, head_mask, do_output_hidden_states, do_output_attentions) sequence_output = encoder_outputs[0] # if padding was applied @@ -1144,11 +1141,11 @@ def tie_weights(self): pass def forward( - self, input_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, lm_labels=None, + self, input_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False ): reformer_outputs = self.reformer( - input_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, + input_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions ) sequence_output = reformer_outputs[0] From 31a596b40b8e96d34ebc7fb265e0d8ab3035eb40 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 18 Apr 2020 21:25:41 +0200 Subject: [PATCH 063/162] make uploaded model forward pass work --- tests/test_modeling_reformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 02bd9ccdef83..128caf68e18a 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -54,7 +54,7 @@ def test_lsh_layer(self): hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] - hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) + hf_output = hf_layer.output(hf_attention_all_heads) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) @@ -75,7 +75,7 @@ def test_local_layer(self): hf_layer.eval() hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] - hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input)) + hf_output = hf_layer.output(hf_attention_all_heads) trax_torch_output = torch.tensor(np.asarray(trax_output)) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) @@ -478,10 +478,10 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) trax_layer_weights = weights[5] - assert len(torch_model_reformer.encoder.layer) * 4 + 1 == len( + assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( trax_layer_weights ), "HF and trax model do not have the same number of layers" - for layer_idx, layer in enumerate(torch_model_reformer.encoder.layer): + for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] self._set_block_weights_in_torch(block_weights, layer, hidden_size) From bae0700f36b2d828d1345e14a7432e70d03ef9b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 18 Apr 2020 21:25:53 +0200 Subject: [PATCH 064/162] make uploaded model forward pass work --- src/transformers/modeling_reformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ceae2b8d98fc..6e859d11936f 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -954,8 +954,6 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.normal_(mean=0.0, std=1e-6) # TODO(PVP): discuss with Thom if necessary here to use different init - - # module.bias.data.zero_() From 831dcec9bbb40d21f209774b528a649b5a7e7422 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sat, 18 Apr 2020 23:41:47 +0200 Subject: [PATCH 065/162] refactor code --- src/transformers/modeling_reformer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 6e859d11936f..8e7de6098c56 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -820,7 +820,7 @@ def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, with torch.enable_grad(): next_attn_output.requires_grad = True - attn_output = self.attention(next_attn_output, head_mask) + attn_output = self.attention(next_attn_output, head_mask)[0] torch.autograd.backward(attn_output, grad_hidden_states) with torch.no_grad(): @@ -849,9 +849,8 @@ def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, class _ReversibleFunction(Function): @staticmethod def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): - ctx.head_mask = head_mask - attn_output = hidden_states + hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) for i, layer in enumerate(layers): if do_output_hidden_states is True: all_hidden_states.append(attn_output) @@ -868,6 +867,8 @@ def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attent all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) + # attach params to ctx for backward + ctx.head_mask = head_mask ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() ctx.layers = layers @@ -879,16 +880,21 @@ def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attent @staticmethod def backward(ctx, grad_hidden_states): grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) + + # retrieve params from ctx for backward attn_output = ctx.attn_output hidden_states = ctx.hidden_states head_mask = ctx.head_mask + layers = ctx.layers - for i, layer in enumerate(ctx.layers[::-1]): + for i, layer in enumerate(layers[::-1]): attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask[i]) grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) - return grad_hidden_states, None, None + # num of return vars has to match num of forward() args + # return gradient for hidden_states arg and None for other args + return grad_hidden_states, None, None, None, None, None, None class ReformerEncoder(nn.Module): @@ -908,6 +914,7 @@ def forward( all_attentions = [] # Make this work + hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) # Apply layer norm to concatenated hidden states From 57ee09c0625cccd1d78688e6af3196051aa0758a Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sun, 19 Apr 2020 12:34:38 +0200 Subject: [PATCH 066/162] add namedtuples and cache buckets --- src/transformers/configuration_reformer.py | 6 +- src/transformers/modeling_reformer.py | 86 +++++++++++++++------- tests/test_modeling_reformer.py | 10 ++- 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 3eb8734d3887..bce72731ab6b 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -117,11 +117,11 @@ def __init__( initializer_range=0.02, axial_norm_std=1.0, layer_norm_eps=1e-12, - sinusoidal_pos_embds=False, - axial_pos_embds=True, + sinusoidal_pos_embds=True, + axial_pos_embds=False, axial_pos_shape=[64, 3], axial_pos_embds_dim=[64, 64], - attn_type="local", + attn_type="lsh", pad_token_id=0, **kwargs ): diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 8e7de6098c56..2022dcedbf66 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -26,6 +26,8 @@ from operator import mul from functools import reduce +from collections import namedtuple + import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -54,6 +56,12 @@ def mish(x): ReformerLayerNorm = torch.nn.LayerNorm +LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) +AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) + + def create_sinusoidal_embeddings(n_pos, dim, out): position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) @@ -191,7 +199,8 @@ class PositionEmbeddings(nn.Module): """ def __init__(self, config): - self.max_position_embeddings = config.max_position_embeddings + super().__init__() + self.config = config self.dropout = config.hidden_dropout_prob self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) if self.config.sinusoidal_pos_embds is True: @@ -333,7 +342,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, do_output_attentions=False + self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -356,8 +365,13 @@ def forward( # set `num_buckets` on the fly self._set_num_buckets_on_the_fly(sequence_length) - # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors) + # use cached buckets for backprop + buckets = kwargs["buckets"] if "buckets" in kwargs else None + + if buckets is None: + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors) + assert int(buckets.shape[-1]) == self.num_hashes * sequence_length ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) @@ -396,9 +410,11 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) - outputs = (out_vectors, attention_probs) if do_output_attentions else (out_vectors,) - return outputs + if do_output_attentions is False: + attention_probs = () + + return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) def _hash_vectors(self, vectors): batch_size = vectors.shape[0] @@ -596,7 +612,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, do_output_attentions=False, + self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs ): # get SeqLen and BatchSize @@ -684,9 +700,11 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) - outputs = (out_vectors, attention_probs) if do_output_attentions else (out_vectors,) - return outputs + if do_output_attentions is False: + attention_probs = () + + return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) class ReformerSelfOutput(nn.Module): @@ -701,9 +719,6 @@ def forward(self, hidden_states): hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) return hidden_states -# output = hidden_states + input_tensor -# return output - class ReformerAttention(nn.Module): @@ -732,13 +747,21 @@ def __init__(self, config): self.output = ReformerSelfOutput(config) - def forward(self, hidden_states, head_mask=None, do_output_attentions=False): + def forward(self, hidden_states, head_mask=None, do_output_attentions=False, buckets=None): norm_hidden_states = self.layer_norm(hidden_states) - self_attention_outputs = self.self_attention(norm_hidden_states, head_mask, do_output_attentions) - attention_output = self.output(self_attention_outputs[0]) - outputs = (attention_output,) + self_attention_outputs[1:] # add attentions if we output them - return outputs + # use cached buckets for backprob if buckets not None for LSHSelfAttention + self_attention_outputs = self.self_attention(norm_hidden_states, head_mask, do_output_attentions, buckets=buckets) + + attention_output = self.output(self_attention_outputs.hidden_states) + + # add buckets if necessary + if hasattr(self_attention_outputs, "buckets"): + buckets = self_attention_outputs.buckets + else: + buckets = None + + return AttentionOutput(hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets) class ReformerFeedForwardDense(nn.Module): @@ -804,23 +827,23 @@ def forward( with torch.no_grad(): # is this _init_seed needed? in https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py attn_outputs = self.attention(hidden_states, head_mask, do_output_attentions) - attn_output = attn_outputs[0] + attn_output = attn_outputs.hidden_states # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # Y_1 = X_1 + f(X_2) attn_output = prev_attn_output + attn_output # Y_2 = X_2 + g(Y_1) - output = hidden_states + self.feed_forward(attn_output) - - outputs = (attn_output, output) + attn_outputs[1:] + hidden_states = hidden_states + self.feed_forward(attn_output) - return outputs + return ReformerOutput(attn_output=attn_output, hidden_states=hidden_states, attention_probs=attn_outputs.attention_probs, buckets=attn_outputs.buckets) - def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states, head_mask=None): + def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states, head_mask=None, buckets=None): with torch.enable_grad(): next_attn_output.requires_grad = True - attn_output = self.attention(next_attn_output, head_mask)[0] + + # use cached buckets for backprob if buckets not None for LSHSelfAttention + attn_output = self.attention(next_attn_output, head_mask, buckets=buckets).hidden_states torch.autograd.backward(attn_output, grad_hidden_states) with torch.no_grad(): @@ -850,6 +873,7 @@ class _ReversibleFunction(Function): @staticmethod def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): + all_buckets = () hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) for i, layer in enumerate(layers): if do_output_hidden_states is True: @@ -857,10 +881,12 @@ def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attent all_hidden_states.append(hidden_states) layer_outputs = layer(attn_output, hidden_states, head_mask[i], do_output_attentions) - attn_output, hidden_states = layer_outputs[:2] + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + all_buckets = all_buckets + (layer_outputs.buckets,) if do_output_attentions: - all_attentions.append(layer_outputs[2]) + all_attentions.append(layer_outputs.attention_probs) # Add last layer if do_output_hidden_states is True: @@ -872,6 +898,7 @@ def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attent ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() ctx.layers = layers + ctx.all_buckets = all_buckets # Concatenate 2 RevNet outputs hidden_states = torch.cat([attn_output, hidden_states], dim=-1) @@ -886,9 +913,14 @@ def backward(ctx, grad_hidden_states): hidden_states = ctx.hidden_states head_mask = ctx.head_mask layers = ctx.layers + all_buckets = ctx.all_buckets for i, layer in enumerate(layers[::-1]): - attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask[i]) + # pop last buckets + buckets = all_buckets[-1] + all_buckets = all_buckets[:-1] + + attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=head_mask[i], buckets=buckets) grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 128caf68e18a..85647f5ff05c 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -213,6 +213,12 @@ def load_reformer_lm_model(self, config, mode="eval"): attn_type = "SelfAttention" else: raise ValueError() + if config.sinusoidal_pos_embds is True: + axial_pos_shape = () + d_axial_pos_embs = None + else: + axial_pos_shape = config.axial_pos_shape + d_axial_pos_embs = config.axial_pos_embds_dim gin_config = """ import trax.layers @@ -281,8 +287,8 @@ def load_reformer_lm_model(self, config, mode="eval"): config.num_hidden_layers, config.num_attention_heads, config.max_position_embeddings, - config.axial_pos_shape, - config.axial_pos_embds_dim, + axial_pos_shape, + d_axial_pos_embs, config.chunk_size_feed_forward, hidden_act, attn_type, From 2d23fadee962bb9a5f75cd1fd14aaa75705b4576 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 19 Apr 2020 12:56:22 +0200 Subject: [PATCH 067/162] correct head masks --- src/transformers/modeling_reformer.py | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 2022dcedbf66..082cf154c44f 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -871,16 +871,16 @@ def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, class _ReversibleFunction(Function): @staticmethod - def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): + def forward(ctx, hidden_states, layers, head_masks, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): all_buckets = () hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) - for i, layer in enumerate(layers): + for layer, head_mask in zip(layers, head_masks): if do_output_hidden_states is True: all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) - layer_outputs = layer(attn_output, hidden_states, head_mask[i], do_output_attentions) + layer_outputs = layer(attn_output, hidden_states, head_mask, do_output_attentions) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states all_buckets = all_buckets + (layer_outputs.buckets,) @@ -894,7 +894,7 @@ def forward(ctx, hidden_states, layers, head_mask, all_hidden_states, all_attent all_hidden_states.append(hidden_states) # attach params to ctx for backward - ctx.head_mask = head_mask + ctx.head_masks = head_masks ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() ctx.layers = layers @@ -911,16 +911,16 @@ def backward(ctx, grad_hidden_states): # retrieve params from ctx for backward attn_output = ctx.attn_output hidden_states = ctx.hidden_states - head_mask = ctx.head_mask + head_masks = ctx.head_masks layers = ctx.layers all_buckets = ctx.all_buckets - for i, layer in enumerate(layers[::-1]): + for layer, head_mask in zip(layers[::-1], head_masks[::-1]): # pop last buckets buckets = all_buckets[-1] all_buckets = all_buckets[:-1] - attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=head_mask[i], buckets=buckets) + attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=head_mask, buckets=buckets) grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) @@ -939,7 +939,7 @@ def __init__(self, config): self.dropout = config.hidden_dropout_prob def forward( - self, hidden_states, head_mask=None, do_output_hidden_states=False, do_output_attentions=False, + self, hidden_states, head_masks=None, do_output_hidden_states=False, do_output_attentions=False, ): # hidden_states and attention lists to be filled if wished all_hidden_states = [] @@ -947,7 +947,7 @@ def forward( # Make this work hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) - hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_mask, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) + hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_masks, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) # Apply layer norm to concatenated hidden states hidden_states = self.layer_norm(hidden_states) @@ -1037,7 +1037,7 @@ def _prune_heads(self, heads_to_prune): self.encoder.layer[layer].attention.prune_heads(heads) def forward( - self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, do_output_hidden_states=False, do_output_attentions=False + self, input_ids=None, attention_mask=None, position_ids=None, head_masks=None, inputs_embeds=None, do_output_hidden_states=False, do_output_attentions=False ): r""" Return: @@ -1110,23 +1110,23 @@ def forward( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = ( - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_mask for each layer - head_mask = head_mask.to( + if head_masks is not None: + if head_masks.dim() == 1: + head_masks = head_masks.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_masks = head_masks.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_masks.dim() == 2: + head_masks = ( + head_masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_masks for each layer + head_masks = head_masks.to( dtype=next(self.parameters()).dtype ) # switch to fload if need + fp16 compatibility else: - head_mask = [None] * self.config.num_hidden_layers + head_masks = [None] * self.config.num_hidden_layers embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) - encoder_outputs = self.encoder(embedding_output, head_mask, do_output_hidden_states, do_output_attentions) + encoder_outputs = self.encoder(embedding_output, head_masks, do_output_hidden_states, do_output_attentions) sequence_output = encoder_outputs[0] # if padding was applied @@ -1178,11 +1178,11 @@ def tie_weights(self): pass def forward( - self, input_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False + self, input_ids=None, position_ids=None, head_masks=None, inputs_embeds=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False ): reformer_outputs = self.reformer( - input_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions + input_ids, position_ids=position_ids, head_masks=head_masks, inputs_embeds=inputs_embeds, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions ) sequence_output = reformer_outputs[0] From 0c35bbffe35f5c14e94f47be60925ec35d225ecb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 19 Apr 2020 14:11:13 +0200 Subject: [PATCH 068/162] refactor --- src/transformers/modeling_reformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 082cf154c44f..a6445f951458 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -342,7 +342,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs + self, hidden_states, head_mask=None, do_output_attentions=False, buckets=None, **kwargs ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -365,9 +365,7 @@ def forward( # set `num_buckets` on the fly self._set_num_buckets_on_the_fly(sequence_length) - # use cached buckets for backprop - buckets = kwargs["buckets"] if "buckets" in kwargs else None - + # use cached buckets for backprop only if buckets is None: # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors) From 232463e6815e94d1b027ba462606a852f9d2f7bf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 20 Apr 2020 00:56:05 +0200 Subject: [PATCH 069/162] made reformer more flexible --- src/transformers/configuration_reformer.py | 13 +- src/transformers/modeling_reformer.py | 168 ++++++++++++--------- tests/test_modeling_reformer.py | 31 ++-- 3 files changed, 122 insertions(+), 90 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index bce72731ab6b..734a7fb64c57 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -100,11 +100,11 @@ def __init__( vocab_size=200, attention_head_size=32, hidden_size=128, - num_hidden_layers=4, num_attention_heads=2, num_buckets=2, num_hashes=4, - chunk_length=64, + lsh_attn_chunk_length=64, + local_attn_chunk_length=64, num_chunks_before=1, num_chunks_after=0, chunk_size_lm_head=0, @@ -121,7 +121,7 @@ def __init__( axial_pos_embds=False, axial_pos_shape=[64, 3], axial_pos_embds_dim=[64, 64], - attn_type="lsh", + attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0, **kwargs ): @@ -134,11 +134,12 @@ def __init__( self.vocab_size = vocab_size self.attention_head_size = attention_head_size self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_hashes = num_hashes + self.num_hidden_layers = len(attn_layers) self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets - self.chunk_length = chunk_length + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.local_attn_chunk_length = local_attn_chunk_length self.num_chunks_after = num_chunks_after self.num_chunks_before = num_chunks_before self.hidden_act = hidden_act @@ -155,4 +156,4 @@ def __init__( self.axial_norm_std = axial_norm_std self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward - self.attn_type = attn_type + self.attn_layers = attn_layers diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index a6445f951458..8f42dc3091a1 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -16,7 +16,6 @@ """PyTorch REFORMER model. """ import inspect -import itertools import logging from typing import Callable @@ -70,6 +69,18 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out.requires_grad = False +def _get_least_common_mult_chunk_len(config): + attn_types = config.attn_layers + if len(set(attn_types)) == 1 and attn_types[0] == "lsh": + return config.lsh_attn_chunk_length + elif len(set(attn_types)) == 1 and attn_types[0] == "local": + return config.local_attn_chunk_length + elif len(set(attn_types)) == 2 and set(attn_types) == set(["lsh", "local"]): + return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) + else: + raise NotImplementedError("Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(config.attn_layers)) + + def apply_chunking_to_forward( chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors ) -> torch.Tensor: @@ -130,7 +141,7 @@ def __init__(self, config): super().__init__() self.axial_pos_shape = config.axial_pos_shape self.axial_pos_embds_dim = config.axial_pos_embds_dim - self.chunk_length = config.chunk_length + self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config) self.weights = nn.ParameterList() self.dropout = config.hidden_dropout_prob @@ -182,8 +193,8 @@ def forward(self, position_ids): else: assert ( reduce(mul, self.axial_pos_shape) >= sequence_length - ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, config.chunk_length): max({}, {})".format( - self.axial_pos_shape, sequence_length, self.chunk_length + ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( + self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length ) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) @@ -326,7 +337,7 @@ def __init__(self, config): self.hash_seed = config.seed self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets - self.chunk_length = config.chunk_length + self.chunk_length = config.lsh_attn_chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after self.is_decoder = config.is_decoder @@ -342,11 +353,13 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, do_output_attentions=False, buckets=None, **kwargs + self, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False, buckets=None, **kwargs ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] + # num hashes can optionally be overwritten by user + num_hashes = num_hashes if num_hashes is not None else self.num_hashes mixed_query_key_vectors = self.query_key(hidden_states) mixed_value_vectors = self.value(hidden_states) @@ -368,14 +381,14 @@ def forward( # use cached buckets for backprop only if buckets is None: # hash query key vectors into buckets - buckets = self._hash_vectors(query_key_vectors) + buckets = self._hash_vectors(query_key_vectors, num_hashes) - assert int(buckets.shape[-1]) == self.num_hashes * sequence_length + assert int(buckets.shape[-1]) == num_hashes * sequence_length - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets) + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets, num_hashes) - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker) - value_vectors = self._gather_by_expansion(value_vectors, ticker) + query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker, num_hashes) + value_vectors = self._gather_by_expansion(value_vectors, ticker, num_hashes) query_key_vectors = self._split_dim_by( query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size @@ -394,12 +407,12 @@ def forward( query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask ) - if self.num_hashes > 1: + if num_hashes > 1: out_vectors = self._split_dim_by( - out_vectors, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size ) logits = self._split_dim_by( - logits, self.num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size ).unsqueeze(-1) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) @@ -414,7 +427,7 @@ def forward( return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) - def _hash_vectors(self, vectors): + def _hash_vectors(self, vectors, num_hashes): batch_size = vectors.shape[0] # See https://arxiv.org/pdf/1509.02897.pdf @@ -436,10 +449,10 @@ def _hash_vectors(self, vectors): rotation_size += num_bucket num_buckets *= num_bucket - rotations_shape = (vectors.shape[-1], self.num_hashes, rotation_size // 2) + rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) # TODO: delete later when integration tests are ok - # create a random self.attention_head_size x self.num_hashes x self.num_buckets/2 + # create a random self.attention_head_size x num_hashes x self.num_buckets/2 # random_rotations = torch.randn(rotations_shape, device=vectors.device) np.random.seed(self.hash_seed) random_rotations = torch.tensor( @@ -472,7 +485,7 @@ def _hash_vectors(self, vectors): # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. - offsets = torch.arange(self.num_hashes, device=vectors.device) + offsets = torch.arange(num_hashes, device=vectors.device) offsets = torch.reshape(offsets * num_buckets, (-1, 1)) # repeat same values for Batch_size and Num_Attn_Heads @@ -481,11 +494,11 @@ def _hash_vectors(self, vectors): return offset_buckets - def _get_ticker_and_undo_ticker(self, sequence_length, buckets): + def _get_ticker_and_undo_ticker(self, sequence_length, buckets, num_hashes): batch_size = buckets.shape[0] # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(self.num_hashes * sequence_length, device=buckets.device) + ticker = torch.arange(num_hashes * sequence_length, device=buckets.device) ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) buckets_and_t = sequence_length * buckets + (ticker % sequence_length) @@ -582,9 +595,9 @@ def _len_norm(self, x, epsilon=1e-6): norm_x = x / torch.sqrt(variance + epsilon) return norm_x - def _gather_by_expansion(self, vectors, idxs): + def _gather_by_expansion(self, vectors, idxs, num_hashes): expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - vectors = vectors.repeat(1, 1, self.num_hashes, 1) + vectors = vectors.repeat(1, 1, num_hashes, 1) return torch.gather(vectors, 2, expanded_idxs) @@ -593,7 +606,7 @@ def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads - self.chunk_length = config.chunk_length + self.chunk_length = config.local_attn_chunk_length self.num_chunks_before = config.num_chunks_before self.num_chunks_after = config.num_chunks_after self.is_decoder = config.is_decoder @@ -720,36 +733,31 @@ def forward(self, hidden_states): class ReformerAttention(nn.Module): - layer_id_iter = itertools.count() - - def __init__(self, config): + def __init__(self, config, layer_id=0): super().__init__() - self.layer_id = next(self.layer_id_iter) + self.layer_id = layer_id self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn_layers = config.attn_layers - if config.attn_type == "lsh": + if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config) - elif config.attn_type == "local": + elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": self.self_attention = LocalSelfAttention(config) - elif config.attn_type == "mixed": - if self.layer_id % 2 == 0: - self.self_attention = LocalSelfAttention(config) - else: + elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): + # get correct attn_layers + if self.attn_layers[self.layer_id] == "lsh": self.self_attention = LSHSelfAttention(config) + else: + self.self_attention = LocalSelfAttention(config) else: - raise NotImplementedError( - "config.attn_type: {} does not exist. Select one of ['lsh', 'local', 'mixed'].".format( - config.attn_type - ) - ) - + raise NotImplementedError("Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(self.attn_layers)) self.output = ReformerSelfOutput(config) - def forward(self, hidden_states, head_mask=None, do_output_attentions=False, buckets=None): + def forward(self, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False, buckets=None): norm_hidden_states = self.layer_norm(hidden_states) # use cached buckets for backprob if buckets not None for LSHSelfAttention - self_attention_outputs = self.self_attention(norm_hidden_states, head_mask, do_output_attentions, buckets=buckets) + self_attention_outputs = self.self_attention(hidden_states=norm_hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, buckets=buckets) attention_output = self.output(self_attention_outputs.hidden_states) @@ -800,7 +808,6 @@ def __init__(self, config): self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) - # TODO(PVP): Does this work with backpropagation? def forward(self, attention_output): return apply_chunking_to_forward( self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output) @@ -813,18 +820,17 @@ def forward_chunk(self, attention_output): class ReformerLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_id=0): super().__init__() - self.attention = ReformerAttention(config) + self.attention = ReformerAttention(config, layer_id) self.feed_forward = ChunkReformerFeedForward(config) def forward( - self, prev_attn_output, hidden_states, head_mask=None, do_output_attentions=False + self, prev_attn_output, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False ): with torch.no_grad(): - # is this _init_seed needed? in https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py - attn_outputs = self.attention(hidden_states, head_mask, do_output_attentions) + attn_outputs = self.attention(hidden_states=hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions) attn_output = attn_outputs.hidden_states # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) @@ -835,41 +841,52 @@ def forward( return ReformerOutput(attn_output=attn_output, hidden_states=hidden_states, attention_probs=attn_outputs.attention_probs, buckets=attn_outputs.buckets) - def backward_pass(self, next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states, head_mask=None, buckets=None): + def backward_pass(self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=None, buckets=None): + # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py with torch.enable_grad(): next_attn_output.requires_grad = True - # use cached buckets for backprob if buckets not None for LSHSelfAttention - attn_output = self.attention(next_attn_output, head_mask, buckets=buckets).hidden_states + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # g(Y_1) + attn_output = self.feed_forward(next_attn_output) torch.autograd.backward(attn_output, grad_hidden_states) +# attn_output = self.attention(hidden_states=next_attn_output, head_mask=head_mask, buckets=buckets).hidden_states with torch.no_grad(): + # X_2 = Y_2 - g(Y_1) hidden_states = hidden_states - attn_output del attn_output - grad_next_attn_output = grad_next_attn_output + next_attn_output.grad + grad_attn_output = grad_attn_output + next_attn_output.grad next_attn_output.grad = None with torch.enable_grad(): hidden_states.requires_grad = True - output = self.feed_forward(hidden_states) - torch.autograd.backward(output, grad_next_attn_output, retain_graph=True) + # f(X_2) + # use cached buckets for backprob if buckets not None for LSHSelfAttention + output = self.attention(hidden_states=hidden_states, head_mask=head_mask, buckets=buckets).hidden_states + torch.autograd.backward(output, grad_attn_output, retain_graph=True) with torch.no_grad(): - next_attn_output = next_attn_output - output - del output + # X_1 = Y_1 - f(X_2) + attn_output = next_attn_output - output + del output, next_attn_output grad_hidden_states = grad_hidden_states + hidden_states.grad hidden_states.grad = None hidden_states = hidden_states.detach() - return next_attn_output, hidden_states, grad_next_attn_output, grad_hidden_states + return attn_output, hidden_states, grad_attn_output, grad_hidden_states class _ReversibleFunction(Function): + """ + This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + """ + @staticmethod - def forward(ctx, hidden_states, layers, head_masks, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): + def forward(ctx, hidden_states, layers, head_masks, num_hashes, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): all_buckets = () hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) @@ -878,7 +895,7 @@ def forward(ctx, hidden_states, layers, head_masks, all_hidden_states, all_atten all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) - layer_outputs = layer(attn_output, hidden_states, head_mask, do_output_attentions) + layer_outputs = layer(prev_attn_output=attn_output, hidden_states=hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states all_buckets = all_buckets + (layer_outputs.buckets,) @@ -914,30 +931,31 @@ def backward(ctx, grad_hidden_states): all_buckets = ctx.all_buckets for layer, head_mask in zip(layers[::-1], head_masks[::-1]): - # pop last buckets + # pop last buckets from stack buckets = all_buckets[-1] all_buckets = all_buckets[:-1] - attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=head_mask, buckets=buckets) + # backprop + attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(next_attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, head_mask=head_mask, buckets=buckets) grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args - return grad_hidden_states, None, None, None, None, None, None + return grad_hidden_states, None, None, None, None, None, None, None class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() - self.layers = nn.ModuleList([ReformerLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) self.dropout = config.hidden_dropout_prob def forward( - self, hidden_states, head_masks=None, do_output_hidden_states=False, do_output_attentions=False, + self, hidden_states, head_masks=None, num_hashes=None, do_output_hidden_states=False, do_output_attentions=False, ): # hidden_states and attention lists to be filled if wished all_hidden_states = [] @@ -945,7 +963,7 @@ def forward( # Make this work hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) - hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_masks, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) + hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_masks, num_hashes, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) # Apply layer norm to concatenated hidden states hidden_states = self.layer_norm(hidden_states) @@ -1014,6 +1032,7 @@ class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config + assert self.config.num_hidden_layers > 0, "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" self.embeddings = ReformerEmbeddings(config) self.encoder = ReformerEncoder(config) @@ -1035,7 +1054,7 @@ def _prune_heads(self, heads_to_prune): self.encoder.layer[layer].attention.prune_heads(heads) def forward( - self, input_ids=None, attention_mask=None, position_ids=None, head_masks=None, inputs_embeds=None, do_output_hidden_states=False, do_output_attentions=False + self, input_ids=None, attention_mask=None, position_ids=None, head_masks=None, inputs_embeds=None, num_hashes=None, do_output_hidden_states=False, do_output_attentions=False ): r""" Return: @@ -1071,24 +1090,25 @@ def forward( real_sequence_length = input_shape[-1] # if needs padding - if input_shape[-1] % self.config.chunk_length != 0: + least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) + if input_shape[-1] % least_common_mult_chunk_length != 0: # TODO: should also allow this when self.is_decoder is False? # TODO: need to improve attn mask input possibility here assert ( self.training is False and self.config.is_decoder is True - ), "Sequence Length {} has to be a multiple of config.chunk_length {} if {}. Please consider padding the input to a length of {}.".format( + ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if {}. Please consider padding the input to a length of {}.".format( input_shape[-2], - self.config.chunk_length, + least_common_mult_chunk_length, "training" if self.training is True else "config.is_decoder = True", - input_shape[-1] + (self.config.chunk_length - input_shape[-1] % self.config.chunk_length), + input_shape[-1] + (least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length), ) if input_ids is None: raise NotImplementedError("Currently only supported for `input_ids`") - padding_length = self.config.chunk_length - input_shape[-1] % self.config.chunk_length + padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length - # Extend `input_ids` with padding to match self.chunk_len + # Extend `input_ids` with padding to match least common multiple chunk_length input_ids = torch.cat( [ input_ids, @@ -1124,7 +1144,7 @@ def forward( embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) - encoder_outputs = self.encoder(embedding_output, head_masks, do_output_hidden_states, do_output_attentions) + encoder_outputs = self.encoder(hidden_states=embedding_output, head_masks=head_masks, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions) sequence_output = encoder_outputs[0] # if padding was applied @@ -1176,11 +1196,11 @@ def tie_weights(self): pass def forward( - self, input_ids=None, position_ids=None, head_masks=None, inputs_embeds=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False + self, input_ids=None, position_ids=None, head_masks=None, inputs_embeds=None, num_hashes=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False ): reformer_outputs = self.reformer( - input_ids, position_ids=position_ids, head_masks=head_masks, inputs_embeds=inputs_embeds, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions + input_ids, position_ids=position_ids, head_masks=head_masks, inputs_embeds=inputs_embeds, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions ) sequence_output = reformer_outputs[0] diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 85647f5ff05c..c35ba0df5513 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -48,7 +48,7 @@ def test_lsh_layer(self): trax_torch_output = torch.tensor(np.asarray(trax_output)) hf_input = torch.tensor(np_input, dtype=torch.float) - config.attn_type = "lsh" + config.attn_layers = ["lsh"] hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() @@ -69,7 +69,7 @@ def test_local_layer(self): trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) hf_input = torch.tensor(np_input, dtype=torch.float) - config.attn_type = "local" + config.attn_layers = ["local"] hf_layer = ReformerAttention(config) self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() @@ -110,6 +110,16 @@ def test_reformer_lm_model(self): self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) + def test_backprop_lm_model(self): + config = ReformerConfig() + + shape = (1, 192) # Batch x SeqLen x ModelDimPerHead + input_ids = torch.tensor(np.random.randint(0, config.vocab_size, size=shape), dtype=torch.long, device=torch_device) + + model = ReformerModelWithLMHead(config) + loss = model(input_ids, lm_labels=input_ids)[0] + loss.backward() + def test_pretrained_crime_and_punishment_lm_model(self): hf_model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish") config = hf_model.config @@ -155,7 +165,7 @@ def load_lsh_layer(self, config, mode="eval"): config.num_attention_heads, config.attention_head_size, config.attention_head_size, - config.chunk_length, + config.lsh_attn_chunk_length, config.num_chunks_before, config.num_chunks_after, config.num_hashes, @@ -189,7 +199,7 @@ def load_local_layer(self, config, mode="eval"): config.num_attention_heads, config.attention_head_size, config.attention_head_size, - config.chunk_length, + config.local_attn_chunk_length, config.num_chunks_before, config.num_chunks_after, config.attention_probs_dropout_prob, @@ -207,9 +217,10 @@ def load_reformer_lm_model(self, config, mode="eval"): hidden_act = "Relu" else: raise ValueError() - if config.attn_type == "lsh": + attn_type = config.attn_layers[0] + if attn_type == "lsh": attn_type = "LSHSelfAttention" - elif config.attn_type == "local": + elif attn_type == "local": attn_type = "SelfAttention" else: raise ValueError() @@ -266,16 +277,16 @@ def load_reformer_lm_model(self, config, mode="eval"): ReformerLM.share_qk = False ReformerLM.ff_use_sru = 0 """.format( - config.chunk_length, - config.chunk_length, - config.chunk_length // 2, + config.lsh_attn_chunk_length, + config.lsh_attn_chunk_length, + config.lsh_attn_chunk_length // 2, config.num_chunks_before, config.num_chunks_after, config.num_hashes, config.num_buckets, config.seed, config.is_decoder, - config.chunk_length, + config.local_attn_chunk_length, config.num_chunks_before, config.num_chunks_after, config.is_decoder, From 2648a940fc2a827b6bb091de2c42dda153cbedcb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 20 Apr 2020 00:56:52 +0200 Subject: [PATCH 070/162] make style --- ...ert_reformer_trax_checkpoint_to_pytorch.py | 16 +- src/transformers/modeling_reformer.py | 156 ++++++++++++++---- tests/test_modeling_reformer.py | 5 +- 3 files changed, 132 insertions(+), 45 deletions(-) diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py index bfea6935e9d6..6f51addf97aa 100755 --- a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -17,12 +17,11 @@ import argparse import logging +import pickle import numpy as np import torch - from tensorflow.compat.v1.io.gfile import GFile -import pickle from transformers import ReformerConfig, ReformerModelWithLMHead @@ -49,8 +48,7 @@ def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size): torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( - torch_layer.self_attention.value, - torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), @@ -65,15 +63,13 @@ def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size): np_dense = np.asarray(weights[3]) set_param( - torch_layer.self_attention.query, - torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), + torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( - torch_layer.self_attention.value, - torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), @@ -184,8 +180,8 @@ def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch print("Building PyTorch model from configuration: {}".format(str(config))) model = ReformerModelWithLMHead(config) - with GFile(trax_model_pkl_path, 'rb') as f: - model_weights = pickle.load(f)['weights'] + with GFile(trax_model_pkl_path, "rb") as f: + model_weights = pickle.load(f)["weights"] set_model_weights_in_torch(model_weights, model, config.hidden_size) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 8f42dc3091a1..54960fc2b524 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -17,20 +17,17 @@ import inspect import logging +from collections import namedtuple +from functools import reduce +from operator import mul from typing import Callable # DELETE later import numpy as np - -from operator import mul -from functools import reduce - -from collections import namedtuple - import torch from torch import nn -from torch.nn import CrossEntropyLoss from torch.autograd.function import Function +from torch.nn import CrossEntropyLoss from .activations import gelu, gelu_new, swish from .configuration_reformer import ReformerConfig @@ -78,7 +75,11 @@ def _get_least_common_mult_chunk_len(config): elif len(set(attn_types)) == 2 and set(attn_types) == set(["lsh", "local"]): return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) else: - raise NotImplementedError("Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(config.attn_layers)) + raise NotImplementedError( + "Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format( + config.attn_layers + ) + ) def apply_chunking_to_forward( @@ -622,9 +623,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob - def forward( - self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs - ): + def forward(self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -732,7 +731,6 @@ def forward(self, hidden_states): class ReformerAttention(nn.Module): - def __init__(self, config, layer_id=0): super().__init__() self.layer_id = layer_id @@ -750,14 +748,24 @@ def __init__(self, config, layer_id=0): else: self.self_attention = LocalSelfAttention(config) else: - raise NotImplementedError("Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(self.attn_layers)) + raise NotImplementedError( + "Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format( + self.attn_layers + ) + ) self.output = ReformerSelfOutput(config) def forward(self, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False, buckets=None): norm_hidden_states = self.layer_norm(hidden_states) # use cached buckets for backprob if buckets not None for LSHSelfAttention - self_attention_outputs = self.self_attention(hidden_states=norm_hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, buckets=buckets) + self_attention_outputs = self.self_attention( + hidden_states=norm_hidden_states, + head_mask=head_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + buckets=buckets, + ) attention_output = self.output(self_attention_outputs.hidden_states) @@ -767,7 +775,9 @@ def forward(self, hidden_states, head_mask=None, num_hashes=None, do_output_atte else: buckets = None - return AttentionOutput(hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets) + return AttentionOutput( + hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets + ) class ReformerFeedForwardDense(nn.Module): @@ -810,7 +820,8 @@ def __init__(self, config): def forward(self, attention_output): return apply_chunking_to_forward( - self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output) + self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output + ) def forward_chunk(self, attention_output): norm_attention_output = self.layer_norm(attention_output) @@ -825,12 +836,15 @@ def __init__(self, config, layer_id=0): self.attention = ReformerAttention(config, layer_id) self.feed_forward = ChunkReformerFeedForward(config) - def forward( - self, prev_attn_output, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False - ): + def forward(self, prev_attn_output, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False): with torch.no_grad(): - attn_outputs = self.attention(hidden_states=hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions) + attn_outputs = self.attention( + hidden_states=hidden_states, + head_mask=head_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + ) attn_output = attn_outputs.hidden_states # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) @@ -839,9 +853,16 @@ def forward( # Y_2 = X_2 + g(Y_1) hidden_states = hidden_states + self.feed_forward(attn_output) - return ReformerOutput(attn_output=attn_output, hidden_states=hidden_states, attention_probs=attn_outputs.attention_probs, buckets=attn_outputs.buckets) + return ReformerOutput( + attn_output=attn_output, + hidden_states=hidden_states, + attention_probs=attn_outputs.attention_probs, + buckets=attn_outputs.buckets, + ) - def backward_pass(self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=None, buckets=None): + def backward_pass( + self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=None, buckets=None + ): # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py with torch.enable_grad(): @@ -851,7 +872,7 @@ def backward_pass(self, next_attn_output, hidden_states, grad_attn_output, grad_ # g(Y_1) attn_output = self.feed_forward(next_attn_output) torch.autograd.backward(attn_output, grad_hidden_states) -# attn_output = self.attention(hidden_states=next_attn_output, head_mask=head_mask, buckets=buckets).hidden_states + # attn_output = self.attention(hidden_states=next_attn_output, head_mask=head_mask, buckets=buckets).hidden_states with torch.no_grad(): # X_2 = Y_2 - g(Y_1) @@ -886,7 +907,17 @@ class _ReversibleFunction(Function): """ @staticmethod - def forward(ctx, hidden_states, layers, head_masks, num_hashes, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions): + def forward( + ctx, + hidden_states, + layers, + head_masks, + num_hashes, + all_hidden_states, + all_attentions, + do_output_hidden_states, + do_output_attentions, + ): all_buckets = () hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) @@ -895,7 +926,13 @@ def forward(ctx, hidden_states, layers, head_masks, num_hashes, all_hidden_state all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) - layer_outputs = layer(prev_attn_output=attn_output, hidden_states=hidden_states, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions) + layer_outputs = layer( + prev_attn_output=attn_output, + hidden_states=hidden_states, + head_mask=head_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + ) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states all_buckets = all_buckets + (layer_outputs.buckets,) @@ -936,7 +973,14 @@ def backward(ctx, grad_hidden_states): all_buckets = all_buckets[:-1] # backprop - attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass(next_attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, head_mask=head_mask, buckets=buckets) + attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass( + next_attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + head_mask=head_mask, + buckets=buckets, + ) grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) @@ -955,7 +999,12 @@ def __init__(self, config): self.dropout = config.hidden_dropout_prob def forward( - self, hidden_states, head_masks=None, num_hashes=None, do_output_hidden_states=False, do_output_attentions=False, + self, + hidden_states, + head_masks=None, + num_hashes=None, + do_output_hidden_states=False, + do_output_attentions=False, ): # hidden_states and attention lists to be filled if wished all_hidden_states = [] @@ -963,7 +1012,16 @@ def forward( # Make this work hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) - hidden_states = _ReversibleFunction.apply(hidden_states, self.layers, head_masks, num_hashes, all_hidden_states, all_attentions, do_output_hidden_states, do_output_attentions) + hidden_states = _ReversibleFunction.apply( + hidden_states, + self.layers, + head_masks, + num_hashes, + all_hidden_states, + all_attentions, + do_output_hidden_states, + do_output_attentions, + ) # Apply layer norm to concatenated hidden states hidden_states = self.layer_norm(hidden_states) @@ -1009,6 +1067,8 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.normal_(mean=0.0, std=1e-6) # TODO(PVP): discuss with Thom if necessary here to use different init + + # module.bias.data.zero_() @@ -1032,7 +1092,9 @@ class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config - assert self.config.num_hidden_layers > 0, "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" + assert ( + self.config.num_hidden_layers > 0 + ), "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" self.embeddings = ReformerEmbeddings(config) self.encoder = ReformerEncoder(config) @@ -1054,7 +1116,15 @@ def _prune_heads(self, heads_to_prune): self.encoder.layer[layer].attention.prune_heads(heads) def forward( - self, input_ids=None, attention_mask=None, position_ids=None, head_masks=None, inputs_embeds=None, num_hashes=None, do_output_hidden_states=False, do_output_attentions=False + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_masks=None, + inputs_embeds=None, + num_hashes=None, + do_output_hidden_states=False, + do_output_attentions=False, ): r""" Return: @@ -1144,7 +1214,13 @@ def forward( embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) - encoder_outputs = self.encoder(hidden_states=embedding_output, head_masks=head_masks, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions) + encoder_outputs = self.encoder( + hidden_states=embedding_output, + head_masks=head_masks, + num_hashes=num_hashes, + do_output_hidden_states=do_output_hidden_states, + do_output_attentions=do_output_attentions, + ) sequence_output = encoder_outputs[0] # if padding was applied @@ -1196,11 +1272,25 @@ def tie_weights(self): pass def forward( - self, input_ids=None, position_ids=None, head_masks=None, inputs_embeds=None, num_hashes=None, lm_labels=None, do_output_hidden_states=False, do_output_attentions=False + self, + input_ids=None, + position_ids=None, + head_masks=None, + inputs_embeds=None, + num_hashes=None, + lm_labels=None, + do_output_hidden_states=False, + do_output_attentions=False, ): reformer_outputs = self.reformer( - input_ids, position_ids=position_ids, head_masks=head_masks, inputs_embeds=inputs_embeds, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions + input_ids, + position_ids=position_ids, + head_masks=head_masks, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + do_output_hidden_states=do_output_hidden_states, + do_output_attentions=do_output_attentions, ) sequence_output = reformer_outputs[0] diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c35ba0df5513..24ddffbefa86 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -21,7 +21,6 @@ # trax imports - to be deleted later import trax from transformers import is_torch_available # noqa: F401 - from trax.shapes import ShapeDtype as trax_ShapeDtype from .utils import require_torch, torch_device # noqa: F401 @@ -114,7 +113,9 @@ def test_backprop_lm_model(self): config = ReformerConfig() shape = (1, 192) # Batch x SeqLen x ModelDimPerHead - input_ids = torch.tensor(np.random.randint(0, config.vocab_size, size=shape), dtype=torch.long, device=torch_device) + input_ids = torch.tensor( + np.random.randint(0, config.vocab_size, size=shape), dtype=torch.long, device=torch_device + ) model = ReformerModelWithLMHead(config) loss = model(input_ids, lm_labels=input_ids)[0] From 902408b8cd1a8dcd3012854a034ac37bd5b8ccf8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Apr 2020 21:50:34 +0200 Subject: [PATCH 071/162] remove set max length --- src/transformers/tokenization_reformer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index 2ac96b877541..3f4e0ee85e78 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -96,12 +96,6 @@ def __init__( additional_special_tokens=additional_special_tokens, **kwargs, ) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens try: import sentencepiece as spm From 8ed63ab064ebb731d64117f8e543d900a30213d4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Apr 2020 14:30:10 +0200 Subject: [PATCH 072/162] add attention masks --- src/transformers/modeling_reformer.py | 171 +++++++++++++++++++------- tests/test_modeling_reformer.py | 18 ++- 2 files changed, 137 insertions(+), 52 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 54960fc2b524..3c3cd69bc934 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -307,7 +307,7 @@ def _transpose_for_output(self, x, num_attn_heads, attn_head_size): x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size): + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): batch_size = vectors.shape[0] split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) @@ -354,7 +354,14 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False, buckets=None, **kwargs + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + buckets=None, + **kwargs ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -397,7 +404,8 @@ def forward( value_vectors = self._split_dim_by( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size ) - ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size) + + ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads) if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 @@ -405,7 +413,13 @@ def forward( key_vectors = self._len_and_dim_norm(query_key_vectors) out_vectors, logits, attention_probs = self._attend( - query_key_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask + query_vectors=query_key_vectors, + key_vectors=key_vectors, + value_vectors=value_vectors, + ticker=ticker, + undo_ticker=undo_ticker, + attention_mask=attention_mask, + head_mask=head_mask, ) if num_hashes > 1: @@ -523,7 +537,7 @@ def _set_num_buckets_on_the_fly(self, sequence_length): logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) self.num_buckets = num_buckets - def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, head_mask): + def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, attention_mask, head_mask): key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) @@ -534,19 +548,34 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker query_info = ticker key_value_info = self._look_adjacent(ticker, self.num_chunks_before, self.num_chunks_after) + mask = None # Causal mask if self.is_decoder: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.lt(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - mask * 1e9 + mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + + # Attention mask: chunk, look up correct mask value from key_value_info + # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. + if attention_mask is not None: + attn_mask = attention_mask.to(torch.uint8)[:, None, :] + attn_mask = self._split_dim_by(attn_mask, -1, self.chunk_length, 1).expand(ticker.shape) + attn_mask = torch.gather(attn_mask, -1, key_value_info) + # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk + attn_mask = attn_mask.unsqueeze(-2).expand(query_key_dots.shape) + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask - # Self mask - # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.eq(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).long().to(query_info.device) - query_key_dots = query_key_dots - mask * 1e5 + # if attention_mask and/or casaul mask apply here + if mask is not None: + query_key_dots = torch.where( + mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device) + ) + # Self mask is ALWAYS applied # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. - # Note: Self mask is used because Q and K projection weights are shared. # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): @@ -557,6 +586,10 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " + mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + query_key_dots = torch.where( + mask, query_key_dots, torch.tensor(-1e5, dtype=query_key_dots.dtype, device=query_key_dots.device) + ) logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (num_chunks_before + num_chunks_after)]` @@ -623,7 +656,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob - def forward(self, hidden_states, head_mask=None, do_output_attentions=False, **kwargs): + def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -667,28 +700,42 @@ def forward(self, hidden_states, head_mask=None, do_output_attentions=False, **k indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) - query_indices = self._split_dim_by( - indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size - ) - key_value_indices = self._split_dim_by( - indices, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size - ) + query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) + key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) # append chunks before and after key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) key_value_indices = self._look_adjacent(key_value_indices, self.num_chunks_before, self.num_chunks_after) + # chunk attention mask and look before and after + if attention_mask is not None: + attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attention_mask = self._split_dim_by(attention_mask, -1, self.chunk_length, 1) + attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + mask = None # Causal mask - if self.is_decoder: + if self.is_decoder is True: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = ( - torch.lt(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).long().to(query_indices.device) + mask = torch.ge(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).to(query_indices.device) + + # Attention mask + if attention_mask is not None: + attn_mask = attention_mask.unsqueeze(-2).expand(query_key_dots.shape) + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask + + if mask is not None: + query_key_dots = torch.where( + mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device) ) - query_key_dots = query_key_dots - mask * 1e9 logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) @@ -755,13 +802,22 @@ def __init__(self, config, layer_id=0): ) self.output = ReformerSelfOutput(config) - def forward(self, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False, buckets=None): + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + buckets=None, + ): norm_hidden_states = self.layer_norm(hidden_states) # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( hidden_states=norm_hidden_states, head_mask=head_mask, + attention_mask=attention_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, buckets=buckets, @@ -836,12 +892,21 @@ def __init__(self, config, layer_id=0): self.attention = ReformerAttention(config, layer_id) self.feed_forward = ChunkReformerFeedForward(config) - def forward(self, prev_attn_output, hidden_states, head_mask=None, num_hashes=None, do_output_attentions=False): + def forward( + self, + prev_attn_output, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + ): with torch.no_grad(): attn_outputs = self.attention( hidden_states=hidden_states, head_mask=head_mask, + attention_mask=attention_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, ) @@ -861,7 +926,14 @@ def forward(self, prev_attn_output, hidden_states, head_mask=None, num_hashes=No ) def backward_pass( - self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, head_mask=None, buckets=None + self, + next_attn_output, + hidden_states, + grad_attn_output, + grad_hidden_states, + attention_mask=None, + head_mask=None, + buckets=None, ): # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py @@ -886,7 +958,9 @@ def backward_pass( hidden_states.requires_grad = True # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention - output = self.attention(hidden_states=hidden_states, head_mask=head_mask, buckets=buckets).hidden_states + output = self.attention( + hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets + ).hidden_states torch.autograd.backward(output, grad_attn_output, retain_graph=True) with torch.no_grad(): @@ -911,6 +985,7 @@ def forward( ctx, hidden_states, layers, + attention_mask, head_masks, num_hashes, all_hidden_states, @@ -929,6 +1004,7 @@ def forward( layer_outputs = layer( prev_attn_output=attn_output, hidden_states=hidden_states, + attention_mask=attention_mask, head_mask=head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, @@ -946,6 +1022,7 @@ def forward( all_hidden_states.append(hidden_states) # attach params to ctx for backward + ctx.attention_mask = attention_mask ctx.head_masks = head_masks ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() @@ -963,6 +1040,7 @@ def backward(ctx, grad_hidden_states): # retrieve params from ctx for backward attn_output = ctx.attn_output hidden_states = ctx.hidden_states + attention_mask = ctx.attention_mask head_masks = ctx.head_masks layers = ctx.layers all_buckets = ctx.all_buckets @@ -979,6 +1057,7 @@ def backward(ctx, grad_hidden_states): grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, head_mask=head_mask, + attention_mask=attention_mask, buckets=buckets, ) @@ -986,7 +1065,7 @@ def backward(ctx, grad_hidden_states): # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args - return grad_hidden_states, None, None, None, None, None, None, None + return grad_hidden_states, None, None, None, None, None, None, None, None class ReformerEncoder(nn.Module): @@ -1001,6 +1080,7 @@ def __init__(self, config): def forward( self, hidden_states, + attention_mask=None, head_masks=None, num_hashes=None, do_output_hidden_states=False, @@ -1015,6 +1095,7 @@ def forward( hidden_states = _ReversibleFunction.apply( hidden_states, self.layers, + attention_mask, head_masks, num_hashes, all_hidden_states, @@ -1191,32 +1272,28 @@ def forward( ], dim=-1, ) + + # Extend `attention_mask` + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask, + torch.zeros(input_shape[0], padding_length, device=attention_mask.device, dtype=torch.long), + ], + dim=-1, + ) + input_shape = input_ids.size() - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_masks is not None: - if head_masks.dim() == 1: - head_masks = head_masks.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_masks = head_masks.expand(self.config.num_hidden_layers, -1, -1, -1, -1) - elif head_masks.dim() == 2: - head_masks = ( - head_masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - ) # We can specify head_masks for each layer - head_masks = head_masks.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility - else: - head_masks = [None] * self.config.num_hidden_layers + # prepare head mask + head_masks = self.get_head_mask(head_masks, self.config.num_hidden_layers) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( hidden_states=embedding_output, head_masks=head_masks, + attention_mask=attention_mask, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions, @@ -1275,6 +1352,7 @@ def forward( self, input_ids=None, position_ids=None, + attention_mask=None, head_masks=None, inputs_embeds=None, num_hashes=None, @@ -1286,6 +1364,7 @@ def forward( reformer_outputs = self.reformer( input_ids, position_ids=position_ids, + attention_mask=attention_mask, head_masks=head_masks, inputs_embeds=inputs_embeds, num_hashes=num_hashes, diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 24ddffbefa86..debb65d04fe9 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -35,12 +35,13 @@ class ReformerIntegrationTests(unittest.TestCase): def test_lsh_layer(self): config = ReformerConfig() - shape = (2, 192, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_layer = self.load_lsh_layer(config) input_signature = trax_ShapeDtype(shape, np.float32) trax_weights, trax_state = trax_layer.init(input_signature) + mask = np.ones(shape[:-1], dtype=np.int32) trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) @@ -52,7 +53,7 @@ def test_lsh_layer(self): self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() - hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_attention_all_heads = hf_layer.self_attention(hf_input, attention_mask=torch.tensor(mask))[0] hf_output = hf_layer.output(hf_attention_all_heads) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) @@ -62,10 +63,13 @@ def test_local_layer(self): shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) - trax_layer = self.load_local_layer(config) + trax_layer = self.load_local_layer(config, mask=True) input_signature = trax_ShapeDtype(shape, np.float32) trax_weights, trax_state = trax_layer.init(input_signature) - trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) + mask = np.ones(shape[:-1], dtype=np.int32) + mask[:, -10:] = 0 + + trax_output = trax_layer(np_input, mask=mask, weights=trax_weights, state=trax_state) hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_layers = ["local"] @@ -73,7 +77,7 @@ def test_local_layer(self): self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() - hf_attention_all_heads = hf_layer.self_attention(hf_input)[0] + hf_attention_all_heads = hf_layer.self_attention(hf_input, attention_mask=torch.tensor(mask))[0] hf_output = hf_layer.output(hf_attention_all_heads) trax_torch_output = torch.tensor(np.asarray(trax_output)) @@ -180,7 +184,7 @@ def load_lsh_layer(self, config, mode="eval"): layer = trax.layers.LSHSelfAttention(mode=mode) return layer - def load_local_layer(self, config, mode="eval"): + def load_local_layer(self, config, mask=False, mode="eval"): gin_config = """ import trax.layers @@ -195,6 +199,7 @@ def load_local_layer(self, config, mode="eval"): SelfAttention.attention_dropout = {} SelfAttention.output_dropout = {} SelfAttention.causal = {} + SelfAttention.masked= {} SelfAttention.use_reference_code = True """.format( config.num_attention_heads, @@ -206,6 +211,7 @@ def load_local_layer(self, config, mode="eval"): config.attention_probs_dropout_prob, config.hidden_dropout_prob, config.is_decoder, + mask, ) gin.parse_config(gin_config) layer = trax.layers.SelfAttention(mode=mode) From 513bb4342b4be42bbcc6b7ff58b9fd026117ee5c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 22 Apr 2020 14:37:50 +0200 Subject: [PATCH 073/162] fix up tests --- tests/test_modeling_reformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index debb65d04fe9..419886b1a298 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -63,13 +63,12 @@ def test_local_layer(self): shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) - trax_layer = self.load_local_layer(config, mask=True) + trax_layer = self.load_local_layer(config) input_signature = trax_ShapeDtype(shape, np.float32) trax_weights, trax_state = trax_layer.init(input_signature) mask = np.ones(shape[:-1], dtype=np.int32) - mask[:, -10:] = 0 - trax_output = trax_layer(np_input, mask=mask, weights=trax_weights, state=trax_state) + trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_layers = ["local"] From db60c23d101777634925eabb9ee65ed607bef127 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 14:53:23 +0200 Subject: [PATCH 074/162] fix conflict --- src/transformers/configuration_auto.py | 3 +++ src/transformers/configuration_reformer.py | 2 +- src/transformers/modeling_auto.py | 4 ++++ src/transformers/modeling_reformer.py | 2 ++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index ece2289c60f8..59272a86cca7 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -29,6 +29,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig +from .configuration_reformer import ReformerConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig @@ -73,6 +74,7 @@ ("camembert", CamembertConfig,), ("xlm-roberta", XLMRobertaConfig,), ("bart", BartConfig,), + ("reformer", ReformerConfig,), ("roberta", RobertaConfig,), ("flaubert", FlaubertConfig,), ("bert", BertConfig,), @@ -130,6 +132,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model) + - contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model) - contains `bert`: :class:`~transformers.BertConfig` (Bert model) - contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model) - contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 734a7fb64c57..261895bc5956 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -93,7 +93,7 @@ class ReformerConfig(PretrainedConfig): A dictionary containing all the available pre-trained checkpoints. """ pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP - model_type = "bert" + model_type = "reformer" def __init__( self, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index c54b9c719840..143c9b2b8fbe 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -31,6 +31,7 @@ FlaubertConfig, GPT2Config, OpenAIGPTConfig, + ReformerConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -97,6 +98,7 @@ ) from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel +from .modeling_reformer import ReformerModel, ReformerModelWithLMHead from .modeling_roberta import ( ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, @@ -179,6 +181,7 @@ (XLMConfig, XLMModel), (CTRLConfig, CTRLModel), (ElectraConfig, ElectraModel), + (ReformerConfig, ReformerModel) ] ) @@ -222,6 +225,7 @@ (CTRLConfig, CTRLLMHeadModel), (ElectraConfig, ElectraForMaskedLM), (EncoderDecoderConfig, EncoderDecoderModel), + (ReformerConfig, ReformerModelWithLMHead), ] ) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3c3cd69bc934..1e31f3c04bcd 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1237,6 +1237,8 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") + assert len(input_shape) == 2, "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) + device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 real_sequence_length = input_shape[-1] From 9f359af44e48f7fcd90e5c4c4df1f7ffd3ffc531 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 23 Apr 2020 13:46:47 +0200 Subject: [PATCH 075/162] fix lsh attention mask --- src/transformers/modeling_reformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1e31f3c04bcd..aea6850a9765 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -557,8 +557,9 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # Attention mask: chunk, look up correct mask value from key_value_info # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. if attention_mask is not None: - attn_mask = attention_mask.to(torch.uint8)[:, None, :] - attn_mask = self._split_dim_by(attn_mask, -1, self.chunk_length, 1).expand(ticker.shape) + attn_mask = attention_mask.to(torch.uint8)[:, None, None, :] + # expand attn_mask to fit with key_value_info shape + attn_mask = attn_mask.expand(ticker.shape[:-1] + (-1,)) attn_mask = torch.gather(attn_mask, -1, key_value_info) # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk attn_mask = attn_mask.unsqueeze(-2).expand(query_key_dots.shape) From 48097a074875af7407dbe9ac8d4dd94ebc050cea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 23 Apr 2020 14:05:52 +0200 Subject: [PATCH 076/162] make random seed optional for the moment --- src/transformers/configuration_reformer.py | 5 +++-- src/transformers/modeling_reformer.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 261895bc5956..6f9d57ac10ce 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -123,12 +123,13 @@ def __init__( axial_pos_embds_dim=[64, 64], attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0, + seed=0, # TO REMOVE LATER: **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) - # TO CHANGE LATER: - self.seed = 0 + # TO REMOVE LATER: + self.seed = seed self.is_decoder = True self.vocab_size = vocab_size diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index aea6850a9765..3ef17e85ba06 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -467,12 +467,15 @@ def _hash_vectors(self, vectors, num_hashes): rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) # TODO: delete later when integration tests are ok - # create a random self.attention_head_size x num_hashes x self.num_buckets/2 - # random_rotations = torch.randn(rotations_shape, device=vectors.device) - np.random.seed(self.hash_seed) - random_rotations = torch.tensor( - np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device - ) + if self.hash_seed is not None: + np.random.seed(self.hash_seed) + random_rotations = torch.tensor( + np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device + ) + else: + # create a random self.attention_head_size x num_hashes x self.num_buckets/2 + random_rotations = torch.randn(rotations_shape, device=vectors.device) + # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 # TODO: IMPORTANT: At the moment we use the same random rotation over all batches From 650e00c4bf632845ff36fea6cca77bc84eeeca66 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 23 Apr 2020 15:27:05 +0200 Subject: [PATCH 077/162] improve memory in reformer --- src/transformers/modeling_reformer.py | 531 ++++++++++++++++++++------ 1 file changed, 415 insertions(+), 116 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3ef17e85ba06..ba36e9687d5f 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -48,18 +48,37 @@ def mish(x): return x * torch.tanh(nn.functional.softplus(x)) -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} +ACT2FN = { + "gelu": gelu, + "relu": torch.nn.functional.relu, + "swish": swish, + "gelu_new": gelu_new, + "mish": mish, +} ReformerLayerNorm = torch.nn.LayerNorm -LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) -LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) -AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) -ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) +LSHSelfAttentionOutput = namedtuple( + "LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"] +) +LocalSelfAttentionOutput = namedtuple( + "LocalSelfAttentionOutput", ["hidden_states", "attention_probs"] +) +AttentionOutput = namedtuple( + "AttentionOutput", ["hidden_states", "attention_probs", "buckets"] +) +ReformerOutput = namedtuple( + "ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"] +) def create_sinusoidal_embeddings(n_pos, dim, out): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ] + ) out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() @@ -83,7 +102,10 @@ def _get_least_common_mult_chunk_len(config): def apply_chunking_to_forward( - chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors + chunk_size: int, + chunk_dim: int, + forward_fn: Callable[..., torch.Tensor], + *input_tensors ) -> torch.Tensor: """ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. @@ -100,7 +122,9 @@ def apply_chunking_to_forward( a Tensor with the same shape the foward_fn would have given if applied """ - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format( + input_tensors + ) tensor_shape = input_tensors[0].shape assert all( input_tensor.shape == tensor_shape for input_tensor in input_tensors @@ -124,9 +148,15 @@ def apply_chunking_to_forward( num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size # chunk input tensor into tuples - input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + input_tensors_chunks = tuple( + input_tensor.chunk(num_chunks, dim=chunk_dim) + for input_tensor in input_tensors + ) # apply forward fn to every tuple - output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + output_chunks = tuple( + forward_fn(*input_tensors_chunk) + for input_tensors_chunk in zip(*input_tensors_chunks) + ) # concatenate output at same dimension return torch.cat(output_chunks, dim=chunk_dim) @@ -168,7 +198,8 @@ def forward(self, position_ids): sequence_length = position_ids.shape[1] broadcasted_weights = [ - weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights + weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) + for weight in self.weights ] if self.training is True: @@ -182,12 +213,19 @@ def forward(self, position_ids): # permute weights so that 2D correctly drops dims 1 and 2 perm_weigthts = weights.permute(0, 3, 2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) - drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) + drop_perm_weights = nn.functional.dropout2d( + perm_weigthts, self.dropout, training=self.training + ) drop_weights = drop_perm_weights.permute(0, 3, 2, 1) - position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) + position_encodings = torch.reshape( + drop_weights, (batch_size, sequence_length, -1) + ) else: position_encodings = torch.cat( - [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], + [ + torch.reshape(weight, (batch_size, sequence_length, -1)) + for weight in broadcasted_weights + ], dim=-1, ) @@ -195,13 +233,15 @@ def forward(self, position_ids): assert ( reduce(mul, self.axial_pos_shape) >= sequence_length ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( - self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length + self.axial_pos_shape, + sequence_length, + self.least_common_mult_chunk_length, ) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) - position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[ - :, :sequence_length - ] + position_encodings = position_encodings.view( + batch_size, -1, position_encodings.shape[-1] + )[:, :sequence_length] return position_encodings @@ -214,16 +254,22 @@ def __init__(self, config): super().__init__() self.config = config self.dropout = config.hidden_dropout_prob - self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.embedding = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) if self.config.sinusoidal_pos_embds is True: create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.embedding.weight + n_pos=config.max_position_embeddings, + dim=config.hidden_size, + out=self.embedding.weight, ) def forward(self, position_ids): position_embeddings = self.embedding(position_ids) if self.config.sinusoidal_pos_embds is False: - position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) + position_embeddings = nn.functional.dropout( + position_embeddings, self.dropout, self.training + ) return position_embeddings @@ -235,7 +281,9 @@ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = ( - AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) + AxialPositionEmbeddings(config) + if config.axial_pos_embds + else PositionEmbeddings(config) ) assert not ( config.sinusoidal_pos_embds and config.axial_pos_embds @@ -295,7 +343,9 @@ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): if i == 0: slices.append(vectors) else: - slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + slices.append( + torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2) + ) return torch.cat(slices, dim=3) def _transpose_for_scores(self, x, num_attn_heads, attn_head_size): @@ -307,7 +357,9 @@ def _transpose_for_output(self, x, num_attn_heads, attn_head_size): x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): + def _split_dim_by( + self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None + ): batch_size = vectors.shape[0] split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) @@ -316,7 +368,11 @@ def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, att elif len(vectors.shape) == 3: return torch.reshape(vectors, split_dim_shape) else: - raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + raise ValueError( + "Input vector rank should be one of [3, 4], but is: {}".format( + len(vectors.shape) + ) + ) def _merge_by_middle_dim(self, vectors, num_attn_heads): batch_size = vectors.shape[0] @@ -327,7 +383,11 @@ def _merge_by_middle_dim(self, vectors, num_attn_heads): elif len(vectors.shape) == 4: return torch.reshape(vectors, new_dim_shape) else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) + raise ValueError( + "Input vector rank should be one of [4, 5], but is: {}".format( + len(vectors.shape) + ) + ) class LSHSelfAttention(nn.Module, EfficientAttentionUtils): @@ -369,14 +429,17 @@ def forward( # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes - mixed_query_key_vectors = self.query_key(hidden_states) - mixed_value_vectors = self.value(hidden_states) + query_key_vectors = self.query_key(hidden_states) + value_vectors = self.value(hidden_states) + + # free memory + del hidden_states query_key_vectors = self._transpose_for_scores( - mixed_query_key_vectors, self.num_attention_heads, self.attention_head_size + query_key_vectors, self.num_attention_heads, self.attention_head_size ) value_vectors = self._transpose_for_scores( - mixed_value_vectors, self.num_attention_heads, self.attention_head_size + value_vectors, self.num_attention_heads, self.attention_head_size ) assert query_key_vectors.shape[-1] == self.attention_head_size @@ -393,19 +456,33 @@ def forward( assert int(buckets.shape[-1]) == num_hashes * sequence_length - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets, num_hashes) + ticker, undo_ticker = self._get_ticker_and_undo_ticker( + sequence_length, buckets, num_hashes + ) - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker, num_hashes) + query_key_vectors = self._gather_by_expansion( + query_key_vectors, ticker, num_hashes + ) value_vectors = self._gather_by_expansion(value_vectors, ticker, num_hashes) query_key_vectors = self._split_dim_by( - query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + query_key_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, ) value_vectors = self._split_dim_by( - value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + value_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, ) - ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads) + ticker = self._split_dim_by( + ticker, -1, self.chunk_length, self.num_attention_heads + ) if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 @@ -422,25 +499,53 @@ def forward( head_mask=head_mask, ) + # free memory + del query_key_vectors, key_vectors, value_vectors + if num_hashes > 1: out_vectors = self._split_dim_by( - out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + out_vectors, + num_hashes, + sequence_length, + self.num_attention_heads, + self.attention_head_size, ) logits = self._split_dim_by( - logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size + logits, + num_hashes, + sequence_length, + self.num_attention_heads, + self.attention_head_size, ).unsqueeze(-1) - probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + probs_vectors = torch.exp( + logits - torch.logsumexp(logits, dim=2, keepdim=True) + ) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + # free memory + del probs_vectors + + # free memory + del logits + + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ) - out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) + out_vectors = self._transpose_for_output( + out_vectors, self.num_attention_heads, self.attention_head_size + ) if do_output_attentions is False: attention_probs = () - return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) + return LSHSelfAttentionOutput( + hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets + ) def _hash_vectors(self, vectors, num_hashes): batch_size = vectors.shape[0] @@ -451,14 +556,18 @@ def _hash_vectors(self, vectors, num_hashes): if isinstance(self.num_buckets, int): assert ( self.num_buckets % 2 == 0 - ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format(self.num_buckets) + ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format( + self.num_buckets + ) rotation_size = self.num_buckets num_buckets = self.num_buckets else: # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( + assert ( + num_bucket % 2 == 0 + ), "The number of buckets should be even, but `num_bucket`: {}".format( num_bucket ) rotation_size += num_bucket @@ -470,7 +579,9 @@ def _hash_vectors(self, vectors, num_hashes): if self.hash_seed is not None: np.random.seed(self.hash_seed) random_rotations = torch.tensor( - np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device + np.random.normal(size=rotations_shape), + dtype=torch.float32, + device=vectors.device, ) else: # create a random self.attention_head_size x num_hashes x self.num_buckets/2 @@ -490,7 +601,9 @@ def _hash_vectors(self, vectors, num_hashes): # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum : cur_sum + (num_bucket // 2)] + rotated_vectors = rotated_vectors[ + ..., cur_sum : cur_sum + (num_bucket // 2) + ] cur_sum += num_bucket // 2 rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -508,7 +621,9 @@ def _hash_vectors(self, vectors, num_hashes): # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets, self.num_attention_heads) + offset_buckets = self._merge_by_middle_dim( + buckets + offsets, self.num_attention_heads + ) return offset_buckets @@ -533,29 +648,56 @@ def _set_num_buckets_on_the_fly(self, sequence_length): num_buckets = 2 * sequence_length // self.chunk_length # factorize `num_buckets` if `num_buckets` becomes too large - num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length) + num_buckets_limit = max( + int((self.max_position_embeddings // self.chunk_length) ** (0.5)), + self.chunk_length, + ) if num_buckets > 2 * num_buckets_limit: num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] - logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) + logger.warning( + "config.num_buckets is not set. Setting config.num_buckets to {}...".format( + num_buckets + ) + ) self.num_buckets = num_buckets - def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, attention_mask, head_mask): - key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) - value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + def _attend( + self, + query_vectors, + key_vectors, + value_vectors, + ticker, + undo_ticker, + attention_mask, + head_mask, + ): + key_vectors = self._look_adjacent( + key_vectors, self.num_chunks_before, self.num_chunks_after + ) + value_vectors = self._look_adjacent( + value_vectors, self.num_chunks_before, self.num_chunks_after + ) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + # free memory + del query_vectors, key_vectors + # TODO(PVP): better naming query_info = ticker - key_value_info = self._look_adjacent(ticker, self.num_chunks_before, self.num_chunks_after) + key_value_info = self._look_adjacent( + ticker, self.num_chunks_before, self.num_chunks_after + ) mask = None # Causal mask if self.is_decoder: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to( + query_info.device + ) # Attention mask: chunk, look up correct mask value from key_value_info # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. @@ -572,10 +714,17 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker else: mask = attn_mask + # free memory + del attn_mask + # if attention_mask and/or casaul mask apply here if mask is not None: query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device) + mask, + query_key_dots, + torch.tensor( + -1e9, dtype=query_key_dots.dtype, device=query_key_dots.device + ), ) # Self mask is ALWAYS applied @@ -590,15 +739,27 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " - mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to( + query_info.device + ) query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e5, dtype=query_key_dots.dtype, device=query_key_dots.device) + mask, + query_key_dots, + torch.tensor( + -1e5, dtype=query_key_dots.dtype, device=query_key_dots.device + ), ) + # free memory + del mask + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (num_chunks_before + num_chunks_after)]` dots = torch.exp(query_key_dots - logits) + # free memory + del query_key_dots + # TODO(PVP): discuss with thom. Trax uses special dropout here where same dropout mask is applied for all "num_hashes * seq_len // chunk_length" dim. # should be fine with normal dropout, no? # dropout @@ -611,11 +772,16 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker # attend values out_vectors = torch.matmul(dots, value_vectors) + # free memory + del value_vectors + # merge chunk length logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand( + -1, -1, -1, self.attention_head_size + ) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) logits = torch.gather(logits, 2, undo_ticker) @@ -624,7 +790,9 @@ def _attend(self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) vectors = vectors / torch.sqrt( - torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32) + torch.tensor( + self.attention_head_size, device=vectors.device, dtype=torch.float32 + ) ) return vectors @@ -660,22 +828,31 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob - def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + do_output_attentions=False, + **kwargs + ): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - mixed_query_vectors = self.query(hidden_states) - mixed_key_vectors = self.key(hidden_states) - mixed_value_vectors = self.value(hidden_states) + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) query_vectors = self._transpose_for_scores( - mixed_query_vectors, self.num_attention_heads, self.attention_head_size + query_vectors, self.num_attention_heads, self.attention_head_size + ) + key_vectors = self._transpose_for_scores( + key_vectors, self.num_attention_heads, self.attention_head_size ) - key_vectors = self._transpose_for_scores(mixed_key_vectors, self.num_attention_heads, self.attention_head_size) value_vectors = self._transpose_for_scores( - mixed_value_vectors, self.num_attention_heads, self.attention_head_size + value_vectors, self.num_attention_heads, self.attention_head_size ) assert query_vectors.shape[-1] == self.attention_head_size @@ -686,46 +863,79 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ assert self.num_chunks_before == 0 and self.num_chunks_after == 0 key_vectors = key_vectors / torch.sqrt( - torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32) + torch.tensor( + self.attention_head_size, device=key_vectors.device, dtype=torch.float32 + ) ) # chunk vectors query_vectors = self._split_dim_by( - query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + query_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, ) key_vectors = self._split_dim_by( - key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + key_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, ) value_vectors = self._split_dim_by( - value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size + value_vectors, + -1, + self.chunk_length, + self.num_attention_heads, + self.attention_head_size, ) # chunk indices indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) - query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) - key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) + query_indices = self._split_dim_by( + indices, -1, self.chunk_length, self.num_attention_heads + ) + key_value_indices = self._split_dim_by( + indices, -1, self.chunk_length, self.num_attention_heads + ) # append chunks before and after - key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) - value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) - key_value_indices = self._look_adjacent(key_value_indices, self.num_chunks_before, self.num_chunks_after) + key_vectors = self._look_adjacent( + key_vectors, self.num_chunks_before, self.num_chunks_after + ) + value_vectors = self._look_adjacent( + value_vectors, self.num_chunks_before, self.num_chunks_after + ) + key_value_indices = self._look_adjacent( + key_value_indices, self.num_chunks_before, self.num_chunks_after + ) # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attention_mask = self._split_dim_by(attention_mask, -1, self.chunk_length, 1) - attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + attention_mask = self._split_dim_by( + attention_mask, -1, self.chunk_length, 1 + ) + attention_mask = self._look_adjacent( + attention_mask, self.num_chunks_before, self.num_chunks_after + ) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + # free memory + del query_vectors, key_vectors + mask = None # Causal mask if self.is_decoder is True: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.ge(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).to(query_indices.device) + mask = torch.ge( + query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2) + ).to(query_indices.device) # Attention mask if attention_mask is not None: @@ -736,16 +946,31 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ else: mask = attn_mask + # free memory + del attn_mask + if mask is not None: query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device) + mask, + query_key_dots, + torch.tensor( + -1e9, dtype=query_key_dots.dtype, device=query_key_dots.device + ), ) + # free memory + del mask + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) + # free memory + del logits + # dropout - attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) + attention_probs = nn.functional.dropout( + attention_probs, self.dropout, self.training + ) # Mask heads if we want to if head_mask is not None: @@ -754,18 +979,29 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # attend values out_vectors = torch.matmul(attention_probs, value_vectors) + # free memory + del value_vectors + # merge chunk length - logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size) + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ) - out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) + out_vectors = self._transpose_for_output( + out_vectors, self.num_attention_heads, self.attention_head_size + ) if do_output_attentions is False: attention_probs = () - return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) + return LocalSelfAttentionOutput( + hidden_states=out_vectors, attention_probs=attention_probs + ) class ReformerSelfOutput(nn.Module): @@ -777,7 +1013,9 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout( + hidden_states, self.dropout, self.training + ) return hidden_states @@ -785,14 +1023,18 @@ class ReformerAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.layer_id = layer_id - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = ReformerLayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) self.attn_layers = config.attn_layers if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": self.self_attention = LocalSelfAttention(config) - elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): + elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set( + ["lsh", "local"] + ): # get correct attn_layers if self.attn_layers[self.layer_id] == "lsh": self.self_attention = LSHSelfAttention(config) @@ -815,18 +1057,17 @@ def forward( do_output_attentions=False, buckets=None, ): - norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.layer_norm(hidden_states) # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( - hidden_states=norm_hidden_states, + hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, buckets=buckets, ) - attention_output = self.output(self_attention_outputs.hidden_states) # add buckets if necessary @@ -836,7 +1077,9 @@ def forward( buckets = None return AttentionOutput( - hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets + hidden_states=attention_output, + attention_probs=self_attention_outputs.attention_probs, + buckets=buckets, ) @@ -852,7 +1095,9 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout( + hidden_states, self.dropout, self.training + ) hidden_states = self.act_fn(hidden_states) return hidden_states @@ -865,7 +1110,9 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout( + hidden_states, self.dropout, self.training + ) return hidden_states @@ -874,20 +1121,24 @@ def __init__(self, config): super().__init__() self.seq_len_dim = 1 self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = ReformerLayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) def forward(self, attention_output): return apply_chunking_to_forward( - self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output + self.chunk_size_feed_forward, + self.seq_len_dim, + self.forward_chunk, + attention_output, ) - def forward_chunk(self, attention_output): - norm_attention_output = self.layer_norm(attention_output) - dense_output = self.dense(norm_attention_output) - output = self.output(dense_output) - return output + def forward_chunk(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dense(hidden_states) + return self.output(hidden_states) class ReformerLayer(nn.Module): @@ -919,6 +1170,10 @@ def forward( # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # Y_1 = X_1 + f(X_2) attn_output = prev_attn_output + attn_output + + # free memory + del prev_attn_output + # Y_2 = X_2 + g(Y_1) hidden_states = hidden_states + self.feed_forward(attn_output) @@ -963,7 +1218,10 @@ def backward_pass( # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention output = self.attention( - hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + buckets=buckets, ).hidden_states torch.autograd.backward(output, grad_attn_output, retain_graph=True) @@ -1034,12 +1292,13 @@ def forward( ctx.all_buckets = all_buckets # Concatenate 2 RevNet outputs - hidden_states = torch.cat([attn_output, hidden_states], dim=-1) - return hidden_states + return torch.cat([attn_output, hidden_states], dim=-1) @staticmethod def backward(ctx, grad_hidden_states): - grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) + grad_attn_output, grad_hidden_states = torch.chunk( + grad_hidden_states, 2, dim=-1 + ) # retrieve params from ctx for backward attn_output = ctx.attn_output @@ -1055,7 +1314,12 @@ def backward(ctx, grad_hidden_states): all_buckets = all_buckets[:-1] # backprop - attn_output, hidden_states, grad_attn_output, grad_hidden_states = layer.backward_pass( + ( + attn_output, + hidden_states, + grad_attn_output, + grad_hidden_states, + ) = layer.backward_pass( next_attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, @@ -1065,6 +1329,9 @@ def backward(ctx, grad_hidden_states): buckets=buckets, ) + # free memory + del attn_output, hidden_states + grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) # num of return vars has to match num of forward() args @@ -1075,10 +1342,14 @@ def backward(ctx, grad_hidden_states): class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() - self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ReformerLayer(config, i) for i in range(config.num_hidden_layers)] + ) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size - self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = ReformerLayerNorm( + 2 * config.hidden_size, eps=config.layer_norm_eps + ) self.dropout = config.hidden_dropout_prob def forward( @@ -1112,7 +1383,9 @@ def forward( hidden_states = self.layer_norm(hidden_states) # Apply dropout - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout( + hidden_states, self.dropout, self.training + ) outputs = (hidden_states,) if do_output_hidden_states: @@ -1233,7 +1506,9 @@ def forward( """ if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() # noqa: F841 elif inputs_embeds is not None: @@ -1241,9 +1516,15 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - assert len(input_shape) == 2, "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) + assert ( + len(input_shape) == 2 + ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format( + input_shape + ) - device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa: F841 + device = ( + input_ids.device if input_ids is not None else inputs_embeds.device + ) # noqa: F841 real_sequence_length = input_shape[-1] # if needs padding @@ -1257,13 +1538,20 @@ def forward( input_shape[-2], least_common_mult_chunk_length, "training" if self.training is True else "config.is_decoder = True", - input_shape[-1] + (least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length), + input_shape[-1] + + ( + least_common_mult_chunk_length + - input_shape[-1] % least_common_mult_chunk_length + ), ) if input_ids is None: raise NotImplementedError("Currently only supported for `input_ids`") - padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length + padding_length = ( + least_common_mult_chunk_length + - input_shape[-1] % least_common_mult_chunk_length + ) # Extend `input_ids` with padding to match least common multiple chunk_length input_ids = torch.cat( @@ -1284,7 +1572,12 @@ def forward( attention_mask = torch.cat( [ attention_mask, - torch.zeros(input_shape[0], padding_length, device=attention_mask.device, dtype=torch.long), + torch.zeros( + input_shape[0], + padding_length, + device=attention_mask.device, + dtype=torch.long, + ), ], dim=-1, ) @@ -1294,7 +1587,9 @@ def forward( # prepare head mask head_masks = self.get_head_mask(head_masks, self.config.num_hidden_layers) - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds + ) encoder_outputs = self.encoder( hidden_states=embedding_output, @@ -1329,7 +1624,9 @@ def __init__(self, config): self.decoder.bias = self.bias def forward(self, hidden_states): - return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + return apply_chunking_to_forward( + self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states + ) # TODO(PVP): Does this work with backpropagation? def forward_chunk(self, hidden_states): @@ -1388,7 +1685,9 @@ def forward( shift_labels = lm_labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) outputs = (loss,) + outputs return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) From ccba9acb408c644ec7962809f0f1e174f25fd7ff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 01:01:47 +0200 Subject: [PATCH 078/162] add tests --- src/transformers/configuration_reformer.py | 12 +- src/transformers/modeling_reformer.py | 548 ++++++--------------- src/transformers/modeling_utils.py | 4 +- tests/test_modeling_common.py | 50 +- tests/test_modeling_reformer.py | 359 +++++++++++++- 5 files changed, 542 insertions(+), 431 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 6f9d57ac10ce..8b9439a5ce44 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -117,21 +117,19 @@ def __init__( initializer_range=0.02, axial_norm_std=1.0, layer_norm_eps=1e-12, - sinusoidal_pos_embds=True, + sinusoidal_pos_embds=False, axial_pos_embds=False, axial_pos_shape=[64, 3], axial_pos_embds_dim=[64, 64], attn_layers=["lsh", "lsh", "lsh", "lsh"], + is_decoder=False, pad_token_id=0, - seed=0, # TO REMOVE LATER: + hash_seed=None, **kwargs ): - super().__init__(pad_token_id=pad_token_id, **kwargs) + super().__init__(pad_token_id=pad_token_id, is_decoder=is_decoder, **kwargs) - # TO REMOVE LATER: - self.seed = seed - - self.is_decoder = True + self.hash_seed = hash_seed self.vocab_size = vocab_size self.attention_head_size = attention_head_size self.hidden_size = hidden_size diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ba36e9687d5f..2122e771ff2d 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -58,27 +58,14 @@ def mish(x): ReformerLayerNorm = torch.nn.LayerNorm -LSHSelfAttentionOutput = namedtuple( - "LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"] -) -LocalSelfAttentionOutput = namedtuple( - "LocalSelfAttentionOutput", ["hidden_states", "attention_probs"] -) -AttentionOutput = namedtuple( - "AttentionOutput", ["hidden_states", "attention_probs", "buckets"] -) -ReformerOutput = namedtuple( - "ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"] -) +LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) +AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) def create_sinusoidal_embeddings(n_pos, dim, out): - position_enc = np.array( - [ - [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] - for pos in range(n_pos) - ] - ) + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() @@ -102,10 +89,7 @@ def _get_least_common_mult_chunk_len(config): def apply_chunking_to_forward( - chunk_size: int, - chunk_dim: int, - forward_fn: Callable[..., torch.Tensor], - *input_tensors + chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors ) -> torch.Tensor: """ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. @@ -122,9 +106,7 @@ def apply_chunking_to_forward( a Tensor with the same shape the foward_fn would have given if applied """ - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format( - input_tensors - ) + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) tensor_shape = input_tensors[0].shape assert all( input_tensor.shape == tensor_shape for input_tensor in input_tensors @@ -148,15 +130,9 @@ def apply_chunking_to_forward( num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size # chunk input tensor into tuples - input_tensors_chunks = tuple( - input_tensor.chunk(num_chunks, dim=chunk_dim) - for input_tensor in input_tensors - ) + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) # apply forward fn to every tuple - output_chunks = tuple( - forward_fn(*input_tensors_chunk) - for input_tensors_chunk in zip(*input_tensors_chunks) - ) + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) # concatenate output at same dimension return torch.cat(output_chunks, dim=chunk_dim) @@ -198,8 +174,7 @@ def forward(self, position_ids): sequence_length = position_ids.shape[1] broadcasted_weights = [ - weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) - for weight in self.weights + weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights ] if self.training is True: @@ -213,19 +188,12 @@ def forward(self, position_ids): # permute weights so that 2D correctly drops dims 1 and 2 perm_weigthts = weights.permute(0, 3, 2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) - drop_perm_weights = nn.functional.dropout2d( - perm_weigthts, self.dropout, training=self.training - ) + drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) drop_weights = drop_perm_weights.permute(0, 3, 2, 1) - position_encodings = torch.reshape( - drop_weights, (batch_size, sequence_length, -1) - ) + position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) else: position_encodings = torch.cat( - [ - torch.reshape(weight, (batch_size, sequence_length, -1)) - for weight in broadcasted_weights - ], + [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], dim=-1, ) @@ -233,15 +201,13 @@ def forward(self, position_ids): assert ( reduce(mul, self.axial_pos_shape) >= sequence_length ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( - self.axial_pos_shape, - sequence_length, - self.least_common_mult_chunk_length, + self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, ) # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) - position_encodings = position_encodings.view( - batch_size, -1, position_encodings.shape[-1] - )[:, :sequence_length] + position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[ + :, :sequence_length + ] return position_encodings @@ -254,22 +220,16 @@ def __init__(self, config): super().__init__() self.config = config self.dropout = config.hidden_dropout_prob - self.embedding = nn.Embedding( - config.max_position_embeddings, config.hidden_size - ) + self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) if self.config.sinusoidal_pos_embds is True: create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, - dim=config.hidden_size, - out=self.embedding.weight, + n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.embedding.weight, ) def forward(self, position_ids): position_embeddings = self.embedding(position_ids) if self.config.sinusoidal_pos_embds is False: - position_embeddings = nn.functional.dropout( - position_embeddings, self.dropout, self.training - ) + position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) return position_embeddings @@ -281,9 +241,7 @@ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = ( - AxialPositionEmbeddings(config) - if config.axial_pos_embds - else PositionEmbeddings(config) + AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) ) assert not ( config.sinusoidal_pos_embds and config.axial_pos_embds @@ -343,9 +301,7 @@ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): if i == 0: slices.append(vectors) else: - slices.append( - torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2) - ) + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) return torch.cat(slices, dim=3) def _transpose_for_scores(self, x, num_attn_heads, attn_head_size): @@ -357,9 +313,7 @@ def _transpose_for_output(self, x, num_attn_heads, attn_head_size): x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) - def _split_dim_by( - self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None - ): + def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): batch_size = vectors.shape[0] split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) @@ -368,11 +322,7 @@ def _split_dim_by( elif len(vectors.shape) == 3: return torch.reshape(vectors, split_dim_shape) else: - raise ValueError( - "Input vector rank should be one of [3, 4], but is: {}".format( - len(vectors.shape) - ) - ) + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) def _merge_by_middle_dim(self, vectors, num_attn_heads): batch_size = vectors.shape[0] @@ -383,11 +333,7 @@ def _merge_by_middle_dim(self, vectors, num_attn_heads): elif len(vectors.shape) == 4: return torch.reshape(vectors, new_dim_shape) else: - raise ValueError( - "Input vector rank should be one of [4, 5], but is: {}".format( - len(vectors.shape) - ) - ) + raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) class LSHSelfAttention(nn.Module, EfficientAttentionUtils): @@ -395,7 +341,7 @@ def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads - self.hash_seed = config.seed + self.hash_seed = config.hash_seed self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets self.chunk_length = config.lsh_attn_chunk_length @@ -438,9 +384,7 @@ def forward( query_key_vectors = self._transpose_for_scores( query_key_vectors, self.num_attention_heads, self.attention_head_size ) - value_vectors = self._transpose_for_scores( - value_vectors, self.num_attention_heads, self.attention_head_size - ) + value_vectors = self._transpose_for_scores(value_vectors, self.num_attention_heads, self.attention_head_size) assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size @@ -456,33 +400,19 @@ def forward( assert int(buckets.shape[-1]) == num_hashes * sequence_length - ticker, undo_ticker = self._get_ticker_and_undo_ticker( - sequence_length, buckets, num_hashes - ) + ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets, num_hashes) - query_key_vectors = self._gather_by_expansion( - query_key_vectors, ticker, num_hashes - ) + query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, ticker, num_hashes) query_key_vectors = self._split_dim_by( - query_key_vectors, - -1, - self.chunk_length, - self.num_attention_heads, - self.attention_head_size, + query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) value_vectors = self._split_dim_by( - value_vectors, - -1, - self.chunk_length, - self.num_attention_heads, - self.attention_head_size, + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) - ticker = self._split_dim_by( - ticker, -1, self.chunk_length, self.num_attention_heads - ) + ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads) if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 @@ -504,23 +434,13 @@ def forward( if num_hashes > 1: out_vectors = self._split_dim_by( - out_vectors, - num_hashes, - sequence_length, - self.num_attention_heads, - self.attention_head_size, + out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ) logits = self._split_dim_by( - logits, - num_hashes, - sequence_length, - self.num_attention_heads, - self.attention_head_size, + logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ).unsqueeze(-1) - probs_vectors = torch.exp( - logits - torch.logsumexp(logits, dim=2, keepdim=True) - ) + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) # free memory @@ -529,23 +449,14 @@ def forward( # free memory del logits - assert out_vectors.shape == ( - batch_size, - self.num_attention_heads, - sequence_length, - self.attention_head_size, - ) + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) - out_vectors = self._transpose_for_output( - out_vectors, self.num_attention_heads, self.attention_head_size - ) + out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) if do_output_attentions is False: attention_probs = () - return LSHSelfAttentionOutput( - hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets - ) + return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) def _hash_vectors(self, vectors, num_hashes): batch_size = vectors.shape[0] @@ -556,18 +467,14 @@ def _hash_vectors(self, vectors, num_hashes): if isinstance(self.num_buckets, int): assert ( self.num_buckets % 2 == 0 - ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format( - self.num_buckets - ) + ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format(self.num_buckets) rotation_size = self.num_buckets num_buckets = self.num_buckets else: # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for num_bucket in self.num_buckets: - assert ( - num_bucket % 2 == 0 - ), "The number of buckets should be even, but `num_bucket`: {}".format( + assert num_bucket % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( num_bucket ) rotation_size += num_bucket @@ -579,9 +486,7 @@ def _hash_vectors(self, vectors, num_hashes): if self.hash_seed is not None: np.random.seed(self.hash_seed) random_rotations = torch.tensor( - np.random.normal(size=rotations_shape), - dtype=torch.float32, - device=vectors.device, + np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device, ) else: # create a random self.attention_head_size x num_hashes x self.num_buckets/2 @@ -601,9 +506,7 @@ def _hash_vectors(self, vectors, num_hashes): # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[ - ..., cur_sum : cur_sum + (num_bucket // 2) - ] + rotated_vectors = rotated_vectors[..., cur_sum : cur_sum + (num_bucket // 2)] cur_sum += num_bucket // 2 rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -621,9 +524,7 @@ def _hash_vectors(self, vectors, num_hashes): # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim( - buckets + offsets, self.num_attention_heads - ) + offset_buckets = self._merge_by_middle_dim(buckets + offsets, self.num_attention_heads) return offset_buckets @@ -648,36 +549,18 @@ def _set_num_buckets_on_the_fly(self, sequence_length): num_buckets = 2 * sequence_length // self.chunk_length # factorize `num_buckets` if `num_buckets` becomes too large - num_buckets_limit = max( - int((self.max_position_embeddings // self.chunk_length) ** (0.5)), - self.chunk_length, - ) + num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,) if num_buckets > 2 * num_buckets_limit: num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] - logger.warning( - "config.num_buckets is not set. Setting config.num_buckets to {}...".format( - num_buckets - ) - ) + logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) self.num_buckets = num_buckets def _attend( - self, - query_vectors, - key_vectors, - value_vectors, - ticker, - undo_ticker, - attention_mask, - head_mask, + self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, attention_mask, head_mask, ): - key_vectors = self._look_adjacent( - key_vectors, self.num_chunks_before, self.num_chunks_after - ) - value_vectors = self._look_adjacent( - value_vectors, self.num_chunks_before, self.num_chunks_after - ) + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) @@ -687,17 +570,13 @@ def _attend( # TODO(PVP): better naming query_info = ticker - key_value_info = self._look_adjacent( - ticker, self.num_chunks_before, self.num_chunks_after - ) + key_value_info = self._look_adjacent(ticker, self.num_chunks_before, self.num_chunks_after) mask = None # Causal mask if self.is_decoder: # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to( - query_info.device - ) + mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) # Attention mask: chunk, look up correct mask value from key_value_info # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. @@ -720,11 +599,7 @@ def _attend( # if attention_mask and/or casaul mask apply here if mask is not None: query_key_dots = torch.where( - mask, - query_key_dots, - torch.tensor( - -1e9, dtype=query_key_dots.dtype, device=query_key_dots.device - ), + mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device), ) # Self mask is ALWAYS applied @@ -739,22 +614,16 @@ def _attend( # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " - mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to( - query_info.device - ) + mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) query_key_dots = torch.where( - mask, - query_key_dots, - torch.tensor( - -1e5, dtype=query_key_dots.dtype, device=query_key_dots.device - ), + mask, query_key_dots, torch.tensor(-1e5, dtype=query_key_dots.dtype, device=query_key_dots.device), ) # free memory del mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) - # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (num_chunks_before + num_chunks_after)]` + # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` dots = torch.exp(query_key_dots - logits) # free memory @@ -779,9 +648,7 @@ def _attend( logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand( - -1, -1, -1, self.attention_head_size - ) + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) logits = torch.gather(logits, 2, undo_ticker) @@ -790,9 +657,7 @@ def _attend( def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) vectors = vectors / torch.sqrt( - torch.tensor( - self.attention_head_size, device=vectors.device, dtype=torch.float32 - ) + torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32) ) return vectors @@ -828,14 +693,7 @@ def __init__(self, config): self.dropout = config.attention_probs_dropout_prob - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - do_output_attentions=False, - **kwargs - ): + def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] @@ -845,15 +703,9 @@ def forward( key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) - query_vectors = self._transpose_for_scores( - query_vectors, self.num_attention_heads, self.attention_head_size - ) - key_vectors = self._transpose_for_scores( - key_vectors, self.num_attention_heads, self.attention_head_size - ) - value_vectors = self._transpose_for_scores( - value_vectors, self.num_attention_heads, self.attention_head_size - ) + query_vectors = self._transpose_for_scores(query_vectors, self.num_attention_heads, self.attention_head_size) + key_vectors = self._transpose_for_scores(key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._transpose_for_scores(value_vectors, self.num_attention_heads, self.attention_head_size) assert query_vectors.shape[-1] == self.attention_head_size assert key_vectors.shape[-1] == self.attention_head_size @@ -863,67 +715,40 @@ def forward( assert self.num_chunks_before == 0 and self.num_chunks_after == 0 key_vectors = key_vectors / torch.sqrt( - torch.tensor( - self.attention_head_size, device=key_vectors.device, dtype=torch.float32 - ) + torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32) ) # chunk vectors query_vectors = self._split_dim_by( - query_vectors, - -1, - self.chunk_length, - self.num_attention_heads, - self.attention_head_size, - ) + query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size key_vectors = self._split_dim_by( - key_vectors, - -1, - self.chunk_length, - self.num_attention_heads, - self.attention_head_size, - ) + key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size value_vectors = self._split_dim_by( - value_vectors, - -1, - self.chunk_length, - self.num_attention_heads, - self.attention_head_size, - ) + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size # chunk indices indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) - query_indices = self._split_dim_by( - indices, -1, self.chunk_length, self.num_attention_heads - ) - key_value_indices = self._split_dim_by( - indices, -1, self.chunk_length, self.num_attention_heads - ) + query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) + key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) # append chunks before and after - key_vectors = self._look_adjacent( - key_vectors, self.num_chunks_before, self.num_chunks_after - ) - value_vectors = self._look_adjacent( - value_vectors, self.num_chunks_before, self.num_chunks_after - ) - key_value_indices = self._look_adjacent( - key_value_indices, self.num_chunks_before, self.num_chunks_after - ) + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + key_value_indices = self._look_adjacent(key_value_indices, self.num_chunks_before, self.num_chunks_after) # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attention_mask = self._split_dim_by( - attention_mask, -1, self.chunk_length, 1 - ) - attention_mask = self._look_adjacent( - attention_mask, self.num_chunks_before, self.num_chunks_after - ) + attention_mask = self._split_dim_by(attention_mask, -1, self.chunk_length, 1) + attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) # get logits and dots + # query_key_dots shape is `[batch_size, num_attn_heads, seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory @@ -932,10 +757,7 @@ def forward( mask = None # Causal mask if self.is_decoder is True: - # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.ge( - query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2) - ).to(query_indices.device) + mask = torch.ge(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).to(query_indices.device) # Attention mask if attention_mask is not None: @@ -951,11 +773,7 @@ def forward( if mask is not None: query_key_dots = torch.where( - mask, - query_key_dots, - torch.tensor( - -1e9, dtype=query_key_dots.dtype, device=query_key_dots.device - ), + mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device), ) # free memory @@ -968,9 +786,7 @@ def forward( del logits # dropout - attention_probs = nn.functional.dropout( - attention_probs, self.dropout, self.training - ) + attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) # Mask heads if we want to if head_mask is not None: @@ -985,23 +801,14 @@ def forward( # merge chunk length out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - assert out_vectors.shape == ( - batch_size, - self.num_attention_heads, - sequence_length, - self.attention_head_size, - ) + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) - out_vectors = self._transpose_for_output( - out_vectors, self.num_attention_heads, self.attention_head_size - ) + out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) if do_output_attentions is False: attention_probs = () - return LocalSelfAttentionOutput( - hidden_states=out_vectors, attention_probs=attention_probs - ) + return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) class ReformerSelfOutput(nn.Module): @@ -1013,9 +820,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, self.dropout, self.training - ) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) return hidden_states @@ -1023,18 +828,14 @@ class ReformerAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.layer_id = layer_id - self.layer_norm = ReformerLayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn_layers = config.attn_layers if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": self.self_attention = LocalSelfAttention(config) - elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set( - ["lsh", "local"] - ): + elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): # get correct attn_layers if self.attn_layers[self.layer_id] == "lsh": self.self_attention = LSHSelfAttention(config) @@ -1077,9 +878,7 @@ def forward( buckets = None return AttentionOutput( - hidden_states=attention_output, - attention_probs=self_attention_outputs.attention_probs, - buckets=buckets, + hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, ) @@ -1095,9 +894,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, self.dropout, self.training - ) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) hidden_states = self.act_fn(hidden_states) return hidden_states @@ -1110,9 +907,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, self.dropout, self.training - ) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) return hidden_states @@ -1121,18 +916,13 @@ def __init__(self, config): super().__init__() self.seq_len_dim = 1 self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.layer_norm = ReformerLayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) def forward(self, attention_output): return apply_chunking_to_forward( - self.chunk_size_feed_forward, - self.seq_len_dim, - self.forward_chunk, - attention_output, + self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, ) def forward_chunk(self, hidden_states): @@ -1218,10 +1008,7 @@ def backward_pass( # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention output = self.attention( - hidden_states=hidden_states, - head_mask=head_mask, - attention_mask=attention_mask, - buckets=buckets, + hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, ).hidden_states torch.autograd.backward(output, grad_attn_output, retain_graph=True) @@ -1248,7 +1035,7 @@ def forward( hidden_states, layers, attention_mask, - head_masks, + head_mask, num_hashes, all_hidden_states, all_attentions, @@ -1258,16 +1045,15 @@ def forward( all_buckets = () hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) - for layer, head_mask in zip(layers, head_masks): + for layer, layer_head_mask in zip(layers, head_mask): if do_output_hidden_states is True: - all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) layer_outputs = layer( prev_attn_output=attn_output, hidden_states=hidden_states, attention_mask=attention_mask, - head_mask=head_mask, + head_mask=layer_head_mask, num_hashes=num_hashes, do_output_attentions=do_output_attentions, ) @@ -1280,12 +1066,11 @@ def forward( # Add last layer if do_output_hidden_states is True: - all_hidden_states.append(attn_output) all_hidden_states.append(hidden_states) # attach params to ctx for backward ctx.attention_mask = attention_mask - ctx.head_masks = head_masks + ctx.head_mask = head_mask ctx.attn_output = attn_output.detach() ctx.hidden_states = hidden_states.detach() ctx.layers = layers @@ -1296,35 +1081,28 @@ def forward( @staticmethod def backward(ctx, grad_hidden_states): - grad_attn_output, grad_hidden_states = torch.chunk( - grad_hidden_states, 2, dim=-1 - ) + grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) # retrieve params from ctx for backward attn_output = ctx.attn_output hidden_states = ctx.hidden_states attention_mask = ctx.attention_mask - head_masks = ctx.head_masks + head_mask = ctx.head_mask layers = ctx.layers all_buckets = ctx.all_buckets - for layer, head_mask in zip(layers[::-1], head_masks[::-1]): + for idx, layer in enumerate(layers[::-1]): # pop last buckets from stack buckets = all_buckets[-1] all_buckets = all_buckets[:-1] # backprop - ( - attn_output, - hidden_states, - grad_attn_output, - grad_hidden_states, - ) = layer.backward_pass( + (attn_output, hidden_states, grad_attn_output, grad_hidden_states,) = layer.backward_pass( next_attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, - head_mask=head_mask, + head_mask=head_mask[len(layers) - idx - 1], attention_mask=attention_mask, buckets=buckets, ) @@ -1342,21 +1120,17 @@ def backward(ctx, grad_hidden_states): class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() - self.layers = nn.ModuleList( - [ReformerLayer(config, i) for i in range(config.num_hidden_layers)] - ) + self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size - self.layer_norm = ReformerLayerNorm( - 2 * config.hidden_size, eps=config.layer_norm_eps - ) + self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) self.dropout = config.hidden_dropout_prob def forward( self, hidden_states, attention_mask=None, - head_masks=None, + head_mask=None, num_hashes=None, do_output_hidden_states=False, do_output_attentions=False, @@ -1371,7 +1145,7 @@ def forward( hidden_states, self.layers, attention_mask, - head_masks, + head_mask, num_hashes, all_hidden_states, all_attentions, @@ -1383,9 +1157,7 @@ def forward( hidden_states = self.layer_norm(hidden_states) # Apply dropout - hidden_states = nn.functional.dropout( - hidden_states, self.dropout, self.training - ) + hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) outputs = (hidden_states,) if do_output_hidden_states: @@ -1416,18 +1188,18 @@ def _init_weights(self, module): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: - torch.nn.init.xavier_uniform_(module.weight) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # TODO(PVP): discuss with Thom if necessary here to use different init - # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + # torch.nn.init.xavier_uniform_(module.weight) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.normal_(mean=0.0, std=1e-6) + module.bias.data.zero_() # TODO(PVP): discuss with Thom if necessary here to use different init -# module.bias.data.zero_() +# module.bias.data.normal_(mean=0.0, std=1e-6) class ReformerModel(ReformerPreTrainedModel): @@ -1478,7 +1250,7 @@ def forward( input_ids=None, attention_mask=None, position_ids=None, - head_masks=None, + head_mask=None, inputs_embeds=None, num_hashes=None, do_output_hidden_states=False, @@ -1504,96 +1276,74 @@ def forward( Examples:: """ + # TODO(PVP): delete when PR to change output_attentions is made + do_output_attentions = self.config.output_attentions + do_output_hidden_states = self.config.output_hidden_states if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() # noqa: F841 + device = input_ids.device elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] # noqa: F841 + device = inputs_embeds.device else: raise ValueError("You have to specify either input_ids or inputs_embeds") assert ( len(input_shape) == 2 - ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format( - input_shape - ) + ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) + + # prepare head mask + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) - device = ( - input_ids.device if input_ids is not None else inputs_embeds.device - ) # noqa: F841 real_sequence_length = input_shape[-1] # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) if input_shape[-1] % least_common_mult_chunk_length != 0: - # TODO: should also allow this when self.is_decoder is False? - # TODO: need to improve attn mask input possibility here + + padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length assert ( - self.training is False and self.config.is_decoder is True - ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if {}. Please consider padding the input to a length of {}.".format( - input_shape[-2], - least_common_mult_chunk_length, - "training" if self.training is True else "config.is_decoder = True", - input_shape[-1] - + ( - least_common_mult_chunk_length - - input_shape[-1] % least_common_mult_chunk_length - ), + self.training is False + ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if training. Please consider padding the input to a length of {}.".format( + input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length ) - if input_ids is None: - raise NotImplementedError("Currently only supported for `input_ids`") + logger.warning( + "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( + input_shape[-1], input_shape[-1] + padding_length, least_common_mult_chunk_length + ) + ) - padding_length = ( - least_common_mult_chunk_length - - input_shape[-1] % least_common_mult_chunk_length + padded_input_ids = torch.full( + (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, ) # Extend `input_ids` with padding to match least common multiple chunk_length - input_ids = torch.cat( - [ - input_ids, - torch.full( - (input_shape[0], padding_length), - self.config.pad_token_id, - device=input_ids.device, - dtype=torch.long, - ), - ], - dim=-1, - ) + if input_ids is not None: + input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) + input_shape = input_ids.size() + + # Extend `input_embeds` with padding to match least common multiple chunk_length + if inputs_embeds is not None: + padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) + inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) + input_shape = inputs_embeds.size() # Extend `attention_mask` if attention_mask is not None: attention_mask = torch.cat( - [ - attention_mask, - torch.zeros( - input_shape[0], - padding_length, - device=attention_mask.device, - dtype=torch.long, - ), - ], + [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,),], dim=-1, ) - input_shape = input_ids.size() - - # prepare head mask - head_masks = self.get_head_mask(head_masks, self.config.num_hidden_layers) - - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds - ) + embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( hidden_states=embedding_output, - head_masks=head_masks, + head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, @@ -1624,9 +1374,7 @@ def __init__(self, config): self.decoder.bias = self.bias def forward(self, hidden_states): - return apply_chunking_to_forward( - self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states - ) + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) # TODO(PVP): Does this work with backpropagation? def forward_chunk(self, hidden_states): @@ -1656,7 +1404,7 @@ def forward( input_ids=None, position_ids=None, attention_mask=None, - head_masks=None, + head_mask=None, inputs_embeds=None, num_hashes=None, lm_labels=None, @@ -1668,7 +1416,7 @@ def forward( input_ids, position_ids=position_ids, attention_mask=attention_mask, - head_masks=head_masks, + head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, do_output_hidden_states=do_output_hidden_states, @@ -1685,9 +1433,7 @@ def forward( shift_labels = lm_labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) - ) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) outputs = (loss,) + outputs return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 68a6c4de1c67..dd582042a353 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -175,7 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask, num_hidden_layers): + def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False): """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -189,6 +189,8 @@ def get_head_mask(self, head_mask, num_hidden_layers): """ if head_mask is not None: head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) else: head_mask = [None] * num_hidden_layers diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2909e17e9e49..c78e8af460d9 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -125,6 +125,9 @@ def test_attention_outputs(self): encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + chunk_length = self.model_tester.chunk_length * config.num_hashes for model_class in self.all_model_classes: config.output_attentions = True @@ -138,10 +141,16 @@ def test_attention_outputs(self): self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) out_len = len(outputs) if self.is_encoder_decoder: @@ -175,10 +184,16 @@ def test_attention_outputs(self): self_attentions = outputs[-1] self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) def test_torchscript(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -465,14 +480,16 @@ def test_hidden_states_output(self): self.assertEqual(model.config.output_attentions, False) self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [ - self.model_tester.encoder_seq_length - if hasattr(self.model_tester, "encoder_seq_length") - else self.model_tester.seq_length, - self.model_tester.hidden_size, - ], + list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size,], ) def test_resize_tokens_embeddings(self): @@ -485,6 +502,9 @@ def test_resize_tokens_embeddings(self): model = model_class(config) model.to(torch_device) + if self.model_tester.is_training is False: + model.eval() + model_vocab_size = config.vocab_size # Retrieve the embeddings and clone theme model_embed = model.resize_token_embeddings(model_vocab_size) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 419886b1a298..b71a9e589297 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -20,15 +20,356 @@ # trax imports - to be deleted later import trax -from transformers import is_torch_available # noqa: F401 +from transformers import is_torch_available from trax.shapes import ShapeDtype as trax_ShapeDtype -from .utils import require_torch, torch_device # noqa: F401 +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): + from transformers import ReformerConfig, ReformerAttention, ReformerModel, ReformerModelWithLMHead + + # from transformers.modeling_reformer import REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP import torch - from transformers import ReformerAttention, ReformerConfig, ReformerModelWithLMHead + + +@require_torch +class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + class ReformerLocalAttnModelTester(object): + def __init__( + self, + parent, + batch_size=13, + seq_length=16, + is_training=True, + is_decoder=False, + use_input_mask=True, + vocab_size=32, + attention_head_size=16, + hidden_size=32, + num_attention_heads=2, + local_attn_chunk_length=4, + num_chunks_before=1, + num_chunks_after=0, + chunk_size_lm_head=0, + chunk_size_feed_forward=0, + feed_forward_size=32, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + axial_norm_std=1.0, + layer_norm_eps=1e-12, + axial_pos_embds=True, + axial_pos_shape=[4, 4], + axial_pos_embds_dim=[16, 16], + attn_layers=["local", "local", "local", "local"], + pad_token_id=0, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.is_decoder = is_decoder + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = len(attn_layers) + self.local_attn_chunk_length = local_attn_chunk_length + self.num_chunks_after = num_chunks_after + self.num_chunks_before = num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.scope = scope + self.attn_layers = attn_layers + + self.encoder_seq_length = seq_length // local_attn_chunk_length + ( + self.seq_length % local_attn_chunk_length != 0 + ) + self.key_length = (self.num_chunks_before + self.num_chunks_after + 1) * local_attn_chunk_length + self.chunk_length = local_attn_chunk_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = ReformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + feed_forward_size=self.feed_forward_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + is_decoder=self.is_decoder, + axial_pos_embds=self.axial_pos_embds, + axial_pos_shape=self.axial_pos_shape, + axial_pos_embds_dim=self.axial_pos_embds_dim, + local_attn_chunk_length=self.local_attn_chunk_length, + num_chunks_after=self.num_chunks_after, + num_chunks_before=self.num_chunks_before, + attn_layers=self.attn_layers, + ) + + return ( + config, + input_ids, + input_mask, + ) + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) + + def create_and_check_reformer_model( + self, config, input_ids, input_mask, + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + (sequence_output,) = model(input_ids, attention_mask=input_mask) + (sequence_output,) = model(input_ids) + + result = { + "sequence_output": sequence_output, + } + # 2 * hidden_size because we use reversible resnet layers + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size] + ) + + def create_and_check_reformer_with_lm( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, lm_labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.check_loss_output(result) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask,) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester(self) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_reformer_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model(*config_and_inputs) + + def test_for_reformer_with_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + + +@require_torch +class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + class ReformerLSHAttnModelTester(object): + def __init__( + self, + parent, + batch_size=13, + seq_length=13, + use_input_mask=True, + is_training=False, + is_decoder=False, + vocab_size=32, + attention_head_size=16, + hidden_size=64, + num_attention_heads=2, + num_buckets=2, + num_hashes=4, + lsh_attn_chunk_length=4, + num_chunks_before=2, + num_chunks_after=3, + chunk_size_lm_head=5, + chunk_size_feed_forward=6, + feed_forward_size=32, + hidden_act="relu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + sinusoidal_pos_embds=True, + attn_layers=["lsh", "lsh", "lsh", "lsh"], + pad_token_id=0, + eos_token_id=2, + scope=None, + hash_seed=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.is_decoder = is_decoder + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hashes = num_hashes + self.num_hidden_layers = len(attn_layers) + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.num_chunks_after = num_chunks_after + self.num_chunks_before = num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.scope = scope + self.attn_layers = attn_layers + self.hash_seed = hash_seed + + self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) + self.key_length = (self.num_chunks_before + self.num_chunks_after + 1) * lsh_attn_chunk_length + self.chunk_length = lsh_attn_chunk_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = ReformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + feed_forward_size=self.feed_forward_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + is_decoder=self.is_decoder, + sinusoidal_pos_embds=self.sinusoidal_pos_embds, + num_hashes=self.num_hashes, + num_buckets=self.num_buckets, + lsh_attn_chunk_length=self.lsh_attn_chunk_length, + num_chunks_after=self.num_chunks_after, + num_chunks_before=self.num_chunks_before, + attn_layers=self.attn_layers, + hash_seed=self.hash_seed, + ) + + return ( + config, + input_ids, + input_mask, + ) + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) + + def create_and_check_reformer_model( + self, config, input_ids, input_mask, + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + (sequence_output,) = model(input_ids, attention_mask=input_mask) + (sequence_output,) = model(input_ids) + + result = { + "sequence_output": sequence_output, + } + # 2 * hidden_size because we use reversible resnet layers + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size] + ) + + def create_and_check_reformer_with_lm( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, lm_labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.check_loss_output(result) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask,) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = ReformerLSHAttnModelTest.ReformerLSHAttnModelTester(self) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_reformer_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model(*config_and_inputs) + + def test_for_reformer_with_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) @require_torch @@ -41,6 +382,7 @@ def test_lsh_layer(self): trax_layer = self.load_lsh_layer(config) input_signature = trax_ShapeDtype(shape, np.float32) trax_weights, trax_state = trax_layer.init(input_signature) + mask = np.ones(shape[:-1], dtype=np.int32) trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) @@ -83,15 +425,18 @@ def test_local_layer(self): self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_reformer_lm_model(self): - config = ReformerConfig() + config = ReformerConfig(sinusoidal_pos_embds=True, hash_seed=0, is_decoder=True) shape = (1, 192) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_zeros = np.zeros((shape[0], 1), dtype=np.int) - mode = "predict" + mode = "eval" trax_model = self.load_reformer_lm_model(config, mode=mode) + assert ( + config.is_decoder == True + ), "trax can only test casaul mask for ReformerLM. Use tests for layers to test non-casaul mask" input_signature = trax_ShapeDtype(shape, np.int32) trax_weights, trax_state = trax_model.init(input_signature) trax_output = trax_model(np_input, weights=trax_weights, state=trax_state) @@ -176,7 +521,7 @@ def load_lsh_layer(self, config, mode="eval"): config.num_buckets, config.attention_probs_dropout_prob, config.hidden_dropout_prob, - config.seed, + config.hash_seed, config.is_decoder, ) gin.parse_config(gin_config) @@ -290,7 +635,7 @@ def load_reformer_lm_model(self, config, mode="eval"): config.num_chunks_after, config.num_hashes, config.num_buckets, - config.seed, + config.hash_seed, config.is_decoder, config.local_attn_chunk_length, config.num_chunks_before, From f83721e9c9143ed9ac02ef564619f23e095abee4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 01:02:21 +0200 Subject: [PATCH 079/162] make style --- src/transformers/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 143c9b2b8fbe..82c1e2532315 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -181,7 +181,7 @@ (XLMConfig, XLMModel), (CTRLConfig, CTRLModel), (ElectraConfig, ElectraModel), - (ReformerConfig, ReformerModel) + (ReformerConfig, ReformerModel), ] ) From 125c86da52e2285ca62146fc95aecf1419d38d71 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 18:31:11 +0200 Subject: [PATCH 080/162] make sure masks work correctly --- src/transformers/configuration_reformer.py | 3 +- src/transformers/modeling_reformer.py | 22 +++-- tests/test_modeling_reformer.py | 93 +++++++++++++++++++++- 3 files changed, 108 insertions(+), 10 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 8b9439a5ce44..0984ff7287ce 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -124,10 +124,11 @@ def __init__( attn_layers=["lsh", "lsh", "lsh", "lsh"], is_decoder=False, pad_token_id=0, + eos_token_id=2, hash_seed=None, **kwargs ): - super().__init__(pad_token_id=pad_token_id, is_decoder=is_decoder, **kwargs) + super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs) self.hash_seed = hash_seed self.vocab_size = vocab_size diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 2122e771ff2d..e1c6e4967282 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -581,12 +581,17 @@ def _attend( # Attention mask: chunk, look up correct mask value from key_value_info # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. if attention_mask is not None: - attn_mask = attention_mask.to(torch.uint8)[:, None, None, :] + attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] # expand attn_mask to fit with key_value_info shape - attn_mask = attn_mask.expand(ticker.shape[:-1] + (-1,)) - attn_mask = torch.gather(attn_mask, -1, key_value_info) + attention_mask = attention_mask.expand(ticker.shape[:-1] + (-1,)) + key_attn_mask = torch.gather(attention_mask, -1, key_value_info) + query_attn_mask = torch.gather(attention_mask, -1, query_info) # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk - attn_mask = attn_mask.unsqueeze(-2).expand(query_key_dots.shape) + attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + + # free memory + del query_attn_mask, key_attn_mask, attention_mask + # multiply by casaul mask if necessary if mask is not None: mask = mask * attn_mask @@ -745,7 +750,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = self._split_dim_by(attention_mask, -1, self.chunk_length, 1) - attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) # get logits and dots # query_key_dots shape is `[batch_size, num_attn_heads, seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` @@ -761,7 +766,8 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # Attention mask if attention_mask is not None: - attn_mask = attention_mask.unsqueeze(-2).expand(query_key_dots.shape) + # create attn_mask + attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots.shape) # multiply by casaul mask if necessary if mask is not None: mask = mask * attn_mask @@ -769,7 +775,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ mask = attn_mask # free memory - del attn_mask + del attn_mask, attention_mask, attention_mask_key if mask is not None: query_key_dots = torch.where( @@ -1335,7 +1341,7 @@ def forward( # Extend `attention_mask` if attention_mask is not None: attention_mask = torch.cat( - [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,),], + [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,)], dim=-1, ) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index b71a9e589297..fd6b39705653 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -106,6 +106,7 @@ def __init__( self.chunk_size_feed_forward = chunk_size_feed_forward self.scope = scope self.attn_layers = attn_layers + self.pad_token_id = pad_token_id self.encoder_seq_length = seq_length // local_attn_chunk_length + ( self.seq_length % local_attn_chunk_length != 0 @@ -138,6 +139,7 @@ def prepare_config_and_inputs(self): num_chunks_after=self.num_chunks_after, num_chunks_before=self.num_chunks_before, attn_layers=self.attn_layers, + pad_token_id=self.pad_token_id, ) return ( @@ -182,6 +184,43 @@ def create_and_check_reformer_with_lm( ) self.check_loss_output(result) + def create_and_check_reformer_model_with_attn_mask( + self, config, input_ids, input_mask, is_decoder + ): + # no special position embeddings + config.axial_pos_embds = False + config.is_decoder = is_decoder + + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + # set all position encodings to zero so that postions don't matter + with torch.no_grad(): + embedding = model.embeddings.position_embeddings.embedding + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape)) + embedding.weight.requires_grad = False + + half_seq_len = self.seq_length // 2 + roll = self.local_attn_chunk_length + roll = self.local_attn_chunk_length + half_input_ids = input_ids[:, :half_seq_len] + + # normal padded + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) + input_ids_padded = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + + # shifted padded + input_ids_roll = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) + attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) + +# input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) + + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll: half_seq_len + roll] + + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -199,6 +238,11 @@ def test_reformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def test_reformer_model_attn_masking(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + def test_for_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) @@ -275,6 +319,7 @@ def __init__( self.scope = scope self.attn_layers = attn_layers self.hash_seed = hash_seed + self.pad_token_id = pad_token_id self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) self.key_length = (self.num_chunks_before + self.num_chunks_after + 1) * lsh_attn_chunk_length @@ -306,6 +351,7 @@ def prepare_config_and_inputs(self): num_chunks_before=self.num_chunks_before, attn_layers=self.attn_layers, hash_seed=self.hash_seed, + pad_token_id=self.pad_token_id, ) return ( @@ -350,6 +396,46 @@ def create_and_check_reformer_with_lm( ) self.check_loss_output(result) + def create_and_check_reformer_model_with_attn_mask( + self, config, input_ids, input_mask, is_decoder + ): + # no special position embeddings + config.axial_pos_embds = False + config.is_decoder = is_decoder + + # need to set chunk length equal sequence length to be certain that chunking works + config.lsh_attn_chunk_length = self.seq_length + + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + # set all position encodings to zero so that postions don't matter + with torch.no_grad(): + embedding = model.embeddings.position_embeddings.embedding + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape)) + embedding.weight.requires_grad = False + + half_seq_len = self.seq_length // 2 + roll = self.lsh_attn_chunk_length + roll = half_seq_len + half_input_ids = input_ids[:, :half_seq_len] + + # normal padded + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) + input_ids_padded = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + + # shifted padded + input_ids_roll = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) + attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) + +# input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) + + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll: half_seq_len + roll] + + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -367,6 +453,11 @@ def test_reformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def test_reformer_model_attn_masking(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + def test_for_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) @@ -435,7 +526,7 @@ def test_reformer_lm_model(self): trax_model = self.load_reformer_lm_model(config, mode=mode) assert ( - config.is_decoder == True + config.is_decoder is True ), "trax can only test casaul mask for ReformerLM. Use tests for layers to test non-casaul mask" input_signature = trax_ShapeDtype(shape, np.int32) trax_weights, trax_state = trax_model.init(input_signature) From 2beda9ccae9293f6ea4be8cdb5702e42e12d7c7a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 19:25:50 +0200 Subject: [PATCH 081/162] detach gradients --- src/transformers/modeling_reformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e1c6e4967282..63fdc6072c71 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -401,7 +401,6 @@ def forward( assert int(buckets.shape[-1]) == num_hashes * sequence_length ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets, num_hashes) - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, ticker, num_hashes) @@ -537,10 +536,17 @@ def _get_ticker_and_undo_ticker(self, sequence_length, buckets, num_hashes): buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + # remove gradient + buckets_and_t.detach() + # Hash-based sort sorted_ticker = torch.argsort(buckets_and_t, dim=-1) undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + # remove gradient + sorted_ticker.detach() + undo_sorted_ticker.detach() + sorted_ticker = sorted_ticker % sequence_length return sorted_ticker, undo_sorted_ticker @@ -999,7 +1005,6 @@ def backward_pass( # g(Y_1) attn_output = self.feed_forward(next_attn_output) torch.autograd.backward(attn_output, grad_hidden_states) - # attn_output = self.attention(hidden_states=next_attn_output, head_mask=head_mask, buckets=buckets).hidden_states with torch.no_grad(): # X_2 = Y_2 - g(Y_1) From 12e35e165eefff2d9135326d516699f75d9127d4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 19:43:29 +0200 Subject: [PATCH 082/162] save intermediate --- src/transformers/modeling_reformer.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 63fdc6072c71..1aca9907e634 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -655,13 +655,33 @@ def _attend( # free memory del value_vectors + class UnsortLogits(Function): + @staticmethod + def forward(ctx, out_vectors, logits): + out_vectors = out_vectors.detach() + logits = logits.detach() + expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_ticker) + return out_vectors, logits + + @staticmethod + def backward(ctx, grad_out_vectors, grad_logits): + import ipdb + ipdb.set_trace() + expanded_sort_indices = ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) + grad_logits = torch.gather(grad_logits, 2, ticker) + return grad_out_vectors, grad_logits + # merge chunk length logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_ticker) + out_vectors, logits = UnsortLogits.apply(out_vectors, logits) + +# out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) +# logits = torch.gather(logits, 2, undo_ticker) return out_vectors, logits, dots From 8b058e27b731efe3ccc3102c870f023bd6e76072 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 23:22:12 +0200 Subject: [PATCH 083/162] correct backprob through gather --- src/transformers/modeling_reformer.py | 161 ++++++++++++++------------ 1 file changed, 87 insertions(+), 74 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1aca9907e634..ce62a9b0fd89 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -377,7 +377,6 @@ def forward( query_key_vectors = self.query_key(hidden_states) value_vectors = self.value(hidden_states) - # free memory del hidden_states @@ -400,9 +399,11 @@ def forward( assert int(buckets.shape[-1]) == num_hashes * sequence_length - ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets, num_hashes) - query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker, num_hashes) - value_vectors = self._gather_by_expansion(value_vectors, ticker, num_hashes) + sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( + sequence_length, buckets, num_hashes + ) + query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) + value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) query_key_vectors = self._split_dim_by( query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, @@ -411,8 +412,6 @@ def forward( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) - ticker = self._split_dim_by(ticker, -1, self.chunk_length, self.num_attention_heads) - if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0 @@ -422,12 +421,11 @@ def forward( query_vectors=query_key_vectors, key_vectors=key_vectors, value_vectors=value_vectors, - ticker=ticker, - undo_ticker=undo_ticker, + sorted_bucket_idx=sorted_bucket_idx, + undo_sorted_bucket_idx=undo_sorted_bucket_idx, attention_mask=attention_mask, head_mask=head_mask, ) - # free memory del query_key_vectors, key_vectors, value_vectors @@ -441,7 +439,6 @@ def forward( probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) - # free memory del probs_vectors @@ -527,28 +524,29 @@ def _hash_vectors(self, vectors, num_hashes): return offset_buckets - def _get_ticker_and_undo_ticker(self, sequence_length, buckets, num_hashes): + def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): batch_size = buckets.shape[0] - # TODO: what is ticker? Is ticker something like indices?? Ask authors - ticker = torch.arange(num_hashes * sequence_length, device=buckets.device) - ticker = ticker.repeat(batch_size, self.num_attention_heads, 1) + # TODO: what is sorted_bucket_idx? Is sorted_bucket_idx something like indices?? Ask authors + orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device) + orig_indices = orig_indices.repeat(batch_size, self.num_attention_heads, 1) - buckets_and_t = sequence_length * buckets + (ticker % sequence_length) + # scale buckets + scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) # remove gradient - buckets_and_t.detach() + scaled_buckets.detach() # Hash-based sort - sorted_ticker = torch.argsort(buckets_and_t, dim=-1) - undo_sorted_ticker = torch.argsort(sorted_ticker, dim=-1) + sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) + undo_sorted_bucket_idx = torch.argsort(sorted_bucket_idx, dim=-1) # remove gradient - sorted_ticker.detach() - undo_sorted_ticker.detach() + sorted_bucket_idx.detach() + undo_sorted_bucket_idx.detach() - sorted_ticker = sorted_ticker % sequence_length - return sorted_ticker, undo_sorted_ticker + sorted_bucket_idx = sorted_bucket_idx % sequence_length + return sorted_bucket_idx, undo_sorted_bucket_idx def _set_num_buckets_on_the_fly(self, sequence_length): # recommended `num_buckets` from paper @@ -563,38 +561,43 @@ def _set_num_buckets_on_the_fly(self, sequence_length): self.num_buckets = num_buckets def _attend( - self, query_vectors, key_vectors, value_vectors, ticker, undo_ticker, attention_mask, head_mask, + self, + query_vectors, + key_vectors, + value_vectors, + sorted_bucket_idx, + undo_sorted_bucket_idx, + attention_mask, + head_mask, ): key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - # free memory del query_vectors, key_vectors - # TODO(PVP): better naming - query_info = ticker - key_value_info = self._look_adjacent(ticker, self.num_chunks_before, self.num_chunks_after) + query_bucket_idx = self._split_dim_by(sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads) + key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) mask = None # Causal mask if self.is_decoder: - # TODO (PVP): This line can be improved in terms of memory. Mask should be inserted and does not have to be created each time layer is called - mask = torch.ge(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + mask = torch.ge(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( + query_bucket_idx.device + ) - # Attention mask: chunk, look up correct mask value from key_value_info + # Attention mask: chunk, look up correct mask value from key_value_bucket_idx # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] - # expand attn_mask to fit with key_value_info shape - attention_mask = attention_mask.expand(ticker.shape[:-1] + (-1,)) - key_attn_mask = torch.gather(attention_mask, -1, key_value_info) - query_attn_mask = torch.gather(attention_mask, -1, query_info) + # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask.expand(query_bucket_idx.shape[:-1] + (-1,)) + key_attn_mask = torch.gather(attention_mask, -1, key_value_bucket_idx) + query_attn_mask = torch.gather(attention_mask, -1, query_bucket_idx) # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) - # free memory del query_attn_mask, key_attn_mask, attention_mask @@ -603,7 +606,6 @@ def _attend( mask = mask * attn_mask else: mask = attn_mask - # free memory del attn_mask @@ -625,18 +627,16 @@ def _attend( # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " - mask = torch.ne(query_info.unsqueeze(-1), key_value_info.unsqueeze(-2)).to(query_info.device) + mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) query_key_dots = torch.where( mask, query_key_dots, torch.tensor(-1e5, dtype=query_key_dots.dtype, device=query_key_dots.device), ) - # free memory del mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` dots = torch.exp(query_key_dots - logits) - # free memory del query_key_dots @@ -651,37 +651,17 @@ def _attend( # attend values out_vectors = torch.matmul(dots, value_vectors) - # free memory del value_vectors - class UnsortLogits(Function): - @staticmethod - def forward(ctx, out_vectors, logits): - out_vectors = out_vectors.detach() - logits = logits.detach() - expanded_undo_sort_indices = undo_ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_ticker) - return out_vectors, logits - - @staticmethod - def backward(ctx, grad_out_vectors, grad_logits): - import ipdb - ipdb.set_trace() - expanded_sort_indices = ticker.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) - grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) - grad_logits = torch.gather(grad_logits, 2, ticker) - return grad_out_vectors, grad_logits - # merge chunk length logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - out_vectors, logits = UnsortLogits.apply(out_vectors, logits) + out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) -# out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) -# logits = torch.gather(logits, 2, undo_ticker) + # out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + # logits = torch.gather(logits, 2, undo_sorted_bucket_idx) return out_vectors, logits, dots @@ -703,6 +683,35 @@ def _gather_by_expansion(self, vectors, idxs, num_hashes): return torch.gather(vectors, 2, expanded_idxs) +class UndoSort(Function): + @staticmethod + def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx): + # save sorted_bucket_idx for backprop + ctx.sorted_bucket_idx = sorted_bucket_idx + + out_vectors = out_vectors.detach() + logits = logits.detach() + + # undo sort to have correct order for next layer + expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_sorted_bucket_idx) + return out_vectors, logits + + @staticmethod + def backward(ctx, grad_out_vectors, grad_logits): + # get sorted_bucket_idx + sorted_bucket_idx = ctx.sorted_bucket_idx + + # reverse sort of forward + expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) + grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) + grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx) + + # return grad and `None` fillers for last 2 forward args + return grad_out_vectors, grad_logits, None, None + + class LocalSelfAttention(nn.Module, EfficientAttentionUtils): def __init__(self, config): super().__init__() @@ -781,7 +790,6 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # get logits and dots # query_key_dots shape is `[batch_size, num_attn_heads, seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) - # free memory del query_vectors, key_vectors @@ -799,7 +807,6 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ mask = mask * attn_mask else: mask = attn_mask - # free memory del attn_mask, attention_mask, attention_mask_key @@ -813,7 +820,6 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) - # free memory del logits @@ -826,7 +832,6 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # attend values out_vectors = torch.matmul(attention_probs, value_vectors) - # free memory del value_vectors @@ -1342,7 +1347,7 @@ def forward( input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length ) - logger.warning( + logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( input_shape[-1], input_shape[-1] + padding_length, least_common_mult_chunk_length ) @@ -1352,6 +1357,21 @@ def forward( (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, ) + # Extend `attention_mask` + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,)], + dim=-1, + ) + else: + attention_mask = torch.cat( + [ + torch.ones(input_shape, device=device, dtype=torch.uint8), + torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), + ], + dim=-1, + ) + # Extend `input_ids` with padding to match least common multiple chunk_length if input_ids is not None: input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) @@ -1363,13 +1383,6 @@ def forward( inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) input_shape = inputs_embeds.size() - # Extend `attention_mask` - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,)], - dim=-1, - ) - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( From 69258b80896308f3dc77dd2424a5627869e68bb7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Apr 2020 23:22:33 +0200 Subject: [PATCH 084/162] make style --- tests/test_modeling_reformer.py | 36 ++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index fd6b39705653..3341acf986e4 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -184,9 +184,7 @@ def create_and_check_reformer_with_lm( ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -207,17 +205,23 @@ def create_and_check_reformer_model_with_attn_mask( # normal padded attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) - input_ids_padded = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_padded = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + ) # shifted padded - input_ids_roll = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_roll = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) -# input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) + # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll: half_seq_len + roll] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ + :, roll : half_seq_len + roll + ] self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) @@ -396,9 +400,7 @@ def create_and_check_reformer_with_lm( ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -422,17 +424,23 @@ def create_and_check_reformer_model_with_attn_mask( # normal padded attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) - input_ids_padded = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_padded = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + ) # shifted padded - input_ids_roll = torch.cat([half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1) + input_ids_roll = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) -# input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) + # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll: half_seq_len + roll] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ + :, roll : half_seq_len + roll + ] self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) From 44c3a7cf91eb373e98c6b8953bb67ae268d389a8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 10:41:51 +0200 Subject: [PATCH 085/162] change back num hashes --- src/transformers/modeling_reformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ce62a9b0fd89..e77e44e32e91 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1233,8 +1233,6 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() # TODO(PVP): discuss with Thom if necessary here to use different init - - # module.bias.data.normal_(mean=0.0, std=1e-6) From 48fff07316385f207c194e86f31d756a303378b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 14:13:32 +0200 Subject: [PATCH 086/162] rename to labels --- src/transformers/modeling_reformer.py | 12 ++++++------ tests/test_modeling_reformer.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e77e44e32e91..052c9267946f 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1449,7 +1449,7 @@ def forward( head_mask=None, inputs_embeds=None, num_hashes=None, - lm_labels=None, + labels=None, do_output_hidden_states=False, do_output_attentions=False, ): @@ -1466,13 +1466,13 @@ def forward( ) sequence_output = reformer_outputs[0] - lm_logits = self.lm_head(sequence_output) - outputs = (lm_logits,) + reformer_outputs[1:] + logits = self.lm_head(sequence_output) + outputs = (logits,) + reformer_outputs[1:] - if lm_labels is not None: + if labels is not None: # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = lm_labels[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 3341acf986e4..5decb648e9f4 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -174,7 +174,7 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, lm_labels=input_ids) + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) result = { "loss": loss, "prediction_scores": prediction_scores, @@ -390,7 +390,7 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, lm_labels=input_ids) + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) result = { "loss": loss, "prediction_scores": prediction_scores, @@ -565,7 +565,7 @@ def test_backprop_lm_model(self): ) model = ReformerModelWithLMHead(config) - loss = model(input_ids, lm_labels=input_ids)[0] + loss = model(input_ids, labels=input_ids)[0] loss.backward() def test_pretrained_crime_and_punishment_lm_model(self): From 55842be621399ef97ed40470d32e810d17799548 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 15:34:45 +0200 Subject: [PATCH 087/162] fix rotation shape --- src/transformers/modeling_reformer.py | 17 +++++++++-------- tests/test_modeling_reformer.py | 6 +++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 052c9267946f..b2b070717e2e 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -476,24 +476,25 @@ def _hash_vectors(self, vectors, num_hashes): rotation_size += num_bucket num_buckets *= num_bucket - rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) - # TODO: delete later when integration tests are ok if self.hash_seed is not None: + rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) np.random.seed(self.hash_seed) random_rotations = torch.tensor( np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device, ) + rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) else: + rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x self.num_buckets/2 random_rotations = torch.randn(rotations_shape, device=vectors.device) - # rotated_vectors has dim: - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 - # TODO: IMPORTANT: At the moment we use the same random rotation over all batches - # and heads -> is that bad? It seems like in original reformer a different random - # rotation is used over heads and batches - rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) + # rotated_vectors has dim: + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + # TODO: IMPORTANT: At the moment we use the same random rotation over all batches + # -> is that bad? It seems like in original reformer a different random + # rotation is used batches + rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 5decb648e9f4..75655436666d 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -474,8 +474,8 @@ def test_for_reformer_with_lm(self): @require_torch class ReformerIntegrationTests(unittest.TestCase): def test_lsh_layer(self): - config = ReformerConfig() - shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize + config = ReformerConfig(hash_seed=0) + shape = (3, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) trax_layer = self.load_lsh_layer(config) @@ -500,7 +500,7 @@ def test_lsh_layer(self): self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_local_layer(self): - config = ReformerConfig() + config = ReformerConfig(hash_seed=0) shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize np_input = np.random.rand(*shape) From 71426c096273e776f56cd1a0296501622c66f66a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 15:49:44 +0200 Subject: [PATCH 088/162] fix detach --- src/transformers/modeling_reformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index b2b070717e2e..6af5f7840eac 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -536,15 +536,15 @@ def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buc scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) # remove gradient - scaled_buckets.detach() + scaled_buckets = scaled_buckets.detach() # Hash-based sort sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) undo_sorted_bucket_idx = torch.argsort(sorted_bucket_idx, dim=-1) # remove gradient - sorted_bucket_idx.detach() - undo_sorted_bucket_idx.detach() + sorted_bucket_idx = sorted_bucket_idx.detach() + undo_sorted_bucket_idx = undo_sorted_bucket_idx.detach() sorted_bucket_idx = sorted_bucket_idx % sequence_length return sorted_bucket_idx, undo_sorted_bucket_idx From dfbcf8fe7a58f5be0a4ed2b85e6709ca2c72e04e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 22:21:30 +0200 Subject: [PATCH 089/162] update --- src/transformers/modeling_reformer.py | 47 +++++++++++++-------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 6af5f7840eac..ba197b491216 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -402,6 +402,10 @@ def forward( sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sequence_length, buckets, num_hashes ) + + # make sure bucket idx is not longer then sequence length + sorted_bucket_idx = sorted_bucket_idx % sequence_length + query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) @@ -422,13 +426,15 @@ def forward( key_vectors=key_vectors, value_vectors=value_vectors, sorted_bucket_idx=sorted_bucket_idx, - undo_sorted_bucket_idx=undo_sorted_bucket_idx, attention_mask=attention_mask, head_mask=head_mask, ) # free memory del query_key_vectors, key_vectors, value_vectors + # gather correct out vectors + out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) + if num_hashes > 1: out_vectors = self._split_dim_by( out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, @@ -546,7 +552,6 @@ def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buc sorted_bucket_idx = sorted_bucket_idx.detach() undo_sorted_bucket_idx = undo_sorted_bucket_idx.detach() - sorted_bucket_idx = sorted_bucket_idx % sequence_length return sorted_bucket_idx, undo_sorted_bucket_idx def _set_num_buckets_on_the_fly(self, sequence_length): @@ -567,7 +572,6 @@ def _attend( key_vectors, value_vectors, sorted_bucket_idx, - undo_sorted_bucket_idx, attention_mask, head_mask, ): @@ -659,11 +663,6 @@ def _attend( logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) - - # out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - # logits = torch.gather(logits, 2, undo_sorted_bucket_idx) - return out_vectors, logits, dots def _len_and_dim_norm(self, vectors): @@ -1029,13 +1028,13 @@ def backward_pass( # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # g(Y_1) - attn_output = self.feed_forward(next_attn_output) - torch.autograd.backward(attn_output, grad_hidden_states) + res_hidden_states = self.feed_forward(next_attn_output) + torch.autograd.backward(res_hidden_states, grad_hidden_states) with torch.no_grad(): # X_2 = Y_2 - g(Y_1) - hidden_states = hidden_states - attn_output - del attn_output + hidden_states = hidden_states - res_hidden_states + del res_hidden_states grad_attn_output = grad_attn_output + next_attn_output.grad next_attn_output.grad = None @@ -1106,12 +1105,11 @@ def forward( all_hidden_states.append(hidden_states) # attach params to ctx for backward - ctx.attention_mask = attention_mask - ctx.head_mask = head_mask - ctx.attn_output = attn_output.detach() - ctx.hidden_states = hidden_states.detach() + ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) ctx.layers = layers ctx.all_buckets = all_buckets + ctx.head_mask = head_mask + ctx.attention_mask = attention_mask # Concatenate 2 RevNet outputs return torch.cat([attn_output, hidden_states], dim=-1) @@ -1121,12 +1119,11 @@ def backward(ctx, grad_hidden_states): grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) # retrieve params from ctx for backward - attn_output = ctx.attn_output - hidden_states = ctx.hidden_states - attention_mask = ctx.attention_mask - head_mask = ctx.head_mask + attn_output, hidden_states, = ctx.saved_tensors layers = ctx.layers all_buckets = ctx.all_buckets + head_mask = ctx.head_mask + attention_mask = ctx.attention_mask for idx, layer in enumerate(layers[::-1]): # pop last buckets from stack @@ -1144,6 +1141,8 @@ def backward(ctx, grad_hidden_states): buckets=buckets, ) + assert all_buckets == (), "buckets have to be empty after backpropagation" + # free memory del attn_output, hidden_states @@ -1225,16 +1224,16 @@ def _init_weights(self, module): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + torch.nn.init.xavier_uniform_(module.weight) # TODO(PVP): discuss with Thom if necessary here to use different init - # torch.nn.init.xavier_uniform_(module.weight) +# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.data.normal_(mean=0.0, std=1e-6) # TODO(PVP): discuss with Thom if necessary here to use different init -# module.bias.data.normal_(mean=0.0, std=1e-6) +# module.bias.data.zero_() class ReformerModel(ReformerPreTrainedModel): From 0ea564ca50ff0f8e9276c7f51c4c5b218d64d756 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Apr 2020 22:43:25 +0200 Subject: [PATCH 090/162] fix trainer --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2fcfe4531d7b..fe11bbc111a7 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -226,7 +226,7 @@ def train(self, model_path: Optional[str] = None): if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 + self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps + 1) ) else: t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) From af3456ce22daf22d53e9a2a0942578a47f90e203 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Apr 2020 12:09:52 +0200 Subject: [PATCH 091/162] fix backward dropout --- src/transformers/configuration_reformer.py | 8 +++-- src/transformers/modeling_reformer.py | 35 ++++++++++++++++------ tests/test_modeling_reformer.py | 16 +++++----- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 0984ff7287ce..1eed026af3cd 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -111,8 +111,9 @@ def __init__( chunk_size_feed_forward=0, feed_forward_size=128, hidden_act="gelu", - hidden_dropout_prob=0.3, - attention_probs_dropout_prob=0.3, + hidden_dropout_prob=0.0, + lsh_attention_probs_dropout_prob=0.0, + local_attention_probs_dropout_prob=0.0, max_position_embeddings=512, initializer_range=0.02, axial_norm_std=1.0, @@ -145,7 +146,8 @@ def __init__( self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ba197b491216..fc5356cc11b4 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -15,6 +15,8 @@ # limitations under the License. """PyTorch REFORMER model. """ +import sys + import inspect import logging from collections import namedtuple @@ -357,7 +359,7 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dropout = config.attention_probs_dropout_prob + self.dropout = config.lsh_attention_probs_dropout_prob def forward( self, @@ -534,7 +536,6 @@ def _hash_vectors(self, vectors, num_hashes): def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): batch_size = buckets.shape[0] - # TODO: what is sorted_bucket_idx? Is sorted_bucket_idx something like indices?? Ask authors orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device) orig_indices = orig_indices.repeat(batch_size, self.num_attention_heads, 1) @@ -731,7 +732,7 @@ def __init__(self, config): self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dropout = config.attention_probs_dropout_prob + self.dropout = config.local_attention_probs_dropout_prob def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): @@ -973,6 +974,18 @@ def __init__(self, config, layer_id=0): super().__init__() self.attention = ReformerAttention(config, layer_id) self.feed_forward = ChunkReformerFeedForward(config) + # dropout requires to have the same + # seed for forward and backward pass + self.attention_seed = None + self.feed_forward_seed = None + + def _init_attention_seed(self): + self.attention_seed = torch.randint(sys.maxsize, size=(1, 1)).item() + torch.manual_seed(self.attention_seed) + + def _init_feed_forward_seed(self): + self.feed_forward_seed = torch.randint(sys.maxsize, size=(1, 1)).item() + torch.manual_seed(self.feed_forward_seed) def forward( self, @@ -985,6 +998,7 @@ def forward( ): with torch.no_grad(): + self._init_attention_seed() attn_outputs = self.attention( hidden_states=hidden_states, head_mask=head_mask, @@ -1002,6 +1016,7 @@ def forward( del prev_attn_output # Y_2 = X_2 + g(Y_1) + self._init_feed_forward_seed() hidden_states = hidden_states + self.feed_forward(attn_output) return ReformerOutput( @@ -1028,8 +1043,9 @@ def backward_pass( # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # g(Y_1) + torch.manual_seed(self.feed_forward_seed) res_hidden_states = self.feed_forward(next_attn_output) - torch.autograd.backward(res_hidden_states, grad_hidden_states) + res_hidden_states.backward(grad_hidden_states, retain_graph=True) with torch.no_grad(): # X_2 = Y_2 - g(Y_1) @@ -1043,10 +1059,11 @@ def backward_pass( hidden_states.requires_grad = True # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention + torch.manual_seed(self.attention_seed) output = self.attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, ).hidden_states - torch.autograd.backward(output, grad_attn_output, retain_graph=True) + output.backward(grad_attn_output, retain_graph=True) with torch.no_grad(): # X_1 = Y_1 - f(X_2) @@ -1224,16 +1241,16 @@ def _init_weights(self, module): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: - torch.nn.init.xavier_uniform_(module.weight) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # TODO(PVP): discuss with Thom if necessary here to use different init -# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) +# torch.nn.init.xavier_uniform_(module.weight) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.normal_(mean=0.0, std=1e-6) + module.bias.data.zero_() # TODO(PVP): discuss with Thom if necessary here to use different init -# module.bias.data.zero_() +# module.bias.data.normal_(mean=0.0, std=1e-6) class ReformerModel(ReformerPreTrainedModel): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 75655436666d..47d12bd5e88b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -64,7 +64,7 @@ def __init__( feed_forward_size=32, hidden_act="gelu", hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, + local_attention_probs_dropout_prob=0.1, max_position_embeddings=512, initializer_range=0.02, axial_norm_std=1.0, @@ -94,7 +94,7 @@ def __init__( self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps @@ -129,7 +129,7 @@ def prepare_config_and_inputs(self): feed_forward_size=self.feed_forward_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, + local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, is_decoder=self.is_decoder, axial_pos_embds=self.axial_pos_embds, @@ -283,7 +283,7 @@ def __init__( feed_forward_size=32, hidden_act="relu", hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, + lsh_attention_probs_dropout_prob=0.1, max_position_embeddings=512, initializer_range=0.02, layer_norm_eps=1e-12, @@ -313,7 +313,7 @@ def __init__( self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps @@ -344,7 +344,7 @@ def prepare_config_and_inputs(self): feed_forward_size=self.feed_forward_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, + lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, is_decoder=self.is_decoder, sinusoidal_pos_embds=self.sinusoidal_pos_embds, @@ -618,7 +618,7 @@ def load_lsh_layer(self, config, mode="eval"): config.num_chunks_after, config.num_hashes, config.num_buckets, - config.attention_probs_dropout_prob, + config.lsh_attention_probs_dropout_prob, config.hidden_dropout_prob, config.hash_seed, config.is_decoder, @@ -651,7 +651,7 @@ def load_local_layer(self, config, mask=False, mode="eval"): config.local_attn_chunk_length, config.num_chunks_before, config.num_chunks_after, - config.attention_probs_dropout_prob, + config.local_attention_probs_dropout_prob, config.hidden_dropout_prob, config.is_decoder, mask, From 002f19c1b5e3cfee72fddc845d28a90eb2f60462 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Apr 2020 12:53:44 +0200 Subject: [PATCH 092/162] make reformer more flexible --- src/transformers/configuration_reformer.py | 12 ++++-- src/transformers/modeling_reformer.py | 8 ++-- tests/test_modeling_reformer.py | 44 +++++++++++----------- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 1eed026af3cd..af3e4876381d 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -105,8 +105,10 @@ def __init__( num_hashes=4, lsh_attn_chunk_length=64, local_attn_chunk_length=64, - num_chunks_before=1, - num_chunks_after=0, + lsh_num_chunks_before=1, + lsh_num_chunks_after=0, + local_num_chunks_before=1, + local_num_chunks_after=0, chunk_size_lm_head=0, chunk_size_feed_forward=0, feed_forward_size=128, @@ -141,8 +143,10 @@ def __init__( self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets self.lsh_attn_chunk_length = lsh_attn_chunk_length self.local_attn_chunk_length = local_attn_chunk_length - self.num_chunks_after = num_chunks_after - self.num_chunks_before = num_chunks_before + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index fc5356cc11b4..7efae7b94084 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -347,8 +347,8 @@ def __init__(self, config): self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets self.chunk_length = config.lsh_attn_chunk_length - self.num_chunks_before = config.num_chunks_before - self.num_chunks_after = config.num_chunks_after + self.num_chunks_before = config.lsh_num_chunks_before + self.num_chunks_after = config.lsh_num_chunks_after self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings @@ -719,8 +719,8 @@ def __init__(self, config): self.num_attention_heads = config.num_attention_heads self.chunk_length = config.local_attn_chunk_length - self.num_chunks_before = config.num_chunks_before - self.num_chunks_after = config.num_chunks_after + self.num_chunks_before = config.local_num_chunks_before + self.num_chunks_after = config.local_num_chunks_after self.is_decoder = config.is_decoder self.pad_token_id = config.pad_token_id diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 47d12bd5e88b..da67f0b66ccd 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -57,8 +57,8 @@ def __init__( hidden_size=32, num_attention_heads=2, local_attn_chunk_length=4, - num_chunks_before=1, - num_chunks_after=0, + local_num_chunks_before=1, + local_num_chunks_after=0, chunk_size_lm_head=0, chunk_size_feed_forward=0, feed_forward_size=32, @@ -89,8 +89,8 @@ def __init__( self.num_attention_heads = num_attention_heads self.num_hidden_layers = len(attn_layers) self.local_attn_chunk_length = local_attn_chunk_length - self.num_chunks_after = num_chunks_after - self.num_chunks_before = num_chunks_before + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob @@ -111,7 +111,7 @@ def __init__( self.encoder_seq_length = seq_length // local_attn_chunk_length + ( self.seq_length % local_attn_chunk_length != 0 ) - self.key_length = (self.num_chunks_before + self.num_chunks_after + 1) * local_attn_chunk_length + self.key_length = (self.local_num_chunks_before + self.local_num_chunks_after + 1) * local_attn_chunk_length self.chunk_length = local_attn_chunk_length def prepare_config_and_inputs(self): @@ -136,8 +136,8 @@ def prepare_config_and_inputs(self): axial_pos_shape=self.axial_pos_shape, axial_pos_embds_dim=self.axial_pos_embds_dim, local_attn_chunk_length=self.local_attn_chunk_length, - num_chunks_after=self.num_chunks_after, - num_chunks_before=self.num_chunks_before, + local_num_chunks_after=self.local_num_chunks_after, + local_num_chunks_before=self.local_num_chunks_before, attn_layers=self.attn_layers, pad_token_id=self.pad_token_id, ) @@ -276,8 +276,8 @@ def __init__( num_buckets=2, num_hashes=4, lsh_attn_chunk_length=4, - num_chunks_before=2, - num_chunks_after=3, + lsh_num_chunks_before=2, + lsh_num_chunks_after=3, chunk_size_lm_head=5, chunk_size_feed_forward=6, feed_forward_size=32, @@ -308,8 +308,8 @@ def __init__( self.num_hidden_layers = len(attn_layers) self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets self.lsh_attn_chunk_length = lsh_attn_chunk_length - self.num_chunks_after = num_chunks_after - self.num_chunks_before = num_chunks_before + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before self.hidden_act = hidden_act self.feed_forward_size = feed_forward_size self.hidden_dropout_prob = hidden_dropout_prob @@ -326,7 +326,7 @@ def __init__( self.pad_token_id = pad_token_id self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) - self.key_length = (self.num_chunks_before + self.num_chunks_after + 1) * lsh_attn_chunk_length + self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length self.chunk_length = lsh_attn_chunk_length def prepare_config_and_inputs(self): @@ -351,8 +351,8 @@ def prepare_config_and_inputs(self): num_hashes=self.num_hashes, num_buckets=self.num_buckets, lsh_attn_chunk_length=self.lsh_attn_chunk_length, - num_chunks_after=self.num_chunks_after, - num_chunks_before=self.num_chunks_before, + lsh_num_chunks_after=self.lsh_num_chunks_after, + lsh_num_chunks_before=self.lsh_num_chunks_before, attn_layers=self.attn_layers, hash_seed=self.hash_seed, pad_token_id=self.pad_token_id, @@ -614,8 +614,8 @@ def load_lsh_layer(self, config, mode="eval"): config.attention_head_size, config.attention_head_size, config.lsh_attn_chunk_length, - config.num_chunks_before, - config.num_chunks_after, + config.lsh_num_chunks_before, + config.lsh_num_chunks_after, config.num_hashes, config.num_buckets, config.lsh_attention_probs_dropout_prob, @@ -649,8 +649,8 @@ def load_local_layer(self, config, mask=False, mode="eval"): config.attention_head_size, config.attention_head_size, config.local_attn_chunk_length, - config.num_chunks_before, - config.num_chunks_after, + config.local_num_chunks_before, + config.local_num_chunks_after, config.local_attention_probs_dropout_prob, config.hidden_dropout_prob, config.is_decoder, @@ -730,15 +730,15 @@ def load_reformer_lm_model(self, config, mode="eval"): config.lsh_attn_chunk_length, config.lsh_attn_chunk_length, config.lsh_attn_chunk_length // 2, - config.num_chunks_before, - config.num_chunks_after, + config.lsh_num_chunks_before, + config.lsh_num_chunks_after, config.num_hashes, config.num_buckets, config.hash_seed, config.is_decoder, config.local_attn_chunk_length, - config.num_chunks_before, - config.num_chunks_after, + config.local_num_chunks_before, + config.local_num_chunks_after, config.is_decoder, config.vocab_size, config.hidden_size, From 7de3f4f43e066f1d993dce81a61d579db01f29f6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 May 2020 10:01:45 +0200 Subject: [PATCH 093/162] fix --- src/transformers/trainer.py | 25 +++++++++++++++++++++---- src/transformers/training_args.py | 4 ++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fe11bbc111a7..8f9e0f850002 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -196,10 +196,27 @@ def get_optimizers( "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps - ) + optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon, betas=(self.args.adam_beta_1, self.args.adam_beta_2)) + + if self.args.scheduler == "linear": + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps + ) + elif self.args.scheduler == "cosine_decay": + scheduler = get_cosine_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine + ) + elif self.args.scheduler == "cosine_decay_hard_restarts": + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine + ) + elif self.args.scheduler == "constant": + scheduler = get_constant_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps + ) + else: + raise NotImplementedError("The scheduler {} does not exist. Please choose one of the following schedulers ['linear', 'cosine_decay', 'cosine_decay_hard_restarts', 'constant']".format(self.args.scheduler)) + return optimizer, scheduler def _setup_wandb(self): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5bffd44b0d3c..bfcff7bcba47 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -55,6 +55,8 @@ class TrainingArguments: learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) + adam_beta_1: float = field(default=0.9, metadata={"helf": "Beta 1 for Adam optimizer."}) + adam_beta_2: float = field(default=0.999, metadata={"helf": "Beta 2 for Adam optimizer."}) max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) @@ -63,6 +65,8 @@ class TrainingArguments: metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + scheduler: str = field(default="linear", metadata={"help": "Name of learning rate scheduler to use. Choose between ['linear', 'cosine_decay', 'cosine_decay_hard_restarts', 'constant']"}) + num_cycles_cosine: float = field(default=1.0, metadata={"help": "Number of cosine cycles when using cosine decay schedules"}) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"}) From 6111bd55dce6cc80308be6ac37c2437bb4a72778 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 May 2020 10:02:28 +0200 Subject: [PATCH 094/162] fix --- src/transformers/trainer.py | 4 ++-- src/transformers/training_args.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8f9e0f850002..f1e10c99a517 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -204,11 +204,11 @@ def get_optimizers( ) elif self.args.scheduler == "cosine_decay": scheduler = get_cosine_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine_decay ) elif self.args.scheduler == "cosine_decay_hard_restarts": scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine_decay ) elif self.args.scheduler == "constant": scheduler = get_constant_schedule_with_warmup( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index bfcff7bcba47..0e1e1167c49c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -66,7 +66,7 @@ class TrainingArguments: ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) scheduler: str = field(default="linear", metadata={"help": "Name of learning rate scheduler to use. Choose between ['linear', 'cosine_decay', 'cosine_decay_hard_restarts', 'constant']"}) - num_cycles_cosine: float = field(default=1.0, metadata={"help": "Number of cosine cycles when using cosine decay schedules"}) + num_cycles_cosine_decay: float = field(default=1.0, metadata={"help": "Number of cosine cycles when using cosine decay schedules"}) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"}) From 0c75149cccd08f64c36f92ac27f9a121373787fe Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Apr 2020 14:15:39 +0200 Subject: [PATCH 095/162] add tests for fixed seed in reformer layer --- tests/test_modeling_common.py | 2 +- tests/test_modeling_reformer.py | 68 +++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c78e8af460d9..0342d22119ca 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -770,7 +770,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None): def floats_tensor(shape, scale=1.0, rng=None, name=None): - """Creates a random float32 tensor of the shape within the vocab size.""" + """Creates a random float32 tensor""" if rng is None: rng = global_rng diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index da67f0b66ccd..e19cc38f76dd 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -29,7 +29,7 @@ if is_torch_available(): - from transformers import ReformerConfig, ReformerAttention, ReformerModel, ReformerModelWithLMHead + from transformers import ReformerConfig, ReformerAttention, ReformerModel, ReformerModelWithLMHead, ReformerLayer # from transformers.modeling_reformer import REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP import torch @@ -225,6 +225,32 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + config.is_decoder = is_decoder + layer = ReformerLayer(config) + layer.train() + shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize + + # get random tensors + hidden_states = floats_tensor(shape) + prev_attn_output = floats_tensor(shape) + + # now the random seeds for attention and feed forward is initialized + # forward tensors with dropout + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + + next_attn_output = layer_outputs.attn_output + next_hidden_states = layer_outputs.hidden_states + + torch.manual_seed(layer.attention_seed) + attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) + self.parent.assertTrue(torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3)) + + torch.manual_seed(layer.feed_forward_seed) + feed_forward_hidden_states = layer.feed_forward(next_attn_output) + self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) + # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -247,10 +273,15 @@ def test_reformer_model_attn_masking(self): self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) - def test_for_reformer_with_lm(self): + def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + def test_reformer_layer_training_dropout(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): @@ -444,6 +475,32 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + config.is_decoder = is_decoder + layer = ReformerLayer(config) + layer.train() + shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize + + # get random tensors + hidden_states = floats_tensor(shape) + prev_attn_output = floats_tensor(shape) + + # now the random seeds for attention and feed forward is initialized + # forward tensors with dropout + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + + next_attn_output = layer_outputs.attn_output + next_hidden_states = layer_outputs.hidden_states + + torch.manual_seed(layer.attention_seed) + attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) + self.parent.assertTrue(torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3)) + + torch.manual_seed(layer.feed_forward_seed) + feed_forward_hidden_states = layer.feed_forward(next_attn_output) + self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) + # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -466,10 +523,15 @@ def test_reformer_model_attn_masking(self): self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) - def test_for_reformer_with_lm(self): + def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + def test_reformer_layer_training_dropout(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + @require_torch class ReformerIntegrationTests(unittest.TestCase): From 7a03bc753291ae31211e31a520d3125107723e1b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Apr 2020 16:57:42 +0200 Subject: [PATCH 096/162] fix trainer typo --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f1e10c99a517..68822dcb1861 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -243,7 +243,7 @@ def train(self, model_path: Optional[str] = None): if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps + 1) + self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 ) else: t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) From 37943f30639ecfdcd39a6e3229fecd48a94d8787 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Apr 2020 18:33:08 +0200 Subject: [PATCH 097/162] fix typo in activations --- src/transformers/activations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 5429f2fe36ab..60e8e8d05e38 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -34,7 +34,6 @@ def gelu_new(x): else: gelu = F.gelu - ACT2FN = { "relu": F.relu, "swish": swish, From 0f751f537c1cae8fbded44e46ffd7afae3544b4d Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 28 Apr 2020 15:01:40 +0200 Subject: [PATCH 098/162] add fp16 tests --- tests/test_modeling_reformer.py | 48 +++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e19cc38f76dd..095479aa6602 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -251,6 +251,20 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, is_decoder): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + model(input_ids, attention_mask=input_mask) + + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, is_decoder): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + model.generate(input_ids, attention_mask=input_mask, do_sample=False) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -282,6 +296,16 @@ def test_reformer_layer_training_dropout(self): self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): @@ -501,6 +525,20 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, is_decoder): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + model(input_ids, attention_mask=input_mask) + + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, is_decoder): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + model.generate(input_ids, attention_mask=input_mask, do_sample=False) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs @@ -532,6 +570,16 @@ def test_reformer_layer_training_dropout(self): self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + @require_torch class ReformerIntegrationTests(unittest.TestCase): From 8df5dcdeea13c69789d0cf5953ee04744f8a668b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Apr 2020 18:45:26 +0000 Subject: [PATCH 099/162] add fp16 training --- src/transformers/modeling_reformer.py | 16 ++++++---------- tests/test_modeling_reformer.py | 24 +++++++++--------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 7efae7b94084..ddc2f0a6e73c 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -187,12 +187,8 @@ def forward(self, position_ids): ) if self.dropout > 0: weights = torch.cat(broadcasted_weights, dim=-1) - # permute weights so that 2D correctly drops dims 1 and 2 - perm_weigthts = weights.permute(0, 3, 2, 1) - # drop entire matrix of last two dims (prev dims 1 and 2) - drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) - drop_weights = drop_perm_weights.permute(0, 3, 2, 1) - position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) + position_encodings = torch.reshape(weights, (batch_size, sequence_length, -1)) + else: position_encodings = torch.cat( [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], @@ -489,7 +485,7 @@ def _hash_vectors(self, vectors, num_hashes): rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) np.random.seed(self.hash_seed) random_rotations = torch.tensor( - np.random.normal(size=rotations_shape), dtype=torch.float32, device=vectors.device, + np.random.normal(size=rotations_shape), dtype=vectors.dtype, device=vectors.device, ) rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) else: @@ -669,7 +665,7 @@ def _attend( def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) vectors = vectors / torch.sqrt( - torch.tensor(self.attention_head_size, device=vectors.device, dtype=torch.float32) + torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype) ) return vectors @@ -756,7 +752,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ assert self.num_chunks_before == 0 and self.num_chunks_after == 0 key_vectors = key_vectors / torch.sqrt( - torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=torch.float32) + torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) ) # chunk vectors @@ -1375,7 +1371,7 @@ def forward( # Extend `attention_mask` if attention_mask is not None: attention_mask = torch.cat( - [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=torch.long,)], + [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,)], dim=-1, ) else: diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 095479aa6602..8d9853742607 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -195,7 +195,7 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu # set all position encodings to zero so that postions don't matter with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape)) + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -227,7 +227,7 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): config.is_decoder = is_decoder - layer = ReformerLayer(config) + layer = ReformerLayer(config).to(torch_device) layer.train() shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize @@ -251,14 +251,14 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): model = ReformerModel(config=config) model.to(torch_device) model.half() model.eval() model(input_ids, attention_mask=input_mask) - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -301,11 +301,6 @@ def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_reformer_model_fp16_generate(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) - @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): @@ -467,9 +462,10 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu model.to(torch_device) model.eval() # set all position encodings to zero so that postions don't matter + with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape)) + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -490,8 +486,6 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ :, roll : half_seq_len + roll @@ -501,7 +495,7 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): config.is_decoder = is_decoder - layer = ReformerLayer(config) + layer = ReformerLayer(config).to(torch_device) layer.train() shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize @@ -525,14 +519,14 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): model = ReformerModel(config=config) model.to(torch_device) model.half() model.eval() model(input_ids, attention_mask=input_mask) - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() From 51426b520ab687ccbc2d4cc2180dba5a50df30b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Apr 2020 20:57:52 +0200 Subject: [PATCH 100/162] support fp16 --- src/transformers/modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index ddc2f0a6e73c..87e1ba740a09 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -491,7 +491,7 @@ def _hash_vectors(self, vectors, num_hashes): else: rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x self.num_buckets/2 - random_rotations = torch.randn(rotations_shape, device=vectors.device) + random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) # rotated_vectors has dim: # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 From b37fd3b5e77a9bbcf733a676c440b4b714ea0ca8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 15:34:01 +0200 Subject: [PATCH 101/162] correct gradient bug in reformer --- src/transformers/configuration_reformer.py | 11 +- src/transformers/modeling_reformer.py | 43 ++- tests/test_modeling_reformer.py | 291 ++++++++++++++------- 3 files changed, 232 insertions(+), 113 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index af3e4876381d..559b811db620 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -97,12 +97,12 @@ class ReformerConfig(PretrainedConfig): def __init__( self, - vocab_size=200, + vocab_size=10, attention_head_size=32, - hidden_size=128, + hidden_size=64, num_attention_heads=2, num_buckets=2, - num_hashes=4, + num_hashes=2, lsh_attn_chunk_length=64, local_attn_chunk_length=64, lsh_num_chunks_before=1, @@ -122,9 +122,10 @@ def __init__( layer_norm_eps=1e-12, sinusoidal_pos_embds=False, axial_pos_embds=False, - axial_pos_shape=[64, 3], - axial_pos_embds_dim=[64, 64], + axial_pos_shape=[8, 8], + axial_pos_embds_dim=[32, 32], attn_layers=["lsh", "lsh", "lsh", "lsh"], +# attn_layers=["local", "local", "local", "local"], is_decoder=False, pad_token_id=0, eos_token_id=2, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 87e1ba740a09..b895ace7b54e 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -430,8 +430,9 @@ def forward( # free memory del query_key_vectors, key_vectors, value_vectors - # gather correct out vectors - out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) + # calculate total concatenad chunks to split gradients correctly + num_neighboored_concatenated_chunks = self.num_chunks_after + self.num_chunks_before + 1 + out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_neighboored_concatenated_chunks) if num_hashes > 1: out_vectors = self._split_dim_by( @@ -682,13 +683,13 @@ def _gather_by_expansion(self, vectors, idxs, num_hashes): class UndoSort(Function): @staticmethod - def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx): + def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_neighboored_concatenated_chunks): # save sorted_bucket_idx for backprop ctx.sorted_bucket_idx = sorted_bucket_idx + ctx.num_neighboored_concatenated_chunks = num_neighboored_concatenated_chunks out_vectors = out_vectors.detach() logits = logits.detach() - # undo sort to have correct order for next layer expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) @@ -697,16 +698,35 @@ def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) @staticmethod def backward(ctx, grad_out_vectors, grad_logits): - # get sorted_bucket_idx + # get parameters saved in ctx sorted_bucket_idx = ctx.sorted_bucket_idx + num_neighboored_concatenated_chunks = ctx.num_neighboored_concatenated_chunks + + # get real gradient shape + # shape is BatchSize x NumAttnHeads x ChunkLen*(num_chunk_before + num_chunk_after + 1) + grad_logits_shape = grad_logits.shape + # shape is BatchSize x NumAttnHeads x ChunkLen*(num_chunk_before + num_chunk_after + 1) x ChunkLen + grad_out_vectors_shape = grad_out_vectors.shape + + # split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices + # shape is BatchSize x NumAttnHeads x (num_chunk_before + num_chunk_after + 1) x ChunkLen + grad_logits = torch.reshape(grad_logits, (grad_logits_shape[:2] + (num_neighboored_concatenated_chunks, -1))) + # shape is BatchSize x NumAttnHeads x (num_chunk_before + num_chunk_after + 1) x ChunkLen x ChunkLen + grad_out_vectors = torch.reshape(grad_out_vectors, (grad_out_vectors_shape[:2] + (num_neighboored_concatenated_chunks, -1) + grad_out_vectors_shape[-1:])) + + sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_neighboored_concatenated_chunks, -1))) # reverse sort of forward expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) - grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) - grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx) + grad_out_vectors = torch.gather(grad_out_vectors, 3, expanded_sort_indices) + grad_logits = torch.gather(grad_logits, 3, sorted_bucket_idx) + + # reshape into correct shape + grad_logits = torch.reshape(grad_logits, grad_logits_shape) + grad_out_vectors = torch.reshape(grad_out_vectors, grad_out_vectors_shape) - # return grad and `None` fillers for last 2 forward args - return grad_out_vectors, grad_logits, None, None + # return grad and `None` fillers for last 3 forward args + return grad_out_vectors, grad_logits, None, None, None class LocalSelfAttention(nn.Module, EfficientAttentionUtils): @@ -1484,7 +1504,10 @@ def forward( if labels is not None: # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() + # Uncomment this line for integration test with Trax + shift_logits = logits.contiguous() +# shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 8d9853742607..3f918b9928cd 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -20,6 +20,7 @@ # trax imports - to be deleted later import trax +import jax from transformers import is_torch_available from trax.shapes import ShapeDtype as trax_ShapeDtype @@ -628,37 +629,69 @@ def test_local_layer(self): self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) def test_reformer_lm_model(self): - config = ReformerConfig(sinusoidal_pos_embds=True, hash_seed=0, is_decoder=True) + config = ReformerConfig(axial_pos_embds=True, hash_seed=0, is_decoder=True) + + shape = (1, 64) # Batch x SeqLen x ModelDimPerHead - shape = (1, 192) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) + np_input_2 = np.asarray(np_input, np.float32) + mask = np.ones_like(np_input, dtype=np.float32) np_zeros = np.zeros((shape[0], 1), dtype=np.int) - mode = "eval" + mode = "train" + trax_model = self.load_reformer_lm_model(config, mode=mode) assert ( config.is_decoder is True ), "trax can only test casaul mask for ReformerLM. Use tests for layers to test non-casaul mask" - input_signature = trax_ShapeDtype(shape, np.int32) - trax_weights, trax_state = trax_model.init(input_signature) - trax_output = trax_model(np_input, weights=trax_weights, state=trax_state) - trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + if mode == "train": + trax_model = trax.layers.Serial(trax_model, trax.layers.CrossEntropyLoss()) + input_signature = (trax_ShapeDtype(shape, np.int32), trax_ShapeDtype(shape, np.float32), trax_ShapeDtype(shape, np.float32)) + trax_weights, trax_state = trax_model.init(input_signature) + trax_input = (np_input, np_input_2, mask) + torch_trax_weights = trax_weights[0] + + else: + input_signature = trax_ShapeDtype(shape, np.int32) + trax_weights, trax_state = trax_model.init(input_signature) + trax_input = np_input + trax_output = trax_model(trax_input, weights=trax_weights, state=trax_state) + trax_torch_output = torch.tensor(np.asarray(trax_output[0])) + torch_trax_weights = trax_weights if mode != "predict": hf_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) + hf_labels = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input)], dim=-1) else: hf_input = torch.tensor(np_input) hf_model = ReformerModelWithLMHead(config) - self._set_model_weights_in_torch(trax_weights, hf_model, config.hidden_size) - hf_model.eval() + self._set_model_weights_in_torch(torch_trax_weights, hf_model, config.hidden_size) - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + if mode == "train": + # uncomment line to fix hf_input_shifting in ReformerWithLMHead + hf_model.train() + loss = hf_model(hf_input, labels=hf_labels)[0] + else: + hf_output = hf_model(hf_input) + hf_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) + self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-4)) - self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) + if mode == "train": + hf_model.zero_grad() + loss.backward() + + def model_and_loss_call(weights, batch, state): + res = trax_model(batch, weights=weights, state=state) + return res, trax_model.state + + grad_fn = jax.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(trax_weights, trax_input, trax_state) + + all_test_correct = self._set_model_weights_in_torch(grads[0], hf_model, config.hidden_size, set_params=False) + self.assertTrue(all_test_correct) def test_backprop_lm_model(self): config = ReformerConfig() @@ -672,7 +705,9 @@ def test_backprop_lm_model(self): loss = model(input_ids, labels=input_ids)[0] loss.backward() - def test_pretrained_crime_and_punishment_lm_model(self): + # use github old branch to make this test work. Pretrained weights + # cannot be loaded into new code anymore + def no_test_pretrained_crime_and_punishment_lm_model(self): hf_model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish") config = hf_model.config @@ -809,6 +844,7 @@ def load_reformer_lm_model(self, config, mode="eval"): SelfAttention.n_chunks_after = {} SelfAttention.causal= {} SelfAttention.use_reference_code = True + SelfAttention.share_qk = False # Parameters for ReformerLM: # ============================================================================== @@ -825,10 +861,9 @@ def load_reformer_lm_model(self, config, mode="eval"): ReformerLM.ff_chunk_size = {} ReformerLM.ff_activation = @trax.layers.{} ReformerLM.attention_type = @trax.layers.{} + ReformerLM.dropout = 0.0 ReformerLM.n_chunks = 0 - ReformerLM.n_attention_chunks = None - ReformerLM.share_qk = False ReformerLM.ff_use_sru = 0 """.format( config.lsh_attn_chunk_length, @@ -885,7 +920,6 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode @SelfAttention, @LSHSelfAttention, ] - share_qk = False # LSH attention ignores this flag and always shares q & k n_heads = 2 attn_kv = 64 dropout = 0.05 @@ -896,6 +930,7 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode SelfAttention.chunk_len = 64 SelfAttention.n_chunks_before = 1 SelfAttention.n_parallel_heads = 1 + SelfAttention.share_qk = False # Parameters for LSHSelfAttention: # ============================================================================== @@ -923,7 +958,6 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode ReformerLM.n_heads = %n_heads ReformerLM.n_layers = %n_layers ReformerLM.vocab_size = 320 - ReformerLM.share_qk = %share_qk ReformerLM.axial_pos_shape = (512, 1024) ReformerLM.d_axial_pos_embs= (64, 192) """ @@ -933,146 +967,207 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode trax_model.init_from_file(trax_model_path, weights_only=True) return trax_model - def _set_param(self, torch_layer, weight, bias=None): + def _set_param(self, torch_layer, weight, bias=None, name=None): with torch.no_grad(): assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) + return True + + def _test_param(self, torch_layer, grad, bias_grad=None, name=""): + assert torch_layer.weight.grad.shape == grad.shape, "{} layer.grad does not match".format(torch_layer) + if torch.allclose(torch_layer.weight.grad, grad, atol=1e-3): + print("{}-{} layer.grad is good!".format(name, torch_layer)) + else: + print("ERROR {}-{} layer.grad is not good!".format(name, torch_layer)) + return False + if bias_grad is not None: + assert torch_layer.bias.grad.shape == bias_grad.shape, "{} layer.bias does not match".format(torch_layer) + if torch.allclose(torch_layer.bias.grad, bias_grad, atol=1e-3): + print("{}-{} layer.grad bias is good!".format(name, torch_layer)) + else: + print("ERROR {}-{} layer.grad bias is not good!".format(name, torch_layer)) + return False + return True + + def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size, exec_fn=None): + all_test_true = True + if exec_fn is None: + exec_fn = self._set_param - def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison np_query_key = np.asarray(weights[0]) np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) - self._set_param( + all_test_true = exec_fn( torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - self._set_param( + name="attn_query_key" + ) and all_test_true + + all_test_true = exec_fn( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - self._set_param( + name="attn_value" + ) and all_test_true + + all_test_true = exec_fn( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), - ) + name="attn_dense" + ) and all_test_true + return all_test_true + + def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size, exec_fn=None): + all_test_true = True + + if exec_fn is None: + exec_fn = self._set_param - def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison np_query = np.asarray(weights[0]) np_key = np.asarray(weights[1]) np_value = np.asarray(weights[2]) np_dense = np.asarray(weights[3]) - self._set_param( + all_test_true = exec_fn( torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - self._set_param( + ) and all_test_true + all_test_true = exec_fn( torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - self._set_param( + ) and all_test_true + all_test_true = exec_fn( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - self._set_param( + ) and all_test_true + all_test_true = exec_fn( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), - ) + ) and all_test_true + return all_test_true - def _set_block_weights_in_torch(self, weights, torch_block, hidden_size): - # layernorm 1 - layer_norm_1 = weights[0][0][0] - layer_norm_1_weight = np.asarray(layer_norm_1[0]) - layer_norm_1_bias = np.asarray(layer_norm_1[1]) - self._set_param( - torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), - ) + def _set_block_weights_in_torch(self, weights, torch_block, hidden_size, exec_fn=None): + all_test_true = True - # lsh weights + output - attn_weights = weights[0][1] - if len(attn_weights) < 4: - self._set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size) - else: - self._set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) + if exec_fn is None: + exec_fn = self._set_param # intermediate weighs - intermediate_weights = weights[2][0][2][2] +# intermediate_weights = weights[2][0][2][2] + intermediate_weights = weights[2][0][1][2] # Chunked Feed Forward if len(intermediate_weights) == 4: intermediate_weights = intermediate_weights[2] - # layernorm 2 - layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) - layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - self._set_param( - torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), - ) + # intermediate out + out_dense_weight = np.asarray(intermediate_weights[4][0]) + out_dense_bias = np.asarray(intermediate_weights[4][1]) + all_test_true = exec_fn( + torch_block.feed_forward.output.dense, + torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(out_dense_bias), + name="res_feed_forward_2" + ) and all_test_true # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - self._set_param( + all_test_true = exec_fn( torch_block.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias), - ) + name="res_feed_forward_1" + ) and all_test_true - # intermediate out - out_dense_weight = np.asarray(intermediate_weights[4][0]) - out_dense_bias = np.asarray(intermediate_weights[4][1]) - self._set_param( - torch_block.feed_forward.output.dense, - torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), - torch.tensor(out_dense_bias), - ) + # layernorm 2 + layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) + layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) + all_test_true = exec_fn( + torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), name="layer_norm_2" + ) and all_test_true - def _set_model_weights_in_torch(self, weights, torch_model, hidden_size): - # reformer model - torch_model_reformer = torch_model.reformer + # lsh weights + output + attn_weights = weights[0][1] + if len(attn_weights) < 4: + all_test_true = self._set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn) and all_test_true + else: + all_test_true = self._set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn) and all_test_true - # word embeds - word_embeddings = np.asarray(weights[1]) - self._set_param( - torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), - ) + # layernorm 1 + layer_norm_1 = weights[0][0][0] + layer_norm_1_weight = np.asarray(layer_norm_1[0]) + layer_norm_1_bias = np.asarray(layer_norm_1[1]) + all_test_true = exec_fn( + torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), + name="layer_norm" + ) and all_test_true - if isinstance(weights[3], tuple): - position_embeddings = torch_model_reformer.embeddings.position_embeddings - for emb_idx in range(len(position_embeddings.weights)): - emb_weights = np.asarray(weights[3][emb_idx][0]) - assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( - position_embeddings[emb_idx] - ) - position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + return all_test_true - trax_layer_weights = weights[5] - assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( - trax_layer_weights - ), "HF and trax model do not have the same number of layers" - for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): - block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] - self._set_block_weights_in_torch(block_weights, layer, hidden_size) + def _set_model_weights_in_torch(self, weights, torch_model, hidden_size, set_params=True): + # reformer model + torch_model_reformer = torch_model.reformer + + all_test_true = True + if set_params is True: + exec_fn = self._set_param + else: + exec_fn = self._test_param # output weights out_weights = weights[6] + # output embeddings + output_embed_weights = np.asarray(out_weights[2][0]) + output_embed_bias = np.asarray(out_weights[2][1]) + all_test_true = exec_fn( + torch_model.lm_head.decoder, + torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), + torch.tensor(output_embed_bias), + name="lm_head" + ) and all_test_true + # output layer norm layer_norm_out_weight = np.asarray(out_weights[0][0]) layer_norm_out_bias = np.asarray(out_weights[0][1]) - self._set_param( + all_test_true = exec_fn( torch_model_reformer.encoder.layer_norm, torch.tensor(layer_norm_out_weight), torch.tensor(layer_norm_out_bias), - ) + name="last layer norm" + ) and all_test_true - # output embeddings - output_embed_weights = np.asarray(out_weights[2][0]) - output_embed_bias = np.asarray(out_weights[2][1]) - self._set_param( - torch_model.lm_head.decoder, - torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), - torch.tensor(output_embed_bias), - ) + trax_layer_weights = weights[5] + assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( + trax_layer_weights + ), "HF and trax model do not have the same number of layers" + for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers[::-1]): + block_weights = trax_layer_weights[:-1][::-1][4 * layer_idx : 4 * (layer_idx + 1)][::-1] + all_test_true = self._set_block_weights_in_torch(block_weights, layer, hidden_size, exec_fn=exec_fn) and all_test_true + + if isinstance(weights[3], tuple): + position_embeddings = torch_model_reformer.embeddings.position_embeddings + for emb_idx in range(len(position_embeddings.weights)): + emb_weights = np.asarray(weights[3][emb_idx][0]) + assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( + position_embeddings[emb_idx] + ) + if set_params is True: + position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + else: + if torch.allclose(position_embeddings.weights[emb_idx].grad, torch.tensor(emb_weights), atol=1e-3): + print("{} layer.grad is good!".format(position_embeddings)) + else: + print("ERROR: {}-{} layer.grad is not good".format(position_embeddings, "axs_pos_embeds")) + + # word embeds + word_embeddings = np.asarray(weights[1]) + all_test_true = exec_fn( + torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), + name="word_embed" + ) and all_test_true + + return all_test_true From e3e05efba15bbd6c086d6fbfa39af79b8cca4799 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 16:06:02 +0200 Subject: [PATCH 102/162] add fast gelu --- src/transformers/activations.py | 6 ++++++ src/transformers/modeling_reformer.py | 7 ++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 60e8e8d05e38..f79b4055a16f 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -34,12 +34,18 @@ def gelu_new(x): else: gelu = F.gelu + +def gelu_fast(x): + return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + + ACT2FN = { "relu": F.relu, "swish": swish, "gelu": gelu, "tanh": torch.tanh, "gelu_new": gelu_new, + "gelu_fast": gelu_fast } diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index b895ace7b54e..a021bcb84d61 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -31,7 +31,7 @@ from torch.autograd.function import Function from torch.nn import CrossEntropyLoss -from .activations import gelu, gelu_new, swish +from .activations import gelu, gelu_new, swish, gelu_fast from .configuration_reformer import ReformerConfig from .modeling_utils import PreTrainedModel @@ -55,6 +55,7 @@ def mish(x): "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, + "gelu_fast": gelu_fast, "mish": mish, } ReformerLayerNorm = torch.nn.LayerNorm @@ -1505,8 +1506,8 @@ def forward( if labels is not None: # Shift so that tokens < n predict n # Uncomment this line for integration test with Trax - shift_logits = logits.contiguous() -# shift_logits = logits[..., :-1, :].contiguous() +# shift_logits = logits.contiguous() + shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens From c3e32b4d870b548d6b2eef6badfbfa8f8a579d73 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 16:14:50 +0200 Subject: [PATCH 103/162] re-add dropout for embedding dropout --- src/transformers/modeling_reformer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index a021bcb84d61..42c989a2fd91 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -188,7 +188,12 @@ def forward(self, position_ids): ) if self.dropout > 0: weights = torch.cat(broadcasted_weights, dim=-1) - position_encodings = torch.reshape(weights, (batch_size, sequence_length, -1)) + # permute weights so that 2D correctly drops dims 1 and 2 + perm_weigthts = weights.permute(0, 3, 2, 1) + # drop entire matrix of last two dims (prev dims 1 and 2) + drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) + drop_weights = drop_perm_weights.permute(0, 3, 2, 1) + position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) else: position_encodings = torch.cat( From 52ee5ed7107957eceb19cd62a54fdc940405f8ee Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 17:52:19 +0200 Subject: [PATCH 104/162] better naming --- src/transformers/configuration_reformer.py | 6 ++-- src/transformers/modeling_reformer.py | 34 ++++++++++++---------- tests/test_modeling_reformer.py | 2 +- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 559b811db620..4bb3535b8dd6 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -100,8 +100,8 @@ def __init__( vocab_size=10, attention_head_size=32, hidden_size=64, - num_attention_heads=2, - num_buckets=2, + num_attention_heads=1, + num_buckets=[2, 4], num_hashes=2, lsh_attn_chunk_length=64, local_attn_chunk_length=64, @@ -122,7 +122,7 @@ def __init__( layer_norm_eps=1e-12, sinusoidal_pos_embds=False, axial_pos_embds=False, - axial_pos_shape=[8, 8], + axial_pos_shape=[32, 16], axial_pos_embds_dim=[32, 32], attn_layers=["lsh", "lsh", "lsh", "lsh"], # attn_layers=["local", "local", "local", "local"], diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 42c989a2fd91..71ea60c938dd 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -471,6 +471,7 @@ def _hash_vectors(self, vectors, num_hashes): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): assert ( self.num_buckets % 2 == 0 @@ -480,12 +481,13 @@ def _hash_vectors(self, vectors, num_hashes): else: # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 - for num_bucket in self.num_buckets: - assert num_bucket % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( - num_bucket - ) - rotation_size += num_bucket - num_buckets *= num_bucket + for bucket_factor in self.num_buckets: + assert bucket_factor % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format(bucket_factor) + rotation_size = rotation_size + bucket_factor + num_buckets = num_buckets * bucket_factor + + # remove gradient + vectors = vectors.detach() # TODO: delete later when integration tests are ok if self.hash_seed is not None: @@ -497,7 +499,7 @@ def _hash_vectors(self, vectors, num_hashes): rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) else: rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) - # create a random self.attention_head_size x num_hashes x self.num_buckets/2 + # create a random self.attention_head_size x num_hashes x num_buckets/2 random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) # rotated_vectors has dim: @@ -513,17 +515,17 @@ def _hash_vectors(self, vectors, num_hashes): else: # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 - for num_bucket in self.num_buckets: - rotated_vectors = rotated_vectors[..., cur_sum : cur_sum + (num_bucket // 2)] - cur_sum += num_bucket // 2 - rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + for bucket_factor in self.num_buckets: + rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] + cur_sum = cur_sum + bucket_factor // 2 + rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) if buckets is None: - buckets = torch.argmax(rotated_vectors, dim=-1) + buckets = torch.argmax(rotated_vectors_factor, dim=-1) else: - buckets += cur_product * torch.argmax(rotated_vectors, dim=-1) + buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1)) - cur_product *= num_bucket + cur_product = cur_product * bucket_factor # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. @@ -1511,8 +1513,8 @@ def forward( if labels is not None: # Shift so that tokens < n predict n # Uncomment this line for integration test with Trax -# shift_logits = logits.contiguous() - shift_logits = logits[..., :-1, :].contiguous() + shift_logits = logits.contiguous() +# shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 3f918b9928cd..6a0974bf1f78 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -631,7 +631,7 @@ def test_local_layer(self): def test_reformer_lm_model(self): config = ReformerConfig(axial_pos_embds=True, hash_seed=0, is_decoder=True) - shape = (1, 64) # Batch x SeqLen x ModelDimPerHead + shape = (1, 512) # Batch x SeqLen x ModelDimPerHead np_input = np.random.randint(0, config.vocab_size, size=shape) np_input_2 = np.asarray(np_input, np.float32) From ece19eee141adddf598941c09aa64de7292284b9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 17:53:59 +0200 Subject: [PATCH 105/162] better naming --- src/transformers/modeling_reformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 71ea60c938dd..941fb5b03b22 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1513,8 +1513,8 @@ def forward( if labels is not None: # Shift so that tokens < n predict n # Uncomment this line for integration test with Trax - shift_logits = logits.contiguous() -# shift_logits = logits[..., :-1, :].contiguous() +# shift_logits = logits.contiguous() + shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens From e661832ede56c3fbc62119ecec7eed2dfc6f64a5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 18:13:28 +0200 Subject: [PATCH 106/162] renaming --- src/transformers/configuration_reformer.py | 3 +-- src/transformers/modeling_reformer.py | 23 +++++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 4bb3535b8dd6..0598ad6e18c6 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -102,7 +102,7 @@ def __init__( hidden_size=64, num_attention_heads=1, num_buckets=[2, 4], - num_hashes=2, + num_hashes=8, lsh_attn_chunk_length=64, local_attn_chunk_length=64, lsh_num_chunks_before=1, @@ -125,7 +125,6 @@ def __init__( axial_pos_shape=[32, 16], axial_pos_embds_dim=[32, 32], attn_layers=["lsh", "lsh", "lsh", "lsh"], -# attn_layers=["local", "local", "local", "local"], is_decoder=False, pad_token_id=0, eos_token_id=2, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 941fb5b03b22..6bc81c1d5d06 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -437,8 +437,7 @@ def forward( del query_key_vectors, key_vectors, value_vectors # calculate total concatenad chunks to split gradients correctly - num_neighboored_concatenated_chunks = self.num_chunks_after + self.num_chunks_before + 1 - out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_neighboored_concatenated_chunks) + out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes) if num_hashes > 1: out_vectors = self._split_dim_by( @@ -691,10 +690,10 @@ def _gather_by_expansion(self, vectors, idxs, num_hashes): class UndoSort(Function): @staticmethod - def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_neighboored_concatenated_chunks): + def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): # save sorted_bucket_idx for backprop ctx.sorted_bucket_idx = sorted_bucket_idx - ctx.num_neighboored_concatenated_chunks = num_neighboored_concatenated_chunks + ctx.num_hashes = num_hashes out_vectors = out_vectors.detach() logits = logits.detach() @@ -708,21 +707,21 @@ def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, def backward(ctx, grad_out_vectors, grad_logits): # get parameters saved in ctx sorted_bucket_idx = ctx.sorted_bucket_idx - num_neighboored_concatenated_chunks = ctx.num_neighboored_concatenated_chunks + num_hashes = ctx.num_hashes # get real gradient shape - # shape is BatchSize x NumAttnHeads x ChunkLen*(num_chunk_before + num_chunk_after + 1) + # shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes grad_logits_shape = grad_logits.shape - # shape is BatchSize x NumAttnHeads x ChunkLen*(num_chunk_before + num_chunk_after + 1) x ChunkLen + # shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen grad_out_vectors_shape = grad_out_vectors.shape # split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices - # shape is BatchSize x NumAttnHeads x (num_chunk_before + num_chunk_after + 1) x ChunkLen - grad_logits = torch.reshape(grad_logits, (grad_logits_shape[:2] + (num_neighboored_concatenated_chunks, -1))) - # shape is BatchSize x NumAttnHeads x (num_chunk_before + num_chunk_after + 1) x ChunkLen x ChunkLen - grad_out_vectors = torch.reshape(grad_out_vectors, (grad_out_vectors_shape[:2] + (num_neighboored_concatenated_chunks, -1) + grad_out_vectors_shape[-1:])) + # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen + grad_logits = torch.reshape(grad_logits, (grad_logits_shape[:2] + (num_hashes, -1))) + # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen + grad_out_vectors = torch.reshape(grad_out_vectors, (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])) - sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_neighboored_concatenated_chunks, -1))) + sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) # reverse sort of forward expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) From f1a6355df4cf2a4d847f725d99e8bf849628b818 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Apr 2020 23:18:09 +0200 Subject: [PATCH 107/162] finalize test branch --- src/transformers/configuration_reformer.py | 2 +- src/transformers/modeling_reformer.py | 5 +- tests/test_modeling_reformer.py | 800 ++++++++++++++++----- 3 files changed, 640 insertions(+), 167 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 0598ad6e18c6..3034a158c4de 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -102,7 +102,7 @@ def __init__( hidden_size=64, num_attention_heads=1, num_buckets=[2, 4], - num_hashes=8, + num_hashes=4, lsh_attn_chunk_length=64, local_attn_chunk_length=64, lsh_num_chunks_before=1, diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 6bc81c1d5d06..9321d8236001 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -491,7 +491,7 @@ def _hash_vectors(self, vectors, num_hashes): # TODO: delete later when integration tests are ok if self.hash_seed is not None: rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) - np.random.seed(self.hash_seed) + torch.manual_seed(self.hash_seed) random_rotations = torch.tensor( np.random.normal(size=rotations_shape), dtype=vectors.dtype, device=vectors.device, ) @@ -1511,10 +1511,7 @@ def forward( if labels is not None: # Shift so that tokens < n predict n - # Uncomment this line for integration test with Trax -# shift_logits = logits.contiguous() shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 6a0974bf1f78..70f809f6a479 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -30,7 +30,13 @@ if is_torch_available(): - from transformers import ReformerConfig, ReformerAttention, ReformerModel, ReformerModelWithLMHead, ReformerLayer + from transformers import ( + ReformerConfig, + ReformerAttention, + ReformerModel, + ReformerModelWithLMHead, + ReformerLayer, + ) # from transformers.modeling_reformer import REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP import torch @@ -39,7 +45,9 @@ @require_torch class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_model_classes = ( + (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + ) test_pruning = False test_headmasking = False test_torchscript = False @@ -112,7 +120,9 @@ def __init__( self.encoder_seq_length = seq_length // local_attn_chunk_length + ( self.seq_length % local_attn_chunk_length != 0 ) - self.key_length = (self.local_num_chunks_before + self.local_num_chunks_after + 1) * local_attn_chunk_length + self.key_length = ( + self.local_num_chunks_before + self.local_num_chunks_after + 1 + ) * local_attn_chunk_length self.chunk_length = local_attn_chunk_length def prepare_config_and_inputs(self): @@ -120,7 +130,9 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + input_mask = ids_tensor( + [self.batch_size, self.seq_length], vocab_size=2 + ) config = ReformerConfig( vocab_size=self.vocab_size, @@ -166,7 +178,8 @@ def create_and_check_reformer_model( } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size] + list(result["sequence_output"].size()), + [self.batch_size, self.seq_length, 2 * self.hidden_size], ) def create_and_check_reformer_with_lm( @@ -175,17 +188,22 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + loss, prediction_scores = model( + input_ids, attention_mask=input_mask, labels=input_ids + ) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + list(result["prediction_scores"].size()), + [self.batch_size, self.seq_length, self.vocab_size], ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_with_attn_mask( + self, config, input_ids, input_mask, is_decoder + ): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -196,7 +214,9 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu # set all position encodings to zero so that postions don't matter with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) + embedding.weight = torch.nn.Parameter( + torch.zeros(embedding.weight.shape).to(torch_device) + ) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -205,32 +225,53 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu half_input_ids = input_ids[:, :half_seq_len] # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) + attn_mask = torch.cat( + [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], + dim=-1, + ) input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + [ + half_input_ids, + ids_tensor((self.batch_size, half_seq_len), self.vocab_size), + ], + dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + [ + half_input_ids, + ids_tensor((self.batch_size, half_seq_len), self.vocab_size), + ], + dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][ + :, :half_seq_len ] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[ + 0 + ][:, roll : half_seq_len + roll] - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(output_padded, output_padded_rolled, atol=1e-3) + ) - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_layer_dropout_seed( + self, config, input_ids, input_mask, is_decoder + ): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() - shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize # get random tensors hidden_states = floats_tensor(shape) @@ -238,28 +279,46 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ # now the random seeds for attention and feed forward is initialized # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + layer_outputs = layer( + prev_attn_output, hidden_states, attention_mask=input_mask + ) next_attn_output = layer_outputs.attn_output next_hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue(torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3)) + self.parent.assertTrue( + torch.allclose( + prev_attn_output + attn_outputs.hidden_states, + next_attn_output, + atol=1e-3, + ) + ) torch.manual_seed(layer.feed_forward_seed) feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) + self.parent.assertTrue( + torch.allclose( + next_hidden_states, + hidden_states + feed_forward_hidden_states, + atol=1e-3, + ) + ) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_forward( + self, config, input_ids, input_mask + ): model = ReformerModel(config=config) model.to(torch_device) model.half() model.eval() model(input_ids, attention_mask=input_mask) - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_generate( + self, config, input_ids, input_mask + ): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -273,8 +332,12 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict def setUp(self): - self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester( + self + ) + self.config_tester = ConfigTester( + self, config_class=ReformerConfig, hidden_size=37 + ) def test_config(self): self.config_tester.run_common_tests() @@ -285,8 +348,12 @@ def test_reformer_model(self): def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_model_with_attn_mask( + *config_and_inputs, True + ) + self.model_tester.create_and_check_reformer_model_with_attn_mask( + *config_and_inputs, False + ) def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -294,19 +361,27 @@ def test_reformer_with_lm(self): def test_reformer_layer_training_dropout(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_layer_dropout_seed( + *config_and_inputs, True + ) + self.model_tester.create_and_check_reformer_layer_dropout_seed( + *config_and_inputs, False + ) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + self.model_tester.create_and_check_reformer_model_fp16_forward( + *config_and_inputs + ) @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_model_classes = ( + (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + ) test_pruning = False test_headmasking = False test_torchscript = False @@ -357,7 +432,9 @@ def __init__( self.num_attention_heads = num_attention_heads self.num_hashes = num_hashes self.num_hidden_layers = len(attn_layers) - self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.num_buckets = ( + tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + ) self.lsh_attn_chunk_length = lsh_attn_chunk_length self.lsh_num_chunks_after = lsh_num_chunks_after self.lsh_num_chunks_before = lsh_num_chunks_before @@ -376,8 +453,12 @@ def __init__( self.hash_seed = hash_seed self.pad_token_id = pad_token_id - self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) - self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length + self.encoder_seq_length = seq_length // lsh_attn_chunk_length + ( + seq_length % lsh_attn_chunk_length != 0 + ) + self.key_length = ( + self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1 + ) * lsh_attn_chunk_length self.chunk_length = lsh_attn_chunk_length def prepare_config_and_inputs(self): @@ -385,7 +466,9 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + input_mask = ids_tensor( + [self.batch_size, self.seq_length], vocab_size=2 + ) config = ReformerConfig( vocab_size=self.vocab_size, @@ -432,7 +515,8 @@ def create_and_check_reformer_model( } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size] + list(result["sequence_output"].size()), + [self.batch_size, self.seq_length, 2 * self.hidden_size], ) def create_and_check_reformer_with_lm( @@ -441,17 +525,22 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + loss, prediction_scores = model( + input_ids, attention_mask=input_mask, labels=input_ids + ) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + list(result["prediction_scores"].size()), + [self.batch_size, self.seq_length, self.vocab_size], ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_with_attn_mask( + self, config, input_ids, input_mask, is_decoder + ): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -466,7 +555,9 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) + embedding.weight = torch.nn.Parameter( + torch.zeros(embedding.weight.shape).to(torch_device) + ) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -475,30 +566,51 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu half_input_ids = input_ids[:, :half_seq_len] # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1) + attn_mask = torch.cat( + [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], + dim=-1, + ) input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + [ + half_input_ids, + ids_tensor((self.batch_size, half_seq_len), self.vocab_size), + ], + dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1 + [ + half_input_ids, + ids_tensor((self.batch_size, half_seq_len), self.vocab_size), + ], + dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][ + :, :half_seq_len ] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[ + 0 + ][:, roll : half_seq_len + roll] - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(output_padded, output_padded_rolled, atol=1e-3) + ) - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_layer_dropout_seed( + self, config, input_ids, input_mask, is_decoder + ): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() - shape = (self.batch_size, self.seq_length, config.hidden_size) # Batch x SeqLen x hiddenSize + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize # get random tensors hidden_states = floats_tensor(shape) @@ -506,28 +618,46 @@ def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_ # now the random seeds for attention and feed forward is initialized # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + layer_outputs = layer( + prev_attn_output, hidden_states, attention_mask=input_mask + ) next_attn_output = layer_outputs.attn_output next_hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue(torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3)) + self.parent.assertTrue( + torch.allclose( + prev_attn_output + attn_outputs.hidden_states, + next_attn_output, + atol=1e-3, + ) + ) torch.manual_seed(layer.feed_forward_seed) feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue(torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3)) + self.parent.assertTrue( + torch.allclose( + next_hidden_states, + hidden_states + feed_forward_hidden_states, + atol=1e-3, + ) + ) # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_forward( + self, config, input_ids, input_mask + ): model = ReformerModel(config=config) model.to(torch_device) model.half() model.eval() model(input_ids, attention_mask=input_mask) - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_generate( + self, config, input_ids, input_mask + ): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -542,7 +672,9 @@ def prepare_config_and_inputs_for_common(self): def setUp(self): self.model_tester = ReformerLSHAttnModelTest.ReformerLSHAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + self.config_tester = ConfigTester( + self, config_class=ReformerConfig, hidden_size=37 + ) def test_config(self): self.config_tester.run_common_tests() @@ -553,8 +685,12 @@ def test_reformer_model(self): def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_model_with_attn_mask( + *config_and_inputs, True + ) + self.model_tester.create_and_check_reformer_model_with_attn_mask( + *config_and_inputs, False + ) def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -562,22 +698,212 @@ def test_reformer_with_lm(self): def test_reformer_layer_training_dropout(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + self.model_tester.create_and_check_reformer_layer_dropout_seed( + *config_and_inputs, True + ) + self.model_tester.create_and_check_reformer_layer_dropout_seed( + *config_and_inputs, False + ) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + self.model_tester.create_and_check_reformer_model_fp16_forward( + *config_and_inputs + ) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + self.model_tester.create_and_check_reformer_model_fp16_generate( + *config_and_inputs + ) @require_torch class ReformerIntegrationTests(unittest.TestCase): + """ + These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_add_model`. + """ + + def _get_basic_config_and_input(self): + config = { + "vocab_size": 320, + "attention_head_size": 8, + "hidden_size": 16, + "num_attention_heads": 2, + "num_buckets": 2, + "num_hashes": 4, + "lsh_attn_chunk_length": 4, + "local_attn_chunk_length": 4, + "lsh_num_chunks_before": 1, + "lsh_num_chunks_after": 0, + "local_num_chunks_before": 1, + "local_num_chunks_after": 0, + "chunk_size_lm_head": 0, + "chunk_size_feed_forward": 0, + "feed_forward_size": 32, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "lsh_attention_probs_dropout_prob": 0.0, + "local_attention_probs_dropout_prob": 0.0, + "max_position_embeddings": 16, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "sinusoidal_pos_embds": False, + "axial_pos_embds": True, + "axial_pos_shape": [4, 4], + "axial_pos_embds_dim": [8, 8], + "hash_seed": 0, + "is_decoder": True + } + return config + + def _get_hidden_states(self): + return torch.tensor( + [ + [ + [ + 1.90826353e00, + -1.45999730e00, + -6.20405462e-01, + 1.52503433e00, + -3.64464232e-01, + -8.27359235e-01, + 8.39670803e-01, + 2.44492178e-01, + 4.98332758e-01, + 2.69175139e00, + -7.08081422e-03, + 1.04915401e00, + -1.83476661e00, + 7.67220476e-01, + 2.98580543e-01, + 2.84803992e-02, + ], + [ + -2.66374286e-02, + 4.33497576e-01, + 3.10386309e-01, + 5.46039944e-01, + -2.47292666e-04, + -7.52305019e-01, + 2.39162103e-01, + 7.25216186e-01, + -7.58357372e-01, + 4.20635998e-01, + -4.04739919e-02, + 1.59924145e-01, + 2.05135748e00, + -1.15997978e00, + 5.37166397e-01, + 2.62873606e-01, + ], + [ + 1.85247482e-01, + 7.07046037e-01, + -6.77089715e-01, + -2.24209655e00, + -3.75307980e-02, + -8.59380874e-01, + -2.81027884e00, + 1.01276376e00, + -1.69438001e00, + 4.17574660e-01, + -1.49196962e00, + -1.76483717e00, + -1.94566312e-01, + -1.71183858e00, + 7.72903565e-01, + -1.11557056e00, + ], + [ + 9.46069193e-01, + 1.53417623e-01, + -9.58686996e-01, + 1.18126669e-01, + 1.75967724e00, + 1.62194590e00, + -5.74108159e-01, + 6.79920443e-01, + 5.44028163e-01, + 2.05466114e-01, + -3.63045868e-01, + 2.41865062e-01, + 3.20348382e-01, + -9.05611176e-01, + -1.92690727e-01, + -1.19917547e00, + ], + ] + ], + dtype=torch.float32, + device=torch_device, + ) + + def _get_attn_mask(self): + return torch.tensor([[0, 1, 0, 0]], dtype=torch.long, device=torch_device) + + def _get_input_ids(self): + pass + + def test_lsh_layer_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh"] + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + hf_layer = ReformerLayer(ReformerConfig(**config)) + hf_layer.eval() + reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor([ 1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + pass + + def test_lsh_layer_forward_complex(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh"] + config["num_buckets"] = [2, 4] + attn_mask = self._get_attn_mask() + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + hf_layer = ReformerLayer(ReformerConfig(**config)) + hf_layer.eval() + reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor([1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_layer_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local"] + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + hf_layer = ReformerLayer(ReformerConfig(**config)) + hf_layer.eval() + reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor([ 1.3856, -2.1086, -0.8903, 1.3960, -0.0667], dtype=torch.float, device=torch_device) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_layer_forward_complex(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local"] + attn_mask = self._get_attn_mask() + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + hf_layer = ReformerLayer(ReformerConfig(**config)) + hf_layer.eval() + reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor([ 1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +@require_torch +class ReformerIntegrationTestsDynamic(unittest.TestCase): + # This code has to be used with patrickvonplaten's fork of trax to work def test_lsh_layer(self): config = ReformerConfig(hash_seed=0) shape = (3, 64, config.hidden_size) # Batch x SeqLen x hiddenSize @@ -599,7 +925,9 @@ def test_lsh_layer(self): self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) hf_layer.eval() - hf_attention_all_heads = hf_layer.self_attention(hf_input, attention_mask=torch.tensor(mask))[0] + hf_attention_all_heads = hf_layer.self_attention( + hf_input, attention_mask=torch.tensor(mask) + )[0] hf_output = hf_layer.output(hf_attention_all_heads) self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) @@ -619,10 +947,14 @@ def test_local_layer(self): hf_input = torch.tensor(np_input, dtype=torch.float) config.attn_layers = ["local"] hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_local(trax_weights, hf_layer, config.hidden_size) + self._set_layer_weights_in_torch_local( + trax_weights, hf_layer, config.hidden_size + ) hf_layer.eval() - hf_attention_all_heads = hf_layer.self_attention(hf_input, attention_mask=torch.tensor(mask))[0] + hf_attention_all_heads = hf_layer.self_attention( + hf_input, attention_mask=torch.tensor(mask) + )[0] hf_output = hf_layer.output(hf_attention_all_heads) trax_torch_output = torch.tensor(np.asarray(trax_output)) @@ -636,6 +968,7 @@ def test_reformer_lm_model(self): np_input = np.random.randint(0, config.vocab_size, size=shape) np_input_2 = np.asarray(np_input, np.float32) mask = np.ones_like(np_input, dtype=np.float32) + mask[0, 10:20] = 0 np_zeros = np.zeros((shape[0], 1), dtype=np.int) mode = "train" @@ -648,11 +981,14 @@ def test_reformer_lm_model(self): if mode == "train": trax_model = trax.layers.Serial(trax_model, trax.layers.CrossEntropyLoss()) - input_signature = (trax_ShapeDtype(shape, np.int32), trax_ShapeDtype(shape, np.float32), trax_ShapeDtype(shape, np.float32)) + input_signature = ( + trax_ShapeDtype(shape, np.int32), + trax_ShapeDtype(shape, np.float32), + trax_ShapeDtype(shape, np.float32), + ) trax_weights, trax_state = trax_model.init(input_signature) trax_input = (np_input, np_input_2, mask) torch_trax_weights = trax_weights[0] - else: input_signature = trax_ShapeDtype(shape, np.int32) trax_weights, trax_state = trax_model.init(input_signature) @@ -662,17 +998,29 @@ def test_reformer_lm_model(self): torch_trax_weights = trax_weights if mode != "predict": - hf_input = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1) - hf_labels = torch.cat([torch.tensor(np_zeros), torch.tensor(np_input)], dim=-1) + hf_input = torch.cat( + [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 + ) + attention_mask = torch.tensor(mask) + hf_labels = ( + -100 * (1 - attention_mask) + torch.tensor(np_input) * attention_mask + ).to(dtype=torch.long) else: hf_input = torch.tensor(np_input) hf_model = ReformerModelWithLMHead(config) - self._set_model_weights_in_torch(torch_trax_weights, hf_model, config.hidden_size) + self._set_model_weights_in_torch( + torch_trax_weights, hf_model, config.hidden_size + ) if mode == "train": # uncomment line to fix hf_input_shifting in ReformerWithLMHead hf_model.train() + loss = hf_model(hf_input, attention_mask=attention_mask, labels=hf_labels)[ + 0 + ] + # Trax does not really use attention masks in their layers in this setup. The just + # mask the final loss loss = hf_model(hf_input, labels=hf_labels)[0] else: hf_output = hf_model(hf_input) @@ -690,7 +1038,9 @@ def model_and_loss_call(weights, batch, state): grad_fn = jax.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(trax_weights, trax_input, trax_state) - all_test_correct = self._set_model_weights_in_torch(grads[0], hf_model, config.hidden_size, set_params=False) + all_test_correct = self._set_model_weights_in_torch( + grads[0], hf_model, config.hidden_size, set_params=False + ) self.assertTrue(all_test_correct) def test_backprop_lm_model(self): @@ -698,7 +1048,9 @@ def test_backprop_lm_model(self): shape = (1, 192) # Batch x SeqLen x ModelDimPerHead input_ids = torch.tensor( - np.random.randint(0, config.vocab_size, size=shape), dtype=torch.long, device=torch_device + np.random.randint(0, config.vocab_size, size=shape), + dtype=torch.long, + device=torch_device, ) model = ReformerModelWithLMHead(config) @@ -708,18 +1060,24 @@ def test_backprop_lm_model(self): # use github old branch to make this test work. Pretrained weights # cannot be loaded into new code anymore def no_test_pretrained_crime_and_punishment_lm_model(self): - hf_model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish") + hf_model = ReformerModelWithLMHead.from_pretrained( + "patrickvonplaten/reformer-crime-and-punish" + ) config = hf_model.config - trax_model_path = "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" + trax_model_path = ( + "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" + ) - shape = (1, 128) + shape = (1, 512) np_input = np.random.randint(0, config.vocab_size, size=shape) hf_input = torch.tensor(np_input) input_signature = trax_ShapeDtype(shape, np.int32) - trax_model = self.load_crime_and_punishment_model(trax_model_path, input_signature) + trax_model = self.load_crime_and_punishment_model( + trax_model_path, input_signature + ) hf_output = hf_model(hf_input) log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) @@ -727,7 +1085,9 @@ def no_test_pretrained_crime_and_punishment_lm_model(self): trax_output = trax_model(np_input) trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - self.assertTrue(torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3)) + self.assertTrue( + torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) + ) def load_lsh_layer(self, config, mode="eval"): gin_config = """ @@ -897,7 +1257,9 @@ def load_reformer_lm_model(self, config, mode="eval"): model = trax.models.ReformerLM(mode=mode) return model - def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode="predict"): + def load_crime_and_punishment_model( + self, trax_model_path, input_signature, mode="predict" + ): gin.parse_config( """ import trax.layers @@ -969,30 +1331,42 @@ def load_crime_and_punishment_model(self, trax_model_path, input_signature, mode def _set_param(self, torch_layer, weight, bias=None, name=None): with torch.no_grad(): - assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) + assert ( + torch_layer.weight.shape == weight.shape + ), "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: - assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) + assert ( + torch_layer.bias.shape == bias.shape + ), "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) return True def _test_param(self, torch_layer, grad, bias_grad=None, name=""): - assert torch_layer.weight.grad.shape == grad.shape, "{} layer.grad does not match".format(torch_layer) + assert ( + torch_layer.weight.grad.shape == grad.shape + ), "{} layer.grad does not match".format(torch_layer) if torch.allclose(torch_layer.weight.grad, grad, atol=1e-3): print("{}-{} layer.grad is good!".format(name, torch_layer)) else: print("ERROR {}-{} layer.grad is not good!".format(name, torch_layer)) return False if bias_grad is not None: - assert torch_layer.bias.grad.shape == bias_grad.shape, "{} layer.bias does not match".format(torch_layer) + assert ( + torch_layer.bias.grad.shape == bias_grad.shape + ), "{} layer.bias does not match".format(torch_layer) if torch.allclose(torch_layer.bias.grad, bias_grad, atol=1e-3): print("{}-{} layer.grad bias is good!".format(name, torch_layer)) else: - print("ERROR {}-{} layer.grad bias is not good!".format(name, torch_layer)) + print( + "ERROR {}-{} layer.grad bias is not good!".format(name, torch_layer) + ) return False return True - def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size, exec_fn=None): + def _set_layer_weights_in_torch_lsh( + self, weights, torch_layer, hidden_size, exec_fn=None + ): all_test_true = True if exec_fn is None: exec_fn = self._set_param @@ -1002,25 +1376,46 @@ def _set_layer_weights_in_torch_lsh(self, weights, torch_layer, hidden_size, exe np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) - all_test_true = exec_fn( - torch_layer.self_attention.query_key, - torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), - name="attn_query_key" - ) and all_test_true - - all_test_true = exec_fn( - torch_layer.self_attention.value, - torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), - name="attn_value" - ) and all_test_true - - all_test_true = exec_fn( - torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), - name="attn_dense" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_layer.self_attention.query_key, + torch.tensor(np_query_key) + .transpose(1, 2) + .contiguous() + .view(-1, hidden_size), + name="attn_query_key", + ) + and all_test_true + ) + + all_test_true = ( + exec_fn( + torch_layer.self_attention.value, + torch.tensor(np_value) + .transpose(1, 2) + .contiguous() + .view(-1, hidden_size), + name="attn_value", + ) + and all_test_true + ) + + all_test_true = ( + exec_fn( + torch_layer.output.dense, + torch.tensor(np_dense) + .view(-1, hidden_size) + .contiguous() + .transpose(0, 1), + name="attn_dense", + ) + and all_test_true + ) return all_test_true - def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size, exec_fn=None): + def _set_layer_weights_in_torch_local( + self, weights, torch_layer, hidden_size, exec_fn=None + ): all_test_true = True if exec_fn is None: @@ -1032,30 +1427,55 @@ def _set_layer_weights_in_torch_local(self, weights, torch_layer, hidden_size, e np_value = np.asarray(weights[2]) np_dense = np.asarray(weights[3]) - all_test_true = exec_fn( - torch_layer.self_attention.query, - torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), - ) and all_test_true - all_test_true = exec_fn( - torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), - ) and all_test_true - all_test_true = exec_fn( - torch_layer.self_attention.value, - torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), - ) and all_test_true - all_test_true = exec_fn( - torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), - ) and all_test_true + all_test_true = ( + exec_fn( + torch_layer.self_attention.query, + torch.tensor(np_query) + .transpose(1, 2) + .contiguous() + .view(-1, hidden_size), + ) + and all_test_true + ) + all_test_true = ( + exec_fn( + torch_layer.self_attention.key, + torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + and all_test_true + ) + all_test_true = ( + exec_fn( + torch_layer.self_attention.value, + torch.tensor(np_value) + .transpose(1, 2) + .contiguous() + .view(-1, hidden_size), + ) + and all_test_true + ) + all_test_true = ( + exec_fn( + torch_layer.output.dense, + torch.tensor(np_dense) + .view(-1, hidden_size) + .contiguous() + .transpose(0, 1), + ) + and all_test_true + ) return all_test_true - def _set_block_weights_in_torch(self, weights, torch_block, hidden_size, exec_fn=None): + def _set_block_weights_in_torch( + self, weights, torch_block, hidden_size, exec_fn=None + ): all_test_true = True if exec_fn is None: exec_fn = self._set_param # intermediate weighs -# intermediate_weights = weights[2][0][2][2] + # intermediate_weights = weights[2][0][2][2] intermediate_weights = weights[2][0][1][2] # Chunked Feed Forward @@ -1065,49 +1485,78 @@ def _set_block_weights_in_torch(self, weights, torch_block, hidden_size, exec_fn # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) - all_test_true = exec_fn( - torch_block.feed_forward.output.dense, - torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), - torch.tensor(out_dense_bias), - name="res_feed_forward_2" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_block.feed_forward.output.dense, + torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(out_dense_bias), + name="res_feed_forward_2", + ) + and all_test_true + ) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) - all_test_true = exec_fn( - torch_block.feed_forward.dense.dense, - torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), - torch.tensor(inter_dense_bias), - name="res_feed_forward_1" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_block.feed_forward.dense.dense, + torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(inter_dense_bias), + name="res_feed_forward_1", + ) + and all_test_true + ) # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - all_test_true = exec_fn( - torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), name="layer_norm_2" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_block.feed_forward.layer_norm, + torch.tensor(layer_norm_2_weight), + torch.tensor(layer_norm_2_bias), + name="layer_norm_2", + ) + and all_test_true + ) # lsh weights + output attn_weights = weights[0][1] if len(attn_weights) < 4: - all_test_true = self._set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn) and all_test_true + all_test_true = ( + self._set_layer_weights_in_torch_lsh( + attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn + ) + and all_test_true + ) else: - all_test_true = self._set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn) and all_test_true + all_test_true = ( + self._set_layer_weights_in_torch_local( + attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn + ) + and all_test_true + ) # layernorm 1 layer_norm_1 = weights[0][0][0] layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) - all_test_true = exec_fn( - torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), - name="layer_norm" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_block.attention.layer_norm, + torch.tensor(layer_norm_1_weight), + torch.tensor(layer_norm_1_bias), + name="layer_norm", + ) + and all_test_true + ) return all_test_true - def _set_model_weights_in_torch(self, weights, torch_model, hidden_size, set_params=True): + def _set_model_weights_in_torch( + self, weights, torch_model, hidden_size, set_params=True + ): # reformer model torch_model_reformer = torch_model.reformer @@ -1123,51 +1572,78 @@ def _set_model_weights_in_torch(self, weights, torch_model, hidden_size, set_par # output embeddings output_embed_weights = np.asarray(out_weights[2][0]) output_embed_bias = np.asarray(out_weights[2][1]) - all_test_true = exec_fn( - torch_model.lm_head.decoder, - torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), - torch.tensor(output_embed_bias), - name="lm_head" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_model.lm_head.decoder, + torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), + torch.tensor(output_embed_bias), + name="lm_head", + ) + and all_test_true + ) # output layer norm layer_norm_out_weight = np.asarray(out_weights[0][0]) layer_norm_out_bias = np.asarray(out_weights[0][1]) - all_test_true = exec_fn( - torch_model_reformer.encoder.layer_norm, - torch.tensor(layer_norm_out_weight), - torch.tensor(layer_norm_out_bias), - name="last layer norm" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_model_reformer.encoder.layer_norm, + torch.tensor(layer_norm_out_weight), + torch.tensor(layer_norm_out_bias), + name="last layer norm", + ) + and all_test_true + ) trax_layer_weights = weights[5] assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( trax_layer_weights ), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers[::-1]): - block_weights = trax_layer_weights[:-1][::-1][4 * layer_idx : 4 * (layer_idx + 1)][::-1] - all_test_true = self._set_block_weights_in_torch(block_weights, layer, hidden_size, exec_fn=exec_fn) and all_test_true + block_weights = trax_layer_weights[:-1][::-1][ + 4 * layer_idx : 4 * (layer_idx + 1) + ][::-1] + all_test_true = ( + self._set_block_weights_in_torch( + block_weights, layer, hidden_size, exec_fn=exec_fn + ) + and all_test_true + ) if isinstance(weights[3], tuple): position_embeddings = torch_model_reformer.embeddings.position_embeddings for emb_idx in range(len(position_embeddings.weights)): emb_weights = np.asarray(weights[3][emb_idx][0]) - assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( - position_embeddings[emb_idx] - ) + assert ( + position_embeddings.weights[emb_idx].shape == emb_weights.shape + ), "{} emb does not match".format(position_embeddings[emb_idx]) if set_params is True: - position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + position_embeddings.weights[emb_idx] = torch.nn.Parameter( + torch.tensor(emb_weights) + ) else: - if torch.allclose(position_embeddings.weights[emb_idx].grad, torch.tensor(emb_weights), atol=1e-3): + if torch.allclose( + position_embeddings.weights[emb_idx].grad, + torch.tensor(emb_weights), + atol=1e-3, + ): print("{} layer.grad is good!".format(position_embeddings)) else: - print("ERROR: {}-{} layer.grad is not good".format(position_embeddings, "axs_pos_embeds")) + print( + "ERROR: {}-{} layer.grad is not good".format( + position_embeddings, "axs_pos_embeds" + ) + ) # word embeds word_embeddings = np.asarray(weights[1]) - all_test_true = exec_fn( - torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), - name="word_embed" - ) and all_test_true + all_test_true = ( + exec_fn( + torch_model_reformer.embeddings.word_embeddings, + torch.tensor(word_embeddings), + name="word_embed", + ) + and all_test_true + ) return all_test_true From ea1126e8e9479d100c304d40f44fc3305fc59d51 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 11:51:04 +0200 Subject: [PATCH 108/162] finalize tests --- src/transformers/modeling_reformer.py | 77 +++-- tests/test_modeling_reformer.py | 441 +++++++++++++++++++++++--- 2 files changed, 456 insertions(+), 62 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 9321d8236001..e79cd5e28854 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -65,6 +65,7 @@ def mish(x): LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) +ReformerBackwardOutput = namedtuple("ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]) def create_sinusoidal_embeddings(n_pos, dim, out): @@ -358,6 +359,7 @@ def __init__(self, config): self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size + # projection matrices self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -437,7 +439,7 @@ def forward( del query_key_vectors, key_vectors, value_vectors # calculate total concatenad chunks to split gradients correctly - out_vectors, logits = UndoSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes) + out_vectors, logits = GatherSorted.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes) if num_hashes > 1: out_vectors = self._split_dim_by( @@ -491,7 +493,7 @@ def _hash_vectors(self, vectors, num_hashes): # TODO: delete later when integration tests are ok if self.hash_seed is not None: rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) - torch.manual_seed(self.hash_seed) + np.random.seed(self.hash_seed) random_rotations = torch.tensor( np.random.normal(size=rotations_shape), dtype=vectors.dtype, device=vectors.device, ) @@ -508,6 +510,17 @@ def _hash_vectors(self, vectors, num_hashes): # rotation is used batches rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) +# if self.hash_seed is not None: + # for determinism +# torch.manual_seed(self.hash_seed) +# +# rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) + # create a random self.attention_head_size x num_hashes x num_buckets/2 +# random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) +# + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 +# rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) buckets = torch.argmax(rotated_vectors, dim=-1) @@ -688,7 +701,7 @@ def _gather_by_expansion(self, vectors, idxs, num_hashes): return torch.gather(vectors, 2, expanded_idxs) -class UndoSort(Function): +class GatherSorted(Function): @staticmethod def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): # save sorted_bucket_idx for backprop @@ -751,6 +764,7 @@ def __init__(self, config): self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size + # projection matrices self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -1021,6 +1035,9 @@ def forward( ): with torch.no_grad(): + # every forward pass we sample a different seed + # for dropout and save seed for forward fn in backward + # to have correct dropout self._init_attention_seed() attn_outputs = self.attention( hidden_states=hidden_states, @@ -1038,8 +1055,11 @@ def forward( # free memory del prev_attn_output - # Y_2 = X_2 + g(Y_1) + # every forward pass we sample a different seed + # for dropout and save seed for forward fn in backward + # to have correct dropout self._init_feed_forward_seed() + # Y_2 = X_2 + g(Y_1) hidden_states = hidden_states + self.feed_forward(attn_output) return ReformerOutput( @@ -1064,9 +1084,10 @@ def backward_pass( with torch.enable_grad(): next_attn_output.requires_grad = True + # set seed to have correct dropout + torch.manual_seed(self.feed_forward_seed) # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # g(Y_1) - torch.manual_seed(self.feed_forward_seed) res_hidden_states = self.feed_forward(next_attn_output) res_hidden_states.backward(grad_hidden_states, retain_graph=True) @@ -1080,9 +1101,11 @@ def backward_pass( with torch.enable_grad(): hidden_states.requires_grad = True + + # set seed to have correct dropout + torch.manual_seed(self.attention_seed) # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention - torch.manual_seed(self.attention_seed) output = self.attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, ).hidden_states @@ -1097,7 +1120,12 @@ def backward_pass( hidden_states.grad = None hidden_states = hidden_states.detach() - return attn_output, hidden_states, grad_attn_output, grad_hidden_states + return ReformerBackwardOutput( + attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) class _ReversibleFunction(Function): @@ -1159,7 +1187,18 @@ def backward(ctx, grad_hidden_states): grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) # retrieve params from ctx for backward - attn_output, hidden_states, = ctx.saved_tensors + attn_output, hidden_states = ctx.saved_tensors + + # create tuple + output = ReformerBackwardOutput(attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) + + # free memory + del grad_attn_output, grad_hidden_states, attn_output, hidden_states + layers = ctx.layers all_buckets = ctx.all_buckets head_mask = ctx.head_mask @@ -1171,22 +1210,18 @@ def backward(ctx, grad_hidden_states): all_buckets = all_buckets[:-1] # backprop - (attn_output, hidden_states, grad_attn_output, grad_hidden_states,) = layer.backward_pass( - next_attn_output=attn_output, - hidden_states=hidden_states, - grad_attn_output=grad_attn_output, - grad_hidden_states=grad_hidden_states, + output = layer.backward_pass( + next_attn_output=output.attn_output, + hidden_states=output.hidden_states, + grad_attn_output=output.grad_attn_output, + grad_hidden_states=output.grad_hidden_states, head_mask=head_mask[len(layers) - idx - 1], attention_mask=attention_mask, buckets=buckets, ) assert all_buckets == (), "buckets have to be empty after backpropagation" - - # free memory - del attn_output, hidden_states - - grad_hidden_states = torch.cat([grad_attn_output, grad_hidden_states], dim=-1) + grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1) # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args @@ -1511,8 +1546,10 @@ def forward( if labels is not None: # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() +# shift_logits = logits[..., :-1, :].contiguous() +# shift_labels = labels[..., 1:].contiguous() + shift_logits = logits.contiguous() + shift_labels = labels.contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 70f809f6a479..c80117314363 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -15,6 +15,13 @@ import unittest +from transformers import is_torch_available + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from .utils import CACHE_DIR, require_torch, slow, torch_device + + import gin import numpy as np @@ -24,18 +31,14 @@ from transformers import is_torch_available from trax.shapes import ShapeDtype as trax_ShapeDtype -from .test_configuration_common import ConfigTester -from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor -from .utils import CACHE_DIR, require_torch, slow, torch_device - if is_torch_available(): from transformers import ( ReformerConfig, - ReformerAttention, ReformerModel, ReformerModelWithLMHead, ReformerLayer, + ReformerAttention ) # from transformers.modeling_reformer import REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP @@ -44,7 +47,6 @@ @require_torch class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () ) @@ -182,6 +184,15 @@ def create_and_check_reformer_model( [self.batch_size, self.seq_length, 2 * self.hidden_size], ) + def create_and_check_reformer_model_with_lm_backward( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] + loss.backward() + def create_and_check_reformer_with_lm( self, config, input_ids, input_mask, ): @@ -346,6 +357,12 @@ def test_reformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def test_reformer_lm_model_backward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_lm_backward( + *config_and_inputs + ) + def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_with_attn_mask( @@ -378,7 +395,6 @@ def test_reformer_model_fp16_forward(self): @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () ) @@ -519,6 +535,15 @@ def create_and_check_reformer_model( [self.batch_size, self.seq_length, 2 * self.hidden_size], ) + def create_and_check_reformer_model_with_lm_backward( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] + loss.backward() + def create_and_check_reformer_with_lm( self, config, input_ids, input_mask, ): @@ -683,6 +708,12 @@ def test_reformer_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def test_reformer_lm_model_backward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_lm_backward( + *config_and_inputs + ) + def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_with_attn_mask( @@ -723,7 +754,7 @@ def test_reformer_model_fp16_generate(self): @require_torch class ReformerIntegrationTests(unittest.TestCase): """ - These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_add_model`. + These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`. """ def _get_basic_config_and_input(self): @@ -747,16 +778,16 @@ def _get_basic_config_and_input(self): "hidden_dropout_prob": 0.0, "lsh_attention_probs_dropout_prob": 0.0, "local_attention_probs_dropout_prob": 0.0, - "max_position_embeddings": 16, + "max_position_embeddings": 32, "initializer_range": 0.02, "axial_norm_std": 1.0, "layer_norm_eps": 1e-12, "sinusoidal_pos_embds": False, "axial_pos_embds": True, - "axial_pos_shape": [4, 4], + "axial_pos_shape": [4, 8], "axial_pos_embds_dim": [8, 8], "hash_seed": 0, - "is_decoder": True + "is_decoder": True, } return config @@ -845,19 +876,176 @@ def _get_hidden_states(self): def _get_attn_mask(self): return torch.tensor([[0, 1, 0, 0]], dtype=torch.long, device=torch_device) - def _get_input_ids(self): - pass + def _get_input_ids_and_mask(self): + mask = torch.tensor( + [ + [ + 1, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 1, + 0, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 0, + ], + ], + dtype=torch.long, + device=torch_device, + ) + + input_ids = torch.tensor( + [ + [ + 89, + 279, + 286, + 84, + 194, + 316, + 182, + 28, + 283, + 37, + 169, + 7, + 253, + 267, + 107, + 250, + 44, + 7, + 102, + 62, + 3, + 243, + 171, + 265, + 302, + 48, + 164, + 264, + 148, + 229, + 280, + 150, + ], + [ + 9, + 192, + 66, + 112, + 163, + 83, + 135, + 70, + 224, + 96, + 31, + 80, + 196, + 80, + 63, + 22, + 85, + 100, + 47, + 283, + 0, + 163, + 126, + 143, + 195, + 82, + 53, + 82, + 18, + 27, + 182, + 52, + ], + ], + dtype=torch.long, + device=torch_device, + ) + + return input_ids, mask def test_lsh_layer_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["lsh"] + config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) - hf_layer = ReformerLayer(ReformerConfig(**config)) - hf_layer.eval() - reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states) + layer = ReformerLayer(ReformerConfig(**config)) + layer.eval() + reformer_output = layer( + prev_attn_output=hidden_states, hidden_states=hidden_states + ) output_slice = reformer_output.hidden_states[0, 0, :5] - expected_output_slice = torch.tensor([ 1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device) + expected_output_slice = torch.tensor( + [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], + dtype=torch.float, + device=torch_device, + ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) pass @@ -868,23 +1056,38 @@ def test_lsh_layer_forward_complex(self): attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) - hf_layer = ReformerLayer(ReformerConfig(**config)) - hf_layer.eval() - reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask) + layer = ReformerLayer(ReformerConfig(**config)) + layer.eval() + reformer_output = layer( + prev_attn_output=hidden_states, + hidden_states=hidden_states, + attention_mask=attn_mask, + ) output_slice = reformer_output.hidden_states[0, 0, :5] - expected_output_slice = torch.tensor([1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device) + expected_output_slice = torch.tensor( + [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], + dtype=torch.float, + device=torch_device, + ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) def test_local_layer_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local"] + config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) - hf_layer = ReformerLayer(ReformerConfig(**config)) - hf_layer.eval() - reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states) + layer = ReformerLayer(ReformerConfig(**config)) + layer.eval() + reformer_output = layer( + prev_attn_output=hidden_states, hidden_states=hidden_states + ) output_slice = reformer_output.hidden_states[0, 0, :5] - expected_output_slice = torch.tensor([ 1.3856, -2.1086, -0.8903, 1.3960, -0.0667], dtype=torch.float, device=torch_device) + expected_output_slice = torch.tensor( + [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], + dtype=torch.float, + device=torch_device, + ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) def test_local_layer_forward_complex(self): @@ -893,13 +1096,180 @@ def test_local_layer_forward_complex(self): attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) - hf_layer = ReformerLayer(ReformerConfig(**config)) - hf_layer.eval() - reformer_output = hf_layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask) + layer = ReformerLayer(ReformerConfig(**config)) + layer.eval() + reformer_output = layer( + prev_attn_output=hidden_states, + hidden_states=hidden_states, + attention_mask=attn_mask, + ) output_slice = reformer_output.hidden_states[0, 0, :5] - expected_output_slice = torch.tensor([ 1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device) + expected_output_slice = torch.tensor( + [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_lsh_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] + config["num_buckets"] = [2, 4] + torch.manual_seed(0) + model = ReformerModel(ReformerConfig(**config)) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], + dtype=torch.float, + device=torch_device, + ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + def test_local_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "local", "local", "local"] + torch.manual_seed(0) + model = ReformerModel(ReformerConfig(**config)) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_lm_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "lsh", "local", "lsh", "local", "lsh"] + config["num_buckets"] = [2, 4] + config["is_decoder"] = False + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[1, -1, :5] + expected_output_slice = torch.tensor( + [0.0324, -0.0121, 0.0615, 0.0031, -0.0297], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_lm_model_grad(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "local", "local", "local"] + config["hidden_dropout_prob"] = 0.0 + config["local_attention_probs_dropout_prob"] = 0.0 + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)) + model.train() + model.zero_grad() + input_ids, _ = self._get_input_ids_and_mask() + loss = model(input_ids=input_ids, labels=input_ids)[0] + loss.backward() + # check last grads to cover all proable errors + grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + expected_grad_slice_word = torch.tensor( + [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], + dtype=torch.float, + device=torch_device, + ) + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[ + 0 + ][ + 1, 0, -5: + ] + expected_grad_slice_pos_fac_1 = torch.tensor( + [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], + dtype=torch.float, + device=torch_device, + ) + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[ + 1 + ][ + 0, 1, :5 + ] + expected_grad_slice_pos_fac_2 = torch.tensor( + [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue( + torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3) + ) + self.assertTrue( + torch.allclose( + grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3 + ) + ) + self.assertTrue( + torch.allclose( + grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3 + ) + ) + + def test_lsh_lm_model_grad(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] + config["hidden_dropout_prob"] = 0.0 + config["lsh_attention_probs_dropout_prob"] = 0.0 + config["num_buckets"] = [2, 4] + config["num_hashes"] = 6 + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)) + model.train() + model.zero_grad() + input_ids, _ = self._get_input_ids_and_mask() + loss = model(input_ids=input_ids, labels=input_ids)[0] + loss.backward() + # check last grads to cover all proable errors + grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + expected_grad_slice_word = torch.tensor( + [2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], + dtype=torch.float, + device=torch_device, + ) + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[ + 0 + ][ + 1, 0, -5: + ] + expected_grad_slice_pos_fac_1 = torch.tensor( + [-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], + dtype=torch.float, + device=torch_device, + ) + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[ + 1 + ][ + 0, 1, :5 + ] + expected_grad_slice_pos_fac_2 = torch.tensor( + [0.4626, -0.0231, -0.0172, 0.1081, 0.3805], + dtype=torch.float, + device=torch_device, + ) + self.assertTrue( + torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3) + ) + self.assertTrue( + torch.allclose( + grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3 + ) + ) + self.assertTrue( + torch.allclose( + grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3 + ) + ) + @require_torch class ReformerIntegrationTestsDynamic(unittest.TestCase): @@ -1016,9 +1386,6 @@ def test_reformer_lm_model(self): if mode == "train": # uncomment line to fix hf_input_shifting in ReformerWithLMHead hf_model.train() - loss = hf_model(hf_input, attention_mask=attention_mask, labels=hf_labels)[ - 0 - ] # Trax does not really use attention masks in their layers in this setup. The just # mask the final loss loss = hf_model(hf_input, labels=hf_labels)[0] @@ -1092,7 +1459,6 @@ def no_test_pretrained_crime_and_punishment_lm_model(self): def load_lsh_layer(self, config, mode="eval"): gin_config = """ import trax.layers - # Parameters for LSHSelfAttention: # ============================================================================== LSHSelfAttention.n_heads = {} @@ -1129,7 +1495,6 @@ def load_lsh_layer(self, config, mode="eval"): def load_local_layer(self, config, mask=False, mode="eval"): gin_config = """ import trax.layers - # Parameters for SelfAttention: # ============================================================================== SelfAttention.n_heads = {} @@ -1183,7 +1548,6 @@ def load_reformer_lm_model(self, config, mode="eval"): gin_config = """ import trax.layers import trax.models - # Parameters for LSHSelfAttention: # ============================================================================== LSHSelfAttention.chunk_len = {} @@ -1196,7 +1560,6 @@ def load_reformer_lm_model(self, config, mode="eval"): LSHSelfAttention.lsh_seed = {} LSHSelfAttention.causal= {} LSHSelfAttention.use_reference_code = True - # Parameters for SelfAttention: # ============================================================================== SelfAttention.chunk_len = {} @@ -1205,7 +1568,6 @@ def load_reformer_lm_model(self, config, mode="eval"): SelfAttention.causal= {} SelfAttention.use_reference_code = True SelfAttention.share_qk = False - # Parameters for ReformerLM: # ============================================================================== ReformerLM.vocab_size = {} @@ -1222,7 +1584,6 @@ def load_reformer_lm_model(self, config, mode="eval"): ReformerLM.ff_activation = @trax.layers.{} ReformerLM.attention_type = @trax.layers.{} ReformerLM.dropout = 0.0 - ReformerLM.n_chunks = 0 ReformerLM.ff_use_sru = 0 """.format( @@ -1267,7 +1628,6 @@ def load_crime_and_punishment_model( import trax.optimizers import trax.supervised.inputs import trax.supervised.trainer_lib - # Parameters that will vary between experiments: # ============================================================================== train.model = @trax.models.ReformerLM @@ -1286,14 +1646,12 @@ def load_crime_and_punishment_model( attn_kv = 64 dropout = 0.05 n_tokens = 524288 - # Parameters for SelfAttention: # ============================================================================== SelfAttention.chunk_len = 64 SelfAttention.n_chunks_before = 1 SelfAttention.n_parallel_heads = 1 SelfAttention.share_qk = False - # Parameters for LSHSelfAttention: # ============================================================================== LSHSelfAttention.chunk_len = 64 @@ -1305,7 +1663,6 @@ def load_crime_and_punishment_model( LSHSelfAttention.predict_drop_len = 32 # different from original to make code equal LSHSelfAttention.predict_mem_len = 64 # different from original to make code equal LSHSelfAttention.lsh_seed = 0 - # Parameters for ReformerLM: # ============================================================================== ReformerLM.attention_type = %attn_type From d4bc3c6237035737baf025df1848325f46d44185 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 12:11:53 +0200 Subject: [PATCH 109/162] add more tests --- src/transformers/__init__.py | 2 +- src/transformers/modeling_reformer.py | 40 +- tests/test_modeling_reformer.py | 769 +------------------------- 3 files changed, 33 insertions(+), 778 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6fcad4e06db0..062c78890423 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -323,7 +323,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ) - from .modeling_reformer import ReformerAttention, ReformerLayer, ReformerModel, ReformerModelWithLMHead + from .modeling_reformer import ReformerAttention, ReformerLayer, ReformerModel, ReformerModelWithLMHead, REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP # Optimization from .optimization import ( diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index e79cd5e28854..486be046a3c4 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -490,36 +490,16 @@ def _hash_vectors(self, vectors, num_hashes): # remove gradient vectors = vectors.detach() - # TODO: delete later when integration tests are ok if self.hash_seed is not None: - rotations_shape = (vectors.shape[-1], num_hashes, rotation_size // 2) - np.random.seed(self.hash_seed) - random_rotations = torch.tensor( - np.random.normal(size=rotations_shape), dtype=vectors.dtype, device=vectors.device, - ) - rotated_vectors = torch.einsum("bmtd,dhr->bmhtr", vectors, random_rotations) - else: - rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) - # create a random self.attention_head_size x num_hashes x num_buckets/2 - random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) - - # rotated_vectors has dim: - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 - # TODO: IMPORTANT: At the moment we use the same random rotation over all batches - # -> is that bad? It seems like in original reformer a different random - # rotation is used batches - rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) - -# if self.hash_seed is not None: # for determinism -# torch.manual_seed(self.hash_seed) -# -# rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) + torch.manual_seed(self.hash_seed) + + rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x num_buckets/2 -# random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) -# + random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 -# rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) + rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) @@ -1546,10 +1526,10 @@ def forward( if labels is not None: # Shift so that tokens < n predict n -# shift_logits = logits[..., :-1, :].contiguous() -# shift_labels = labels[..., 1:].contiguous() - shift_logits = logits.contiguous() - shift_labels = labels.contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() +# shift_logits = logits.contiguous() +# shift_labels = labels.contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c80117314363..b8a06680db0e 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -21,27 +21,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from .utils import CACHE_DIR, require_torch, slow, torch_device - -import gin -import numpy as np - -# trax imports - to be deleted later -import trax -import jax -from transformers import is_torch_available -from trax.shapes import ShapeDtype as trax_ShapeDtype - - if is_torch_available(): from transformers import ( ReformerConfig, ReformerModel, ReformerModelWithLMHead, + ReformerTokenizer, ReformerLayer, - ReformerAttention + REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP ) - - # from transformers.modeling_reformer import REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP import torch @@ -750,6 +738,12 @@ def test_reformer_model_fp16_generate(self): *config_and_inputs ) + @slow + def test_model_from_pretrained(self): + for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + model = ReformerModelWithLMHead.from_pretrained(model_name, cache_dir=CACHE_DIR) + self.assertIsNotNone(model) + @require_torch class ReformerIntegrationTests(unittest.TestCase): @@ -1173,7 +1167,10 @@ def test_local_lm_model_grad(self): model.zero_grad() input_ids, _ = self._get_input_ids_and_mask() loss = model(input_ids=input_ids, labels=input_ids)[0] + + self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) loss.backward() + # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( @@ -1228,6 +1225,8 @@ def test_lsh_lm_model_grad(self): model.zero_grad() input_ids, _ = self._get_input_ids_and_mask() loss = model(input_ids=input_ids, labels=input_ids)[0] + + self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3)) loss.backward() # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] @@ -1270,737 +1269,13 @@ def test_lsh_lm_model_grad(self): ) ) + @slow + def test_pretrained_generate_crime_and_punish(self): + model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR) + tokenizer = ReformerTokenizer.from_pretrained("patrickvonplaten/reformer-crime-and-punish") + model.eval() -@require_torch -class ReformerIntegrationTestsDynamic(unittest.TestCase): - # This code has to be used with patrickvonplaten's fork of trax to work - def test_lsh_layer(self): - config = ReformerConfig(hash_seed=0) - shape = (3, 64, config.hidden_size) # Batch x SeqLen x hiddenSize - np_input = np.random.rand(*shape) - - trax_layer = self.load_lsh_layer(config) - input_signature = trax_ShapeDtype(shape, np.float32) - trax_weights, trax_state = trax_layer.init(input_signature) - - mask = np.ones(shape[:-1], dtype=np.int32) - - trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) - - trax_torch_output = torch.tensor(np.asarray(trax_output)) - - hf_input = torch.tensor(np_input, dtype=torch.float) - config.attn_layers = ["lsh"] - hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_lsh(trax_weights, hf_layer, config.hidden_size) - hf_layer.eval() - - hf_attention_all_heads = hf_layer.self_attention( - hf_input, attention_mask=torch.tensor(mask) - )[0] - hf_output = hf_layer.output(hf_attention_all_heads) - - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - - def test_local_layer(self): - config = ReformerConfig(hash_seed=0) - shape = (1, 64, config.hidden_size) # Batch x SeqLen x hiddenSize - np_input = np.random.rand(*shape) - - trax_layer = self.load_local_layer(config) - input_signature = trax_ShapeDtype(shape, np.float32) - trax_weights, trax_state = trax_layer.init(input_signature) - mask = np.ones(shape[:-1], dtype=np.int32) - - trax_output = trax_layer(np_input, weights=trax_weights, state=trax_state) - - hf_input = torch.tensor(np_input, dtype=torch.float) - config.attn_layers = ["local"] - hf_layer = ReformerAttention(config) - self._set_layer_weights_in_torch_local( - trax_weights, hf_layer, config.hidden_size - ) - hf_layer.eval() - - hf_attention_all_heads = hf_layer.self_attention( - hf_input, attention_mask=torch.tensor(mask) - )[0] - hf_output = hf_layer.output(hf_attention_all_heads) - - trax_torch_output = torch.tensor(np.asarray(trax_output)) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3)) - - def test_reformer_lm_model(self): - config = ReformerConfig(axial_pos_embds=True, hash_seed=0, is_decoder=True) - - shape = (1, 512) # Batch x SeqLen x ModelDimPerHead - - np_input = np.random.randint(0, config.vocab_size, size=shape) - np_input_2 = np.asarray(np_input, np.float32) - mask = np.ones_like(np_input, dtype=np.float32) - mask[0, 10:20] = 0 - np_zeros = np.zeros((shape[0], 1), dtype=np.int) - - mode = "train" - - trax_model = self.load_reformer_lm_model(config, mode=mode) - - assert ( - config.is_decoder is True - ), "trax can only test casaul mask for ReformerLM. Use tests for layers to test non-casaul mask" - - if mode == "train": - trax_model = trax.layers.Serial(trax_model, trax.layers.CrossEntropyLoss()) - input_signature = ( - trax_ShapeDtype(shape, np.int32), - trax_ShapeDtype(shape, np.float32), - trax_ShapeDtype(shape, np.float32), - ) - trax_weights, trax_state = trax_model.init(input_signature) - trax_input = (np_input, np_input_2, mask) - torch_trax_weights = trax_weights[0] - else: - input_signature = trax_ShapeDtype(shape, np.int32) - trax_weights, trax_state = trax_model.init(input_signature) - trax_input = np_input - trax_output = trax_model(trax_input, weights=trax_weights, state=trax_state) - trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - torch_trax_weights = trax_weights - - if mode != "predict": - hf_input = torch.cat( - [torch.tensor(np_zeros), torch.tensor(np_input[:, :-1])], dim=-1 - ) - attention_mask = torch.tensor(mask) - hf_labels = ( - -100 * (1 - attention_mask) + torch.tensor(np_input) * attention_mask - ).to(dtype=torch.long) - else: - hf_input = torch.tensor(np_input) - - hf_model = ReformerModelWithLMHead(config) - self._set_model_weights_in_torch( - torch_trax_weights, hf_model, config.hidden_size - ) - - if mode == "train": - # uncomment line to fix hf_input_shifting in ReformerWithLMHead - hf_model.train() - # Trax does not really use attention masks in their layers in this setup. The just - # mask the final loss - loss = hf_model(hf_input, labels=hf_labels)[0] - else: - hf_output = hf_model(hf_input) - hf_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-4)) - - if mode == "train": - hf_model.zero_grad() - loss.backward() - - def model_and_loss_call(weights, batch, state): - res = trax_model(batch, weights=weights, state=state) - return res, trax_model.state - - grad_fn = jax.grad(model_and_loss_call, has_aux=True) - grads, state = grad_fn(trax_weights, trax_input, trax_state) - - all_test_correct = self._set_model_weights_in_torch( - grads[0], hf_model, config.hidden_size, set_params=False - ) - self.assertTrue(all_test_correct) - - def test_backprop_lm_model(self): - config = ReformerConfig() - - shape = (1, 192) # Batch x SeqLen x ModelDimPerHead - input_ids = torch.tensor( - np.random.randint(0, config.vocab_size, size=shape), - dtype=torch.long, - device=torch_device, - ) - - model = ReformerModelWithLMHead(config) - loss = model(input_ids, labels=input_ids)[0] - loss.backward() - - # use github old branch to make this test work. Pretrained weights - # cannot be loaded into new code anymore - def no_test_pretrained_crime_and_punishment_lm_model(self): - hf_model = ReformerModelWithLMHead.from_pretrained( - "patrickvonplaten/reformer-crime-and-punish" - ) - config = hf_model.config - - trax_model_path = ( - "/home/patrick/hugging_face/models/trained_reformer_colab/model.pkl" - ) - - shape = (1, 512) - np_input = np.random.randint(0, config.vocab_size, size=shape) - - hf_input = torch.tensor(np_input) - - input_signature = trax_ShapeDtype(shape, np.int32) - trax_model = self.load_crime_and_punishment_model( - trax_model_path, input_signature - ) - - hf_output = hf_model(hf_input) - log_softmax_output = torch.nn.functional.log_softmax(hf_output[0], dim=-1) - - trax_output = trax_model(np_input) - trax_torch_output = torch.tensor(np.asarray(trax_output[0])) - - self.assertTrue( - torch.allclose(log_softmax_output, trax_torch_output, atol=1e-3) - ) - - def load_lsh_layer(self, config, mode="eval"): - gin_config = """ - import trax.layers - # Parameters for LSHSelfAttention: - # ============================================================================== - LSHSelfAttention.n_heads = {} - LSHSelfAttention.d_qk = {} - LSHSelfAttention.d_v = {} - LSHSelfAttention.chunk_len = {} - LSHSelfAttention.n_chunks_before = {} - LSHSelfAttention.n_chunks_after = {} - LSHSelfAttention.n_hashes = {} - LSHSelfAttention.n_buckets = {} - LSHSelfAttention.attention_dropout = {} - LSHSelfAttention.output_dropout = {} - LSHSelfAttention.lsh_seed = {} - LSHSelfAttention.causal= {} - LSHSelfAttention.use_reference_code = True - """.format( - config.num_attention_heads, - config.attention_head_size, - config.attention_head_size, - config.lsh_attn_chunk_length, - config.lsh_num_chunks_before, - config.lsh_num_chunks_after, - config.num_hashes, - config.num_buckets, - config.lsh_attention_probs_dropout_prob, - config.hidden_dropout_prob, - config.hash_seed, - config.is_decoder, - ) - gin.parse_config(gin_config) - layer = trax.layers.LSHSelfAttention(mode=mode) - return layer - - def load_local_layer(self, config, mask=False, mode="eval"): - gin_config = """ - import trax.layers - # Parameters for SelfAttention: - # ============================================================================== - SelfAttention.n_heads = {} - SelfAttention.d_qk = {} - SelfAttention.d_v = {} - SelfAttention.chunk_len = {} - SelfAttention.n_chunks_before = {} - SelfAttention.n_chunks_after = {} - SelfAttention.attention_dropout = {} - SelfAttention.output_dropout = {} - SelfAttention.causal = {} - SelfAttention.masked= {} - SelfAttention.use_reference_code = True - """.format( - config.num_attention_heads, - config.attention_head_size, - config.attention_head_size, - config.local_attn_chunk_length, - config.local_num_chunks_before, - config.local_num_chunks_after, - config.local_attention_probs_dropout_prob, - config.hidden_dropout_prob, - config.is_decoder, - mask, - ) - gin.parse_config(gin_config) - layer = trax.layers.SelfAttention(mode=mode) - return layer - - def load_reformer_lm_model(self, config, mode="eval"): - if config.hidden_act == "gelu": - hidden_act = "Gelu" - elif config.hidden_act == "relu": - hidden_act = "Relu" - else: - raise ValueError() - attn_type = config.attn_layers[0] - if attn_type == "lsh": - attn_type = "LSHSelfAttention" - elif attn_type == "local": - attn_type = "SelfAttention" - else: - raise ValueError() - if config.sinusoidal_pos_embds is True: - axial_pos_shape = () - d_axial_pos_embs = None - else: - axial_pos_shape = config.axial_pos_shape - d_axial_pos_embs = config.axial_pos_embds_dim - - gin_config = """ - import trax.layers - import trax.models - # Parameters for LSHSelfAttention: - # ============================================================================== - LSHSelfAttention.chunk_len = {} - LSHSelfAttention.predict_mem_len = {} - LSHSelfAttention.predict_drop_len = {} - LSHSelfAttention.n_chunks_before = {} - LSHSelfAttention.n_chunks_after = {} - LSHSelfAttention.n_hashes = {} - LSHSelfAttention.n_buckets = {} - LSHSelfAttention.lsh_seed = {} - LSHSelfAttention.causal= {} - LSHSelfAttention.use_reference_code = True - # Parameters for SelfAttention: - # ============================================================================== - SelfAttention.chunk_len = {} - SelfAttention.n_chunks_before = {} - SelfAttention.n_chunks_after = {} - SelfAttention.causal= {} - SelfAttention.use_reference_code = True - SelfAttention.share_qk = False - # Parameters for ReformerLM: - # ============================================================================== - ReformerLM.vocab_size = {} - ReformerLM.d_model = {} - ReformerLM.d_ff = {} - ReformerLM.d_attention_key = {} - ReformerLM.d_attention_value = {} - ReformerLM.n_layers = {} - ReformerLM.n_heads = {} - ReformerLM.max_len = {} - ReformerLM.axial_pos_shape = {} - ReformerLM.d_axial_pos_embs = {} - ReformerLM.ff_chunk_size = {} - ReformerLM.ff_activation = @trax.layers.{} - ReformerLM.attention_type = @trax.layers.{} - ReformerLM.dropout = 0.0 - ReformerLM.n_chunks = 0 - ReformerLM.ff_use_sru = 0 - """.format( - config.lsh_attn_chunk_length, - config.lsh_attn_chunk_length, - config.lsh_attn_chunk_length // 2, - config.lsh_num_chunks_before, - config.lsh_num_chunks_after, - config.num_hashes, - config.num_buckets, - config.hash_seed, - config.is_decoder, - config.local_attn_chunk_length, - config.local_num_chunks_before, - config.local_num_chunks_after, - config.is_decoder, - config.vocab_size, - config.hidden_size, - config.feed_forward_size, - config.attention_head_size, - config.attention_head_size, - config.num_hidden_layers, - config.num_attention_heads, - config.max_position_embeddings, - axial_pos_shape, - d_axial_pos_embs, - config.chunk_size_feed_forward, - hidden_act, - attn_type, - ) - gin.parse_config(gin_config) - model = trax.models.ReformerLM(mode=mode) - return model - - def load_crime_and_punishment_model( - self, trax_model_path, input_signature, mode="predict" - ): - gin.parse_config( - """ - import trax.layers - import trax.models - import trax.optimizers - import trax.supervised.inputs - import trax.supervised.trainer_lib - # Parameters that will vary between experiments: - # ============================================================================== - train.model = @trax.models.ReformerLM - # Our model will have 6 layers, alternating between the LSH attention proposed - # in the Reformer paper and local attention within a certain context window. - n_layers = 6 - attn_type = [ - @SelfAttention, - @LSHSelfAttention, - @SelfAttention, - @LSHSelfAttention, - @SelfAttention, - @LSHSelfAttention, - ] - n_heads = 2 - attn_kv = 64 - dropout = 0.05 - n_tokens = 524288 - # Parameters for SelfAttention: - # ============================================================================== - SelfAttention.chunk_len = 64 - SelfAttention.n_chunks_before = 1 - SelfAttention.n_parallel_heads = 1 - SelfAttention.share_qk = False - # Parameters for LSHSelfAttention: - # ============================================================================== - LSHSelfAttention.chunk_len = 64 - LSHSelfAttention.n_buckets = [64, 128] - LSHSelfAttention.n_chunks_after = 0 - LSHSelfAttention.n_chunks_before = 1 - LSHSelfAttention.n_hashes = 1 - LSHSelfAttention.n_parallel_heads = 1 - LSHSelfAttention.predict_drop_len = 32 # different from original to make code equal - LSHSelfAttention.predict_mem_len = 64 # different from original to make code equal - LSHSelfAttention.lsh_seed = 0 - # Parameters for ReformerLM: - # ============================================================================== - ReformerLM.attention_type = %attn_type - ReformerLM.d_attention_key = %attn_kv - ReformerLM.d_attention_value = %attn_kv - ReformerLM.d_model = 256 - ReformerLM.d_ff = 512 - ReformerLM.dropout = %dropout - ReformerLM.ff_activation = @trax.layers.Relu - ReformerLM.max_len = %n_tokens - ReformerLM.mode = 'train' - ReformerLM.n_heads = %n_heads - ReformerLM.n_layers = %n_layers - ReformerLM.vocab_size = 320 - ReformerLM.axial_pos_shape = (512, 1024) - ReformerLM.d_axial_pos_embs= (64, 192) - """ - ) - trax_model = trax.models.ReformerLM(mode=mode) - trax_model.init(input_signature) - trax_model.init_from_file(trax_model_path, weights_only=True) - return trax_model - - def _set_param(self, torch_layer, weight, bias=None, name=None): - with torch.no_grad(): - assert ( - torch_layer.weight.shape == weight.shape - ), "{} layer.weight does not match".format(torch_layer) - torch_layer.weight = torch.nn.Parameter(weight) - if bias is not None: - assert ( - torch_layer.bias.shape == bias.shape - ), "{} layer.bias does not match".format(torch_layer) - torch_layer.bias = torch.nn.Parameter(bias) - return True - - def _test_param(self, torch_layer, grad, bias_grad=None, name=""): - assert ( - torch_layer.weight.grad.shape == grad.shape - ), "{} layer.grad does not match".format(torch_layer) - if torch.allclose(torch_layer.weight.grad, grad, atol=1e-3): - print("{}-{} layer.grad is good!".format(name, torch_layer)) - else: - print("ERROR {}-{} layer.grad is not good!".format(name, torch_layer)) - return False - if bias_grad is not None: - assert ( - torch_layer.bias.grad.shape == bias_grad.shape - ), "{} layer.bias does not match".format(torch_layer) - if torch.allclose(torch_layer.bias.grad, bias_grad, atol=1e-3): - print("{}-{} layer.grad bias is good!".format(name, torch_layer)) - else: - print( - "ERROR {}-{} layer.grad bias is not good!".format(name, torch_layer) - ) - return False - return True - - def _set_layer_weights_in_torch_lsh( - self, weights, torch_layer, hidden_size, exec_fn=None - ): - all_test_true = True - if exec_fn is None: - exec_fn = self._set_param - - # set torch weights for 1-to-1 comparison - np_query_key = np.asarray(weights[0]) - np_value = np.asarray(weights[1]) - np_dense = np.asarray(weights[2]) - - all_test_true = ( - exec_fn( - torch_layer.self_attention.query_key, - torch.tensor(np_query_key) - .transpose(1, 2) - .contiguous() - .view(-1, hidden_size), - name="attn_query_key", - ) - and all_test_true - ) - - all_test_true = ( - exec_fn( - torch_layer.self_attention.value, - torch.tensor(np_value) - .transpose(1, 2) - .contiguous() - .view(-1, hidden_size), - name="attn_value", - ) - and all_test_true - ) - - all_test_true = ( - exec_fn( - torch_layer.output.dense, - torch.tensor(np_dense) - .view(-1, hidden_size) - .contiguous() - .transpose(0, 1), - name="attn_dense", - ) - and all_test_true - ) - return all_test_true - - def _set_layer_weights_in_torch_local( - self, weights, torch_layer, hidden_size, exec_fn=None - ): - all_test_true = True - - if exec_fn is None: - exec_fn = self._set_param - - # set torch weights for 1-to-1 comparison - np_query = np.asarray(weights[0]) - np_key = np.asarray(weights[1]) - np_value = np.asarray(weights[2]) - np_dense = np.asarray(weights[3]) - - all_test_true = ( - exec_fn( - torch_layer.self_attention.query, - torch.tensor(np_query) - .transpose(1, 2) - .contiguous() - .view(-1, hidden_size), - ) - and all_test_true - ) - all_test_true = ( - exec_fn( - torch_layer.self_attention.key, - torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), - ) - and all_test_true - ) - all_test_true = ( - exec_fn( - torch_layer.self_attention.value, - torch.tensor(np_value) - .transpose(1, 2) - .contiguous() - .view(-1, hidden_size), - ) - and all_test_true - ) - all_test_true = ( - exec_fn( - torch_layer.output.dense, - torch.tensor(np_dense) - .view(-1, hidden_size) - .contiguous() - .transpose(0, 1), - ) - and all_test_true - ) - return all_test_true - - def _set_block_weights_in_torch( - self, weights, torch_block, hidden_size, exec_fn=None - ): - all_test_true = True - - if exec_fn is None: - exec_fn = self._set_param - - # intermediate weighs - # intermediate_weights = weights[2][0][2][2] - intermediate_weights = weights[2][0][1][2] - - # Chunked Feed Forward - if len(intermediate_weights) == 4: - intermediate_weights = intermediate_weights[2] - - # intermediate out - out_dense_weight = np.asarray(intermediate_weights[4][0]) - out_dense_bias = np.asarray(intermediate_weights[4][1]) - all_test_true = ( - exec_fn( - torch_block.feed_forward.output.dense, - torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), - torch.tensor(out_dense_bias), - name="res_feed_forward_2", - ) - and all_test_true - ) - - # intermediate dense - inter_dense_weight = np.asarray(intermediate_weights[1][0]) - inter_dense_bias = np.asarray(intermediate_weights[1][1]) - all_test_true = ( - exec_fn( - torch_block.feed_forward.dense.dense, - torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), - torch.tensor(inter_dense_bias), - name="res_feed_forward_1", - ) - and all_test_true - ) - - # layernorm 2 - layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) - layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) - all_test_true = ( - exec_fn( - torch_block.feed_forward.layer_norm, - torch.tensor(layer_norm_2_weight), - torch.tensor(layer_norm_2_bias), - name="layer_norm_2", - ) - and all_test_true - ) - - # lsh weights + output - attn_weights = weights[0][1] - if len(attn_weights) < 4: - all_test_true = ( - self._set_layer_weights_in_torch_lsh( - attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn - ) - and all_test_true - ) - else: - all_test_true = ( - self._set_layer_weights_in_torch_local( - attn_weights, torch_block.attention, hidden_size, exec_fn=exec_fn - ) - and all_test_true - ) - - # layernorm 1 - layer_norm_1 = weights[0][0][0] - layer_norm_1_weight = np.asarray(layer_norm_1[0]) - layer_norm_1_bias = np.asarray(layer_norm_1[1]) - all_test_true = ( - exec_fn( - torch_block.attention.layer_norm, - torch.tensor(layer_norm_1_weight), - torch.tensor(layer_norm_1_bias), - name="layer_norm", - ) - and all_test_true - ) - - return all_test_true - - def _set_model_weights_in_torch( - self, weights, torch_model, hidden_size, set_params=True - ): - # reformer model - torch_model_reformer = torch_model.reformer - - all_test_true = True - if set_params is True: - exec_fn = self._set_param - else: - exec_fn = self._test_param - - # output weights - out_weights = weights[6] - - # output embeddings - output_embed_weights = np.asarray(out_weights[2][0]) - output_embed_bias = np.asarray(out_weights[2][1]) - all_test_true = ( - exec_fn( - torch_model.lm_head.decoder, - torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), - torch.tensor(output_embed_bias), - name="lm_head", - ) - and all_test_true - ) - - # output layer norm - layer_norm_out_weight = np.asarray(out_weights[0][0]) - layer_norm_out_bias = np.asarray(out_weights[0][1]) - all_test_true = ( - exec_fn( - torch_model_reformer.encoder.layer_norm, - torch.tensor(layer_norm_out_weight), - torch.tensor(layer_norm_out_bias), - name="last layer norm", - ) - and all_test_true - ) - - trax_layer_weights = weights[5] - assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( - trax_layer_weights - ), "HF and trax model do not have the same number of layers" - for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers[::-1]): - block_weights = trax_layer_weights[:-1][::-1][ - 4 * layer_idx : 4 * (layer_idx + 1) - ][::-1] - all_test_true = ( - self._set_block_weights_in_torch( - block_weights, layer, hidden_size, exec_fn=exec_fn - ) - and all_test_true - ) - - if isinstance(weights[3], tuple): - position_embeddings = torch_model_reformer.embeddings.position_embeddings - for emb_idx in range(len(position_embeddings.weights)): - emb_weights = np.asarray(weights[3][emb_idx][0]) - assert ( - position_embeddings.weights[emb_idx].shape == emb_weights.shape - ), "{} emb does not match".format(position_embeddings[emb_idx]) - if set_params is True: - position_embeddings.weights[emb_idx] = torch.nn.Parameter( - torch.tensor(emb_weights) - ) - else: - if torch.allclose( - position_embeddings.weights[emb_idx].grad, - torch.tensor(emb_weights), - atol=1e-3, - ): - print("{} layer.grad is good!".format(position_embeddings)) - else: - print( - "ERROR: {}-{} layer.grad is not good".format( - position_embeddings, "axs_pos_embeds" - ) - ) - - # word embeds - word_embeddings = np.asarray(weights[1]) - all_test_true = ( - exec_fn( - torch_model_reformer.embeddings.word_embeddings, - torch.tensor(word_embeddings), - name="word_embed", - ) - and all_test_true - ) - - return all_test_true + input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) + output_ids = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False) + output_text = tokenizer.decode(output_ids[0]) + self.assertEqual(output_text, "A few months later state expression in his ideas, at the first entrance. He was positively for an inst") From 94086ace1f976fd452fe2643f3d548a5b9b779bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 12:36:39 +0200 Subject: [PATCH 110/162] finish tests --- tests/test_modeling_reformer.py | 83 +++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index b8a06680db0e..b604c770d26b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -304,7 +304,26 @@ def create_and_check_reformer_layer_dropout_seed( atol=1e-3, ) ) - # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) + + def create_and_check_reformer_feed_forward_chunking( + self, config, input_ids, input_mask + ): + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] + + config.chunk_size_lm_head = input_ids.shape[-1] + config.chunk_size_feed_forward = input_ids.shape[-1] + + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + + hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) def create_and_check_reformer_model_fp16_forward( self, config, input_ids, input_mask @@ -373,6 +392,10 @@ def test_reformer_layer_training_dropout(self): *config_and_inputs, False ) + def test_reformer_chunking_forward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -418,6 +441,7 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, sinusoidal_pos_embds=True, + axial_pos_embds=True, attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0, eos_token_id=2, @@ -449,13 +473,13 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.sinusoidal_pos_embds = sinusoidal_pos_embds self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_feed_forward = chunk_size_feed_forward self.scope = scope self.attn_layers = attn_layers self.hash_seed = hash_seed self.pad_token_id = pad_token_id + self.axial_pos_embds = axial_pos_embds self.encoder_seq_length = seq_length // lsh_attn_chunk_length + ( seq_length % lsh_attn_chunk_length != 0 @@ -485,7 +509,7 @@ def prepare_config_and_inputs(self): lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, is_decoder=self.is_decoder, - sinusoidal_pos_embds=self.sinusoidal_pos_embds, + axial_pos_embds=self.axial_pos_embds, num_hashes=self.num_hashes, num_buckets=self.num_buckets, lsh_attn_chunk_length=self.lsh_attn_chunk_length, @@ -657,25 +681,54 @@ def create_and_check_reformer_layer_dropout_seed( atol=1e-3, ) ) - # TODO(PVP) Should add some tests here for backprop function as well (maybe in different test function though) - def create_and_check_reformer_model_fp16_forward( + def create_and_check_reformer_feed_backward_chunking( self, config, input_ids, input_mask ): - model = ReformerModel(config=config) + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) model.to(torch_device) - model.half() - model.eval() - model(input_ids, attention_mask=input_mask) + model.train() + model.zero_grad() + loss_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] + loss_no_chunk.backward() + grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[ + 0 + ][ + 1, 0, -5: + ] + grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[ + 1 + ][ + 0, 1, :5 + ] - def create_and_check_reformer_model_fp16_generate( - self, config, input_ids, input_mask - ): + config.chunk_size_lm_head = input_ids.shape[-1] + config.chunk_size_feed_forward = input_ids.shape[-1] + + torch.manual_seed(0) model = ReformerModelWithLMHead(config=config) model.to(torch_device) - model.half() - model.eval() - model.generate(input_ids, attention_mask=input_mask, do_sample=False) + model.train() + model.zero_grad() + loss_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] + loss_chunk.backward() + grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[ + 0 + ][ + 1, 0, -5: + ] + grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[ + 1 + ][ + 0, 1, :5 + ] + self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() From 01f40746fc1b66381dfb81ebe22d543af4e93437 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 May 2020 10:03:04 +0200 Subject: [PATCH 111/162] fix --- src/transformers/trainer.py | 22 +++------------------- src/transformers/training_args.py | 4 ---- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 68822dcb1861..6191f324b316 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -197,25 +197,9 @@ def get_optimizers( }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon, betas=(self.args.adam_beta_1, self.args.adam_beta_2)) - - if self.args.scheduler == "linear": - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps - ) - elif self.args.scheduler == "cosine_decay": - scheduler = get_cosine_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine_decay - ) - elif self.args.scheduler == "cosine_decay_hard_restarts": - scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, num_cycles=self.args.num_cycles_cosine_decay - ) - elif self.args.scheduler == "constant": - scheduler = get_constant_schedule_with_warmup( - optimizer, num_warmup_steps=self.args.warmup_steps - ) - else: - raise NotImplementedError("The scheduler {} does not exist. Please choose one of the following schedulers ['linear', 'cosine_decay', 'cosine_decay_hard_restarts', 'constant']".format(self.args.scheduler)) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps + ) return optimizer, scheduler diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0e1e1167c49c..5bffd44b0d3c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -55,8 +55,6 @@ class TrainingArguments: learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) - adam_beta_1: float = field(default=0.9, metadata={"helf": "Beta 1 for Adam optimizer."}) - adam_beta_2: float = field(default=0.999, metadata={"helf": "Beta 2 for Adam optimizer."}) max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) @@ -65,8 +63,6 @@ class TrainingArguments: metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) - scheduler: str = field(default="linear", metadata={"help": "Name of learning rate scheduler to use. Choose between ['linear', 'cosine_decay', 'cosine_decay_hard_restarts', 'constant']"}) - num_cycles_cosine_decay: float = field(default=1.0, metadata={"help": "Number of cosine cycles when using cosine decay schedules"}) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"}) From 9dafbc23906d5b3ce16aec4f8218684550be6639 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 12:44:13 +0200 Subject: [PATCH 112/162] fix type trainer --- src/transformers/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6191f324b316..2fcfe4531d7b 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -196,11 +196,10 @@ def get_optimizers( "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon, betas=(self.args.adam_beta_1, self.args.adam_beta_2)) + optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps ) - return optimizer, scheduler def _setup_wandb(self): From de08a57037b1c7a5cbbbddf57d17c4707a2bfe63 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 13:05:42 +0200 Subject: [PATCH 113/162] fix fp16 tests --- tests/test_modeling_reformer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index b604c770d26b..db336b80f951 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -730,6 +730,24 @@ def create_and_check_reformer_feed_backward_chunking( self.parent.assertTrue(torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3)) self.parent.assertTrue(torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)) + def create_and_check_reformer_model_fp16_forward( + self, config, input_ids, input_mask + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + model(input_ids, attention_mask=input_mask) + + def create_and_check_reformer_model_fp16_generate( + self, config, input_ids, input_mask + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + model.generate(input_ids, attention_mask=input_mask, do_sample=False) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask,) = config_and_inputs From aa570dc7c86956c22c299869893a28b468871631 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 13:08:46 +0200 Subject: [PATCH 114/162] fix tests --- tests/test_modeling_reformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index db336b80f951..85838144cf4c 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1103,7 +1103,7 @@ def test_lsh_layer_forward(self): layer = ReformerLayer(ReformerConfig(**config)) layer.eval() reformer_output = layer( - prev_attn_output=hidden_states, hidden_states=hidden_states + prev_attn_output=hidden_states.clone(), hidden_states=hidden_states ) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( @@ -1112,7 +1112,6 @@ def test_lsh_layer_forward(self): device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - pass def test_lsh_layer_forward_complex(self): config = self._get_basic_config_and_input() @@ -1124,7 +1123,7 @@ def test_lsh_layer_forward_complex(self): layer = ReformerLayer(ReformerConfig(**config)) layer.eval() reformer_output = layer( - prev_attn_output=hidden_states, + prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask, ) From a681d19371c8d7dde375ad14c91d55ef9d5a3068 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 13:10:33 +0200 Subject: [PATCH 115/162] fix tests --- tests/test_modeling_reformer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 85838144cf4c..fbb317ca7dc5 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1100,7 +1100,7 @@ def test_lsh_layer_forward(self): config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) - layer = ReformerLayer(ReformerConfig(**config)) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( prev_attn_output=hidden_states.clone(), hidden_states=hidden_states @@ -1120,7 +1120,7 @@ def test_lsh_layer_forward_complex(self): attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) - layer = ReformerLayer(ReformerConfig(**config)) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( prev_attn_output=hidden_states.clone(), @@ -1141,7 +1141,7 @@ def test_local_layer_forward(self): config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) - layer = ReformerLayer(ReformerConfig(**config)) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( prev_attn_output=hidden_states, hidden_states=hidden_states @@ -1160,7 +1160,7 @@ def test_local_layer_forward_complex(self): attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) - layer = ReformerLayer(ReformerConfig(**config)) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( prev_attn_output=hidden_states, @@ -1180,7 +1180,7 @@ def test_lsh_model_forward(self): config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] config["num_buckets"] = [2, 4] torch.manual_seed(0) - model = ReformerModel(ReformerConfig(**config)) + model = ReformerModel(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] @@ -1196,7 +1196,7 @@ def test_local_model_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local", "local", "local", "local"] torch.manual_seed(0) - model = ReformerModel(ReformerConfig(**config)) + model = ReformerModel(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] @@ -1214,7 +1214,7 @@ def test_lm_model_forward(self): config["num_buckets"] = [2, 4] config["is_decoder"] = False torch.manual_seed(0) - model = ReformerModelWithLMHead(ReformerConfig(**config)) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] @@ -1232,7 +1232,7 @@ def test_local_lm_model_grad(self): config["hidden_dropout_prob"] = 0.0 config["local_attention_probs_dropout_prob"] = 0.0 torch.manual_seed(0) - model = ReformerModelWithLMHead(ReformerConfig(**config)) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) model.train() model.zero_grad() input_ids, _ = self._get_input_ids_and_mask() @@ -1290,7 +1290,7 @@ def test_lsh_lm_model_grad(self): config["num_buckets"] = [2, 4] config["num_hashes"] = 6 torch.manual_seed(0) - model = ReformerModelWithLMHead(ReformerConfig(**config)) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) model.train() model.zero_grad() input_ids, _ = self._get_input_ids_and_mask() @@ -1341,7 +1341,7 @@ def test_lsh_lm_model_grad(self): @slow def test_pretrained_generate_crime_and_punish(self): - model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR) + model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR).to(torch_device) tokenizer = ReformerTokenizer.from_pretrained("patrickvonplaten/reformer-crime-and-punish") model.eval() From 320c045af4fdcd21b20309c985f71bea17edfcc9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 13:11:33 +0200 Subject: [PATCH 116/162] fix tests --- tests/test_modeling_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index fbb317ca7dc5..815b693dee05 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -812,7 +812,7 @@ def test_reformer_model_fp16_generate(self): @slow def test_model_from_pretrained(self): for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = ReformerModelWithLMHead.from_pretrained(model_name, cache_dir=CACHE_DIR) + model = ReformerModelWithLMHead.from_pretrained(model_name, cache_dir=CACHE_DIR).to(torch_device) self.assertIsNotNone(model) From 482c6cd6cce27fd777d33cd447f6627d409664af Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 15:29:22 +0200 Subject: [PATCH 117/162] fix issue with dropout --- src/transformers/modeling_reformer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 486be046a3c4..5e362ffe8c93 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -997,11 +997,12 @@ def __init__(self, config, layer_id=0): self.feed_forward_seed = None def _init_attention_seed(self): - self.attention_seed = torch.randint(sys.maxsize, size=(1, 1)).item() + # randomize seeds + self.attention_seed = int(torch.random.seed() % sys.maxsize) torch.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): - self.feed_forward_seed = torch.randint(sys.maxsize, size=(1, 1)).item() + self.feed_forward_seed = int(torch.random.seed() % sys.maxsize) torch.manual_seed(self.feed_forward_seed) def forward( @@ -1013,10 +1014,9 @@ def forward( num_hashes=None, do_output_attentions=False, ): - with torch.no_grad(): # every forward pass we sample a different seed - # for dropout and save seed for forward fn in backward + # for dropout and save seed for forward fn in backward # to have correct dropout self._init_attention_seed() attn_outputs = self.attention( @@ -1035,8 +1035,8 @@ def forward( # free memory del prev_attn_output - # every forward pass we sample a different seed - # for dropout and save seed for forward fn in backward + # every forward pass we sample a different seed + # for dropout and save seed for forward fn in backward # to have correct dropout self._init_feed_forward_seed() # Y_2 = X_2 + g(Y_1) @@ -1170,7 +1170,7 @@ def backward(ctx, grad_hidden_states): attn_output, hidden_states = ctx.saved_tensors # create tuple - output = ReformerBackwardOutput(attn_output=attn_output, + output = ReformerBackwardOutput(attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, From d7905dde1cd81b74a63c8b52698bb33ce2de4213 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 15:43:15 +0200 Subject: [PATCH 118/162] fix dropout seeds --- src/transformers/modeling_reformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 5e362ffe8c93..f40043f17f77 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -998,11 +998,16 @@ def __init__(self, config, layer_id=0): def _init_attention_seed(self): # randomize seeds - self.attention_seed = int(torch.random.seed() % sys.maxsize) + torch.random.seed() + # set new random seed + self.attention_seed = torch.randint(sys.maxsize, size=(1, 1)).item() torch.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): - self.feed_forward_seed = int(torch.random.seed() % sys.maxsize) + # randomize seeds + torch.random.seed() + # set new random seed + self.feed_forward_seed = torch.randint(sys.maxsize, size=(1, 1)).item() torch.manual_seed(self.feed_forward_seed) def forward( From 764e06ed9eb8015d401acf65f200e647c5e99f0b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 14:30:34 +0000 Subject: [PATCH 119/162] correct random seed on gpu --- examples/summarization/t5/output.txt | 0 src/transformers/modeling_reformer.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) create mode 100644 examples/summarization/t5/output.txt diff --git a/examples/summarization/t5/output.txt b/examples/summarization/t5/output.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index f40043f17f77..3537187eb470 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -998,17 +998,16 @@ def __init__(self, config, layer_id=0): def _init_attention_seed(self): # randomize seeds - torch.random.seed() - # set new random seed - self.attention_seed = torch.randint(sys.maxsize, size=(1, 1)).item() - torch.manual_seed(self.attention_seed) + device_idx = torch.cuda.current_device() + self.attention_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): # randomize seeds - torch.random.seed() - # set new random seed - self.feed_forward_seed = torch.randint(sys.maxsize, size=(1, 1)).item() - torch.manual_seed(self.feed_forward_seed) + device_idx = torch.cuda.current_device() + self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.feed_forward_seed) + def forward( self, From a3e0f59c4a6be946bc7e3a96e9ce1c90bd9fb7a0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 16:50:53 +0200 Subject: [PATCH 120/162] finalize random seed for dropout --- src/transformers/modeling_reformer.py | 26 +++++++++++----- tests/test_modeling_reformer.py | 44 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3537187eb470..d79184f91860 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -998,16 +998,28 @@ def __init__(self, config, layer_id=0): def _init_attention_seed(self): # randomize seeds - device_idx = torch.cuda.current_device() - self.attention_seed = torch.cuda.default_generators[device_idx].seed() - torch.cuda.manual_seed(self.attention_seed) + if next(self.parameters()).device.type == "cuda": + # GPU + device_idx = torch.cuda.current_device() + self.attention_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.attention_seed) + else: + # CPU + self.attention_seed = int(torch.seed() % sys.maxsize) + torch.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): # randomize seeds - device_idx = torch.cuda.current_device() - self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() - torch.cuda.manual_seed(self.feed_forward_seed) - + if next(self.parameters()).device.type == "cuda": + # GPU + device_idx = torch.cuda.current_device() + device_idx = torch.cuda.current_device() + self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.feed_forward_seed) + else: + # CPU + self.feed_forward_seed = int(torch.seed() % sys.maxsize) + torch.manual_seed(self.feed_forward_seed) def forward( self, diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 815b693dee05..e09a3699110a 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -682,6 +682,44 @@ def create_and_check_reformer_layer_dropout_seed( ) ) + @slow + def create_and_check_reformer_random_seed( + self, config, input_ids, input_mask + ): + layer = ReformerLayer(config).to(torch_device) + layer.train() + + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + hidden_states = floats_tensor(shape) + attn_output = floats_tensor(shape) + + seeds = [] + for _ in range(100): + layer_outputs = layer( + attn_output, hidden_states, attention_mask=input_mask + ) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.attention_seed) + seeds.append(layer.attention_seed) + self.parent.assertGreater(len(set(seeds)), 50) + + seeds = [] + for _ in range(100): + layer_outputs = layer( + attn_output, hidden_states, attention_mask=input_mask + ) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.feed_forward_seed) + seeds.append(layer.feed_forward_seed) + self.parent.assertGreater(len(set(seeds)), 50) + def create_and_check_reformer_feed_backward_chunking( self, config, input_ids, input_mask ): @@ -795,6 +833,12 @@ def test_reformer_layer_training_dropout(self): *config_and_inputs, False ) + def test_dropout_random_seed_is_changing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_random_seed( + *config_and_inputs + ) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() From c48f88a88bcf951ab0d4c3fbc25bb18aa112915c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 16:51:37 +0200 Subject: [PATCH 121/162] finalize random seed for dropout --- tests/test_modeling_reformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index e09a3699110a..55f09c5fb74f 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -707,7 +707,7 @@ def create_and_check_reformer_random_seed( hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) seeds.append(layer.attention_seed) - self.parent.assertGreater(len(set(seeds)), 50) + self.parent.assertGreater(len(set(seeds)), 70) seeds = [] for _ in range(100): @@ -718,7 +718,7 @@ def create_and_check_reformer_random_seed( hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.feed_forward_seed) seeds.append(layer.feed_forward_seed) - self.parent.assertGreater(len(set(seeds)), 50) + self.parent.assertGreater(len(set(seeds)), 70) def create_and_check_reformer_feed_backward_chunking( self, config, input_ids, input_mask From ce87cb6755e80c15207b7066e674490061322e4e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Apr 2020 16:57:29 +0200 Subject: [PATCH 122/162] remove duplicate line --- src/transformers/modeling_reformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index d79184f91860..9446da0d6ec4 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1013,7 +1013,6 @@ def _init_feed_forward_seed(self): if next(self.parameters()).device.type == "cuda": # GPU device_idx = torch.cuda.current_device() - device_idx = torch.cuda.current_device() self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() torch.cuda.manual_seed(self.feed_forward_seed) else: From d418dd089e2dfdef4a9bc82775ddff793baab78d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 May 2020 08:48:26 +0000 Subject: [PATCH 123/162] correct half precision bug --- src/transformers/modeling_reformer.py | 45 +++++++++++++++++---------- tests/test_modeling_reformer.py | 18 ++++------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 9446da0d6ec4..c4f1f213d1b8 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -365,6 +365,12 @@ def __init__(self, config): self.dropout = config.lsh_attention_probs_dropout_prob + # save mask value here + self.register_buffer("self_mask_value_float16", torch.tensor(-1e3)) + self.register_buffer("self_mask_value_float32", torch.tensor(-1e5)) + self.register_buffer("mask_value_float16", torch.tensor(-1e4)) + self.register_buffer("mask_value_float32", torch.tensor(-1e9)) + def forward( self, hidden_states, @@ -612,17 +618,21 @@ def _attend( # free memory del attn_mask + # get correct mask values depending on precision + if query_key_dots.dtype == torch.float16: + self_mask_value = self.self_mask_value_float16 + mask_value = self.mask_value_float16 + else: + self_mask_value = self.self_mask_value_float32 + mask_value = self.mask_value_float32 + + # if attention_mask and/or casaul mask apply here if mask is not None: - query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device), - ) + query_key_dots = torch.where(mask, query_key_dots, mask_value) - # Self mask is ALWAYS applied - # Note: Causal mask probably uses higher mask value (-1e9) than Self mask (-1e5) so that token is able to attend to itself when it has no other valid attention targets. - # Note: Self mask is used because Q and K projection weights are shared. + # Self mask is ALWAYS applied. # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): - # " While attention to the future is not allowed, typical implementations of the # Transformer do allow a position to attend to itself. # Such behavior is undesirable in a shared-QK formulation because the dot-product @@ -631,9 +641,7 @@ def _attend( # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) - query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e5, dtype=query_key_dots.dtype, device=query_key_dots.device), - ) + query_key_dots = torch.where(mask, query_key_dots, self_mask_value) # free memory del mask @@ -643,8 +651,6 @@ def _attend( # free memory del query_key_dots - # TODO(PVP): discuss with thom. Trax uses special dropout here where same dropout mask is applied for all "num_hashes * seq_len // chunk_length" dim. - # should be fine with normal dropout, no? # dropout dots = nn.functional.dropout(dots, self.dropout, self.training) @@ -751,6 +757,10 @@ def __init__(self, config): self.dropout = config.local_attention_probs_dropout_prob + # save mask value here + self.register_buffer("mask_value_float16", torch.tensor(-1e4)) + self.register_buffer("mask_value_float32", torch.tensor(-1e9)) + def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): # get SeqLen and BatchSize @@ -829,8 +839,14 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ del attn_mask, attention_mask, attention_mask_key if mask is not None: + # get mask tensor depending on half precision or not + if query_key_dots.dtype == torch.float16: + mask_value = self.mask_value_float16 + else: + mask_value = self.mask_value_float32 + query_key_dots = torch.where( - mask, query_key_dots, torch.tensor(-1e9, dtype=query_key_dots.dtype, device=query_key_dots.device), + mask, query_key_dots, mask_value ) # free memory @@ -1488,7 +1504,6 @@ def __init__(self, config): def forward(self, hidden_states): return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) - # TODO(PVP): Does this work with backpropagation? def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states @@ -1507,8 +1522,6 @@ def get_output_embeddings(self): return self.lm_head.decoder def tie_weights(self): - # TODO(PVP): output and input embeddings are - # apparently not tied so skip this step pass def forward( diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 55f09c5fb74f..908fd53dabcc 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -332,16 +332,8 @@ def create_and_check_reformer_model_fp16_forward( model.to(torch_device) model.half() model.eval() - model(input_ids, attention_mask=input_mask) - - def create_and_check_reformer_model_fp16_generate( - self, config, input_ids, input_mask - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.half() - model.eval() - model.generate(input_ids, attention_mask=input_mask, do_sample=False) + output = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -775,7 +767,8 @@ def create_and_check_reformer_model_fp16_forward( model.to(torch_device) model.half() model.eval() - model(input_ids, attention_mask=input_mask) + output = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) def create_and_check_reformer_model_fp16_generate( self, config, input_ids, input_mask @@ -784,7 +777,8 @@ def create_and_check_reformer_model_fp16_generate( model.to(torch_device) model.half() model.eval() - model.generate(input_ids, attention_mask=input_mask, do_sample=False) + output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) + self.parent.assertFalse(torch.isnan(output).any().item()) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() From 3248e675a52f65f9e882922bddf399e201a1048b Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 13:55:55 +0200 Subject: [PATCH 124/162] make style --- src/transformers/__init__.py | 8 +- src/transformers/activations.py | 2 +- src/transformers/modeling_reformer.py | 212 ++++++----- src/transformers/modeling_utils.py | 13 +- tests/test_modeling_reformer.py | 502 ++++++-------------------- 5 files changed, 257 insertions(+), 480 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 062c78890423..52adf8a41238 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -323,7 +323,13 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ) - from .modeling_reformer import ReformerAttention, ReformerLayer, ReformerModel, ReformerModelWithLMHead, REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + from .modeling_reformer import ( + ReformerAttention, + ReformerLayer, + ReformerModel, + ReformerModelWithLMHead, + REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, + ) # Optimization from .optimization import ( diff --git a/src/transformers/activations.py b/src/transformers/activations.py index f79b4055a16f..5238427d3b67 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -45,7 +45,7 @@ def gelu_fast(x): "gelu": gelu, "tanh": torch.tanh, "gelu_new": gelu_new, - "gelu_fast": gelu_fast + "gelu_fast": gelu_fast, } diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index c4f1f213d1b8..239be03cf773 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,23 +15,21 @@ # limitations under the License. """PyTorch REFORMER model. """ -import sys - import inspect import logging +import sys from collections import namedtuple from functools import reduce from operator import mul from typing import Callable -# DELETE later import numpy as np import torch from torch import nn from torch.autograd.function import Function from torch.nn import CrossEntropyLoss -from .activations import gelu, gelu_new, swish, gelu_fast +from .activations import gelu, gelu_fast, gelu_new, swish from .configuration_reformer import ReformerConfig from .modeling_utils import PreTrainedModel @@ -65,7 +63,10 @@ def mish(x): LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) -ReformerBackwardOutput = namedtuple("ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]) +ReformerBackwardOutput = namedtuple( + "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] +) +ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"]) def create_sinusoidal_embeddings(n_pos, dim, out): @@ -389,6 +390,7 @@ def forward( query_key_vectors = self.query_key(hidden_states) value_vectors = self.value(hidden_states) + # free memory del hidden_states @@ -445,7 +447,9 @@ def forward( del query_key_vectors, key_vectors, value_vectors # calculate total concatenad chunks to split gradients correctly - out_vectors, logits = GatherSorted.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes) + out_vectors, logits = GatherSorted.apply( + out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes + ) if num_hashes > 1: out_vectors = self._split_dim_by( @@ -489,7 +493,9 @@ def _hash_vectors(self, vectors, num_hashes): # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for bucket_factor in self.num_buckets: - assert bucket_factor % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format(bucket_factor) + assert bucket_factor % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( + bucket_factor + ) rotation_size = rotation_size + bucket_factor num_buckets = num_buckets * bucket_factor @@ -571,13 +577,7 @@ def _set_num_buckets_on_the_fly(self, sequence_length): self.num_buckets = num_buckets def _attend( - self, - query_vectors, - key_vectors, - value_vectors, - sorted_bucket_idx, - attention_mask, - head_mask, + self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, ): key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) @@ -626,7 +626,6 @@ def _attend( self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - # if attention_mask and/or casaul mask apply here if mask is not None: query_key_dots = torch.where(mask, query_key_dots, mask_value) @@ -642,24 +641,27 @@ def _attend( # where a token has no other valid attention targets (e.g. the first token in a sequence) " mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) query_key_dots = torch.where(mask, query_key_dots, self_mask_value) + # free memory del mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` - dots = torch.exp(query_key_dots - logits) + attention_probs = torch.exp(query_key_dots - logits) + # free memory del query_key_dots # dropout - dots = nn.functional.dropout(dots, self.dropout, self.training) + attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) # Mask heads if we want to if head_mask is not None: - dots = dots * head_mask + attention_probs = attention_probs * head_mask # attend values - out_vectors = torch.matmul(dots, value_vectors) + out_vectors = torch.matmul(attention_probs, value_vectors) + # free memory del value_vectors @@ -667,7 +669,7 @@ def _attend( logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) - return out_vectors, logits, dots + return out_vectors, logits, attention_probs def _len_and_dim_norm(self, vectors): vectors = self._len_norm(vectors) @@ -718,7 +720,9 @@ def backward(ctx, grad_out_vectors, grad_logits): # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen grad_logits = torch.reshape(grad_logits, (grad_logits_shape[:2] + (num_hashes, -1))) # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen - grad_out_vectors = torch.reshape(grad_out_vectors, (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])) + grad_out_vectors = torch.reshape( + grad_out_vectors, (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:]) + ) sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) @@ -818,6 +822,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # get logits and dots # query_key_dots shape is `[batch_size, num_attn_heads, seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + # free memory del query_vectors, key_vectors @@ -835,6 +840,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ mask = mask * attn_mask else: mask = attn_mask + # free memory del attn_mask, attention_mask, attention_mask_key @@ -845,15 +851,14 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ else: mask_value = self.mask_value_float32 - query_key_dots = torch.where( - mask, query_key_dots, mask_value - ) + query_key_dots = torch.where(mask, query_key_dots, mask_value) # free memory del mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) + # free memory del logits @@ -866,6 +871,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # attend values out_vectors = torch.matmul(attention_probs, value_vectors) + # free memory del value_vectors @@ -1201,7 +1207,8 @@ def backward(ctx, grad_hidden_states): attn_output, hidden_states = ctx.saved_tensors # create tuple - output = ReformerBackwardOutput(attn_output=attn_output, + output = ReformerBackwardOutput( + attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, @@ -1281,12 +1288,9 @@ def forward( # Apply dropout hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) - outputs = (hidden_states,) - if do_output_hidden_states: - outputs = outputs + (tuple(all_hidden_states),) - if do_output_attentions: - outputs = outputs + (tuple(all_attentions),) - return outputs # last-layer hidden state, (all hidden states), (all attentions) + return ReformerEncoderOutput( + hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions + ) class ReformerPreTrainedModel(PreTrainedModel): @@ -1312,13 +1316,15 @@ def _init_weights(self, module): if module.weight.requires_grad: module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # TODO(PVP): discuss with Thom if necessary here to use different init -# torch.nn.init.xavier_uniform_(module.weight) + # torch.nn.init.xavier_uniform_(module.weight) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() # TODO(PVP): discuss with Thom if necessary here to use different init + + # module.bias.data.normal_(mean=0.0, std=1e-6) @@ -1418,55 +1424,24 @@ def forward( # prepare head mask head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) - real_sequence_length = input_shape[-1] + # original sequence length for padding + orig_sequence_length = input_shape[-1] # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) - if input_shape[-1] % least_common_mult_chunk_length != 0: - - padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length - assert ( - self.training is False - ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if training. Please consider padding the input to a length of {}.".format( - input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length + has_to_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 + if has_to_pad_to_match_chunk_length is True: + # pad input + input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( + input_ids, + inputs_embeds, + attention_mask, + position_ids, + input_shape, + least_common_mult_chunk_length, + device, ) - logger.info( - "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( - input_shape[-1], input_shape[-1] + padding_length, least_common_mult_chunk_length - ) - ) - - padded_input_ids = torch.full( - (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, - ) - - # Extend `attention_mask` - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,)], - dim=-1, - ) - else: - attention_mask = torch.cat( - [ - torch.ones(input_shape, device=device, dtype=torch.uint8), - torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), - ], - dim=-1, - ) - - # Extend `input_ids` with padding to match least common multiple chunk_length - if input_ids is not None: - input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) - input_shape = input_ids.size() - - # Extend `input_embeds` with padding to match least common multiple chunk_length - if inputs_embeds is not None: - padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) - inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) - input_shape = inputs_embeds.size() - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( @@ -1477,15 +1452,77 @@ def forward( do_output_hidden_states=do_output_hidden_states, do_output_attentions=do_output_attentions, ) - sequence_output = encoder_outputs[0] + sequence_output = encoder_outputs.hidden_states # if padding was applied - if real_sequence_length < input_shape[-1]: - sequence_output = sequence_output[:, :real_sequence_length] + if has_to_pad_to_match_chunk_length is True: + sequence_output = sequence_output[:, :orig_sequence_length] - # add hidden_states and attentions if they are here - outputs = (sequence_output,) + encoder_outputs[1:] - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + outputs = (sequence_output,) + + # TODO(PVP): Replace by named tuple after namedtuples are introduced in the library. + if do_output_hidden_states is True: + outputs = outputs + (encoder_outputs.all_hidden_states,) + if do_output_attentions is True: + outputs = outputs + (encoder_outputs.all_attentions,) + return outputs + + def _pad_to_mult_of_chunk_length( + self, input_ids, inputs_embeds, attention_mask, position_ids, input_shape, padded_seq_length, device + ): + padding_length = padded_seq_length - input_shape[-1] % padded_seq_length + assert ( + self.training is False + ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if training. Please consider padding the input to a length of {}.".format( + input_shape[-2], padded_seq_length, input_shape[-2] + padding_length + ) + + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( + input_shape[-1], input_shape[-1] + padding_length, padded_seq_length + ) + ) + + padded_input_ids = torch.full( + (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, + ) + + # Extend `attention_mask` + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask, + torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,), + ], + dim=-1, + ) + else: + attention_mask = torch.cat( + [ + torch.ones(input_shape, device=device, dtype=torch.uint8), + torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), + ], + dim=-1, + ) + + # Extend `input_ids` with padding to match least common multiple chunk_length + if input_ids is not None: + input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) + input_shape = input_ids.size() + + # Pad position ids if given + if position_ids is not None: + padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device) + padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length) + position_ids = torch.cat([position_ids, padded_position_ids], dim=-1) + + # Extend `input_embeds` with padding to match least common multiple chunk_length + if inputs_embeds is not None: + padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) + inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) + input_shape = inputs_embeds.size() + + return input_ids, inputs_embeds, attention_mask, position_ids, input_shape class ReformerOnlyLMHead(nn.Module): @@ -1556,8 +1593,6 @@ def forward( # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() -# shift_logits = logits.contiguous() -# shift_labels = labels.contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) @@ -1567,4 +1602,9 @@ def forward( def prepare_inputs_for_generation(self, input_ids, past, **kwargs): # TODO(PVP): Add smart caching - return {"input_ids": input_ids} + inputs_dict = {"input_ids": input_ids} + + if "num_hashes" in kwargs: + inputs_dict["num_hashes"] = kwargs["num_hashes"] + + return inputs_dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dd582042a353..7c260791b2ca 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -788,6 +788,8 @@ def generate( attention_mask=None, decoder_start_token_id=None, use_cache=None, + num_hashes=None, + **model_specific_kwargs ): r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. @@ -865,6 +867,9 @@ def generate( use_cache: (`optional`) bool If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`. + model_specific_kwargs: (`optional`) dict + Additional model specific kwargs will be forwarded to the `forward` function of the model. + Return: output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` @@ -1118,6 +1123,7 @@ def generate( encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, ) else: output = self._generate_no_beam_search( @@ -1140,6 +1146,7 @@ def generate( encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, ) return output @@ -1165,6 +1172,7 @@ def _generate_no_beam_search( encoder_outputs, attention_mask, use_cache, + model_specific_kwargs, ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -1177,7 +1185,7 @@ def _generate_no_beam_search( while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) outputs = self(**model_inputs) @@ -1290,6 +1298,7 @@ def _generate_beam_search( encoder_outputs, attention_mask, use_cache, + model_specific_kwargs, ): """ Generate sequences for each example with beam search. """ @@ -1316,7 +1325,7 @@ def _generate_beam_search( while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 908fd53dabcc..212cdcb4f258 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -21,6 +21,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from .utils import CACHE_DIR, require_torch, slow, torch_device + if is_torch_available(): from transformers import ( ReformerConfig, @@ -28,16 +29,14 @@ ReformerModelWithLMHead, ReformerTokenizer, ReformerLayer, - REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, ) import torch @require_torch class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - ) + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -120,9 +119,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = ids_tensor( - [self.batch_size, self.seq_length], vocab_size=2 - ) + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) config = ReformerConfig( vocab_size=self.vocab_size, @@ -168,8 +165,7 @@ def create_and_check_reformer_model( } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( - list(result["sequence_output"].size()), - [self.batch_size, self.seq_length, 2 * self.hidden_size], + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], ) def create_and_check_reformer_model_with_lm_backward( @@ -187,22 +183,17 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model( - input_ids, attention_mask=input_mask, labels=input_ids - ) + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( - list(result["prediction_scores"].size()), - [self.batch_size, self.seq_length, self.vocab_size], + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -213,9 +204,7 @@ def create_and_check_reformer_model_with_attn_mask( # set all position encodings to zero so that postions don't matter with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter( - torch.zeros(embedding.weight.shape).to(torch_device) - ) + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -224,45 +213,28 @@ def create_and_check_reformer_model_with_attn_mask( half_input_ids = input_ids[:, :half_seq_len] # normal padded - attn_mask = torch.cat( - [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], - dim=-1, - ) + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) input_ids_padded = torch.cat( - [ - half_input_ids, - ids_tensor((self.batch_size, half_seq_len), self.vocab_size), - ], - dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [ - half_input_ids, - ids_tensor((self.batch_size, half_seq_len), self.vocab_size), - ], - dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][ - :, :half_seq_len + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ + :, roll : half_seq_len + roll ] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[ - 0 - ][:, roll : half_seq_len + roll] - self.parent.assertTrue( - torch.allclose(output_padded, output_padded_rolled, atol=1e-3) - ) + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - def create_and_check_reformer_layer_dropout_seed( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() @@ -278,9 +250,7 @@ def create_and_check_reformer_layer_dropout_seed( # now the random seeds for attention and feed forward is initialized # forward tensors with dropout - layer_outputs = layer( - prev_attn_output, hidden_states, attention_mask=input_mask - ) + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) next_attn_output = layer_outputs.attn_output next_hidden_states = layer_outputs.hidden_states @@ -288,26 +258,16 @@ def create_and_check_reformer_layer_dropout_seed( torch.manual_seed(layer.attention_seed) attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) self.parent.assertTrue( - torch.allclose( - prev_attn_output + attn_outputs.hidden_states, - next_attn_output, - atol=1e-3, - ) + torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) ) torch.manual_seed(layer.feed_forward_seed) feed_forward_hidden_states = layer.feed_forward(next_attn_output) self.parent.assertTrue( - torch.allclose( - next_hidden_states, - hidden_states + feed_forward_hidden_states, - atol=1e-3, - ) + torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) ) - def create_and_check_reformer_feed_forward_chunking( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) @@ -325,9 +285,7 @@ def create_and_check_reformer_feed_forward_chunking( hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) - def create_and_check_reformer_model_fp16_forward( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): model = ReformerModel(config=config) model.to(torch_device) model.half() @@ -342,12 +300,8 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict def setUp(self): - self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester( - self - ) - self.config_tester = ConfigTester( - self, config_class=ReformerConfig, hidden_size=37 - ) + self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester(self) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -358,18 +312,12 @@ def test_reformer_model(self): def test_reformer_lm_model_backward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_lm_backward( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs) def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask( - *config_and_inputs, True - ) - self.model_tester.create_and_check_reformer_model_with_attn_mask( - *config_and_inputs, False - ) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -377,12 +325,8 @@ def test_reformer_with_lm(self): def test_reformer_layer_training_dropout(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed( - *config_and_inputs, True - ) - self.model_tester.create_and_check_reformer_layer_dropout_seed( - *config_and_inputs, False - ) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) def test_reformer_chunking_forward_equality(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -391,16 +335,12 @@ def test_reformer_chunking_forward_equality(self): @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - ) + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -452,9 +392,7 @@ def __init__( self.num_attention_heads = num_attention_heads self.num_hashes = num_hashes self.num_hidden_layers = len(attn_layers) - self.num_buckets = ( - tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets - ) + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets self.lsh_attn_chunk_length = lsh_attn_chunk_length self.lsh_num_chunks_after = lsh_num_chunks_after self.lsh_num_chunks_before = lsh_num_chunks_before @@ -473,12 +411,8 @@ def __init__( self.pad_token_id = pad_token_id self.axial_pos_embds = axial_pos_embds - self.encoder_seq_length = seq_length // lsh_attn_chunk_length + ( - seq_length % lsh_attn_chunk_length != 0 - ) - self.key_length = ( - self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1 - ) * lsh_attn_chunk_length + self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) + self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length self.chunk_length = lsh_attn_chunk_length def prepare_config_and_inputs(self): @@ -486,9 +420,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = ids_tensor( - [self.batch_size, self.seq_length], vocab_size=2 - ) + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) config = ReformerConfig( vocab_size=self.vocab_size, @@ -535,8 +467,7 @@ def create_and_check_reformer_model( } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( - list(result["sequence_output"].size()), - [self.batch_size, self.seq_length, 2 * self.hidden_size], + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], ) def create_and_check_reformer_model_with_lm_backward( @@ -554,22 +485,17 @@ def create_and_check_reformer_with_lm( model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model( - input_ids, attention_mask=input_mask, labels=input_ids - ) + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( - list(result["prediction_scores"].size()), - [self.batch_size, self.seq_length, self.vocab_size], + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -584,9 +510,7 @@ def create_and_check_reformer_model_with_attn_mask( with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter( - torch.zeros(embedding.weight.shape).to(torch_device) - ) + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 @@ -595,43 +519,26 @@ def create_and_check_reformer_model_with_attn_mask( half_input_ids = input_ids[:, :half_seq_len] # normal padded - attn_mask = torch.cat( - [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], - dim=-1, - ) + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) input_ids_padded = torch.cat( - [ - half_input_ids, - ids_tensor((self.batch_size, half_seq_len), self.vocab_size), - ], - dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [ - half_input_ids, - ids_tensor((self.batch_size, half_seq_len), self.vocab_size), - ], - dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][ - :, :half_seq_len + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ + :, roll : half_seq_len + roll ] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[ - 0 - ][:, roll : half_seq_len + roll] - self.parent.assertTrue( - torch.allclose(output_padded, output_padded_rolled, atol=1e-3) - ) + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - def create_and_check_reformer_layer_dropout_seed( - self, config, input_ids, input_mask, is_decoder - ): + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() @@ -647,9 +554,7 @@ def create_and_check_reformer_layer_dropout_seed( # now the random seeds for attention and feed forward is initialized # forward tensors with dropout - layer_outputs = layer( - prev_attn_output, hidden_states, attention_mask=input_mask - ) + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) next_attn_output = layer_outputs.attn_output next_hidden_states = layer_outputs.hidden_states @@ -657,27 +562,17 @@ def create_and_check_reformer_layer_dropout_seed( torch.manual_seed(layer.attention_seed) attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) self.parent.assertTrue( - torch.allclose( - prev_attn_output + attn_outputs.hidden_states, - next_attn_output, - atol=1e-3, - ) + torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) ) torch.manual_seed(layer.feed_forward_seed) feed_forward_hidden_states = layer.feed_forward(next_attn_output) self.parent.assertTrue( - torch.allclose( - next_hidden_states, - hidden_states + feed_forward_hidden_states, - atol=1e-3, - ) + torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) ) @slow - def create_and_check_reformer_random_seed( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): layer = ReformerLayer(config).to(torch_device) layer.train() @@ -692,9 +587,7 @@ def create_and_check_reformer_random_seed( seeds = [] for _ in range(100): - layer_outputs = layer( - attn_output, hidden_states, attention_mask=input_mask - ) + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) @@ -703,18 +596,14 @@ def create_and_check_reformer_random_seed( seeds = [] for _ in range(100): - layer_outputs = layer( - attn_output, hidden_states, attention_mask=input_mask - ) + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.feed_forward_seed) seeds.append(layer.feed_forward_seed) self.parent.assertGreater(len(set(seeds)), 70) - def create_and_check_reformer_feed_backward_chunking( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): torch.manual_seed(0) model = ReformerModelWithLMHead(config=config) model.to(torch_device) @@ -723,16 +612,8 @@ def create_and_check_reformer_feed_backward_chunking( loss_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] loss_no_chunk.backward() grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[ - 0 - ][ - 1, 0, -5: - ] - grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[ - 1 - ][ - 0, 1, :5 - ] + grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] config.chunk_size_lm_head = input_ids.shape[-1] config.chunk_size_feed_forward = input_ids.shape[-1] @@ -745,24 +626,18 @@ def create_and_check_reformer_feed_backward_chunking( loss_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] loss_chunk.backward() grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[ - 0 - ][ - 1, 0, -5: - ] - grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[ - 1 - ][ - 0, 1, :5 - ] + grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) - self.parent.assertTrue(torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3)) - self.parent.assertTrue(torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) + ) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) + ) - def create_and_check_reformer_model_fp16_forward( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): model = ReformerModel(config=config) model.to(torch_device) model.half() @@ -770,9 +645,7 @@ def create_and_check_reformer_model_fp16_forward( output = model(input_ids, attention_mask=input_mask)[0] self.parent.assertFalse(torch.isnan(output).any().item()) - def create_and_check_reformer_model_fp16_generate( - self, config, input_ids, input_mask - ): + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -788,9 +661,7 @@ def prepare_config_and_inputs_for_common(self): def setUp(self): self.model_tester = ReformerLSHAttnModelTest.ReformerLSHAttnModelTester(self) - self.config_tester = ConfigTester( - self, config_class=ReformerConfig, hidden_size=37 - ) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -801,18 +672,12 @@ def test_reformer_model(self): def test_reformer_lm_model_backward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_lm_backward( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs) def test_reformer_model_attn_masking(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask( - *config_and_inputs, True - ) - self.model_tester.create_and_check_reformer_model_with_attn_mask( - *config_and_inputs, False - ) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) def test_reformer_with_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -820,32 +685,22 @@ def test_reformer_with_lm(self): def test_reformer_layer_training_dropout(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed( - *config_and_inputs, True - ) - self.model_tester.create_and_check_reformer_layer_dropout_seed( - *config_and_inputs, False - ) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_random_seed( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_generate( - *config_and_inputs - ) + self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) @slow def test_model_from_pretrained(self): @@ -982,74 +837,8 @@ def _get_attn_mask(self): def _get_input_ids_and_mask(self): mask = torch.tensor( [ - [ - 1, - 0, - 0, - 1, - 1, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 1, - 1, - 0, - 1, - 1, - 0, - 0, - 1, - 1, - 0, - 0, - 1, - 1, - 1, - 1, - ], - [ - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 0, - 1, - 0, - 0, - 1, - 1, - 1, - 1, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 1, - 0, - 0, - 0, - 1, - 0, - ], + [1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1,], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,], ], dtype=torch.long, device=torch_device, @@ -1140,14 +929,10 @@ def test_lsh_layer_forward(self): torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() - reformer_output = layer( - prev_attn_output=hidden_states.clone(), hidden_states=hidden_states - ) + reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], - dtype=torch.float, - device=torch_device, + [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1161,15 +946,11 @@ def test_lsh_layer_forward_complex(self): layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( - prev_attn_output=hidden_states.clone(), - hidden_states=hidden_states, - attention_mask=attn_mask, + prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask, ) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], - dtype=torch.float, - device=torch_device, + [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1181,14 +962,10 @@ def test_local_layer_forward(self): torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() - reformer_output = layer( - prev_attn_output=hidden_states, hidden_states=hidden_states - ) + reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], - dtype=torch.float, - device=torch_device, + [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1200,16 +977,10 @@ def test_local_layer_forward_complex(self): torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() - reformer_output = layer( - prev_attn_output=hidden_states, - hidden_states=hidden_states, - attention_mask=attn_mask, - ) + reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], - dtype=torch.float, - device=torch_device, + [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1224,9 +995,7 @@ def test_lsh_model_forward(self): hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], - dtype=torch.float, - device=torch_device, + [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1240,9 +1009,7 @@ def test_local_model_forward(self): hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[0, 0, :5] expected_output_slice = torch.tensor( - [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], - dtype=torch.float, - device=torch_device, + [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1258,9 +1025,7 @@ def test_lm_model_forward(self): hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[1, -1, :5] expected_output_slice = torch.tensor( - [0.0324, -0.0121, 0.0615, 0.0031, -0.0297], - dtype=torch.float, - device=torch_device, + [0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) @@ -1282,43 +1047,19 @@ def test_local_lm_model_grad(self): # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( - [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], - dtype=torch.float, - device=torch_device, + [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device, ) - grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[ - 0 - ][ - 1, 0, -5: - ] + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] expected_grad_slice_pos_fac_1 = torch.tensor( - [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], - dtype=torch.float, - device=torch_device, + [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device, ) - grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[ - 1 - ][ - 0, 1, :5 - ] + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] expected_grad_slice_pos_fac_2 = torch.tensor( - [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], - dtype=torch.float, - device=torch_device, - ) - self.assertTrue( - torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3) - ) - self.assertTrue( - torch.allclose( - grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3 - ) - ) - self.assertTrue( - torch.allclose( - grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3 - ) + [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device, ) + self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3)) def test_lsh_lm_model_grad(self): config = self._get_basic_config_and_input() @@ -1339,51 +1080,32 @@ def test_lsh_lm_model_grad(self): # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( - [2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], - dtype=torch.float, - device=torch_device, + [2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], dtype=torch.float, device=torch_device, ) - grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[ - 0 - ][ - 1, 0, -5: - ] + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] expected_grad_slice_pos_fac_1 = torch.tensor( - [-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], - dtype=torch.float, - device=torch_device, + [-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], dtype=torch.float, device=torch_device, ) - grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[ - 1 - ][ - 0, 1, :5 - ] + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] expected_grad_slice_pos_fac_2 = torch.tensor( - [0.4626, -0.0231, -0.0172, 0.1081, 0.3805], - dtype=torch.float, - device=torch_device, - ) - self.assertTrue( - torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3) - ) - self.assertTrue( - torch.allclose( - grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3 - ) - ) - self.assertTrue( - torch.allclose( - grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3 - ) + [0.4626, -0.0231, -0.0172, 0.1081, 0.3805], dtype=torch.float, device=torch_device, ) + self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3)) @slow def test_pretrained_generate_crime_and_punish(self): - model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR).to(torch_device) + model = ReformerModelWithLMHead.from_pretrained( + "patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR + ).to(torch_device) tokenizer = ReformerTokenizer.from_pretrained("patrickvonplaten/reformer-crime-and-punish") model.eval() input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) output_ids = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False) output_text = tokenizer.decode(output_ids[0]) - self.assertEqual(output_text, "A few months later state expression in his ideas, at the first entrance. He was positively for an inst") + self.assertEqual( + output_text, + "A few months later state expression in his ideas, at the first entrance. He was positively for an inst", + ) From 6fe0648dd2da4f776aa0d95951c82e2f36e8d1d6 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 14:18:26 +0200 Subject: [PATCH 125/162] refactor --- src/transformers/modeling_reformer.py | 105 +++++++++++++++----------- 1 file changed, 63 insertions(+), 42 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 239be03cf773..47ec483a8ea1 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -59,6 +59,7 @@ def mish(x): ReformerLayerNorm = torch.nn.LayerNorm +# Define named tuples for nn.Modules here LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) @@ -153,9 +154,10 @@ def __init__(self, config): super().__init__() self.axial_pos_shape = config.axial_pos_shape self.axial_pos_embds_dim = config.axial_pos_embds_dim + self.dropout = config.hidden_dropout_prob + self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config) self.weights = nn.ParameterList() - self.dropout = config.hidden_dropout_prob assert ( sum(self.axial_pos_embds_dim) == config.hidden_size @@ -195,6 +197,7 @@ def forward(self, position_ids): # drop entire matrix of last two dims (prev dims 1 and 2) drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) drop_weights = drop_perm_weights.permute(0, 3, 2, 1) + position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) else: @@ -209,6 +212,7 @@ def forward(self, position_ids): ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, ) + # reshape axial encodings and use only until sequence_length position_encodings = torch.cat(broadcasted_weights, dim=-1) position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[ @@ -224,17 +228,18 @@ class PositionEmbeddings(nn.Module): def __init__(self, config): super().__init__() - self.config = config self.dropout = config.hidden_dropout_prob + self.is_sinusoidal_pos_embds = config.sinusoidal_pos_embds + self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) - if self.config.sinusoidal_pos_embds is True: + if config.sinusoidal_pos_embds is True: create_sinusoidal_embeddings( n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.embedding.weight, ) def forward(self, position_ids): position_embeddings = self.embedding(position_ids) - if self.config.sinusoidal_pos_embds is False: + if self.is_sinusoidal_pos_embds is False: position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) return position_embeddings @@ -245,6 +250,9 @@ class ReformerEmbeddings(nn.Module): def __init__(self, config): super().__init__() + self.max_position_embeddings = config.max_position_embeddings + self.dropout = config.hidden_dropout_prob + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = ( AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) @@ -252,18 +260,16 @@ def __init__(self, config): assert not ( config.sinusoidal_pos_embds and config.axial_pos_embds ), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" - self.max_position_embeddings = config.max_position_embeddings - self.dropout = config.hidden_dropout_prob def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): - if input_ids is not None: input_shape = input_ids.size() + device = input_ids.device else: input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand(input_shape) @@ -276,8 +282,8 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): ), "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format( position_ids.shape[-1], self.max_position_embeddings ) - embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) + embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings @@ -295,9 +301,12 @@ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): Args: vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + num_chunks_before: chunks before current chunk to include in attention + num_chunks_after: chunks after current chunk to include in attention + Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). + array of shape [num_chunks, N * chunk_length, ...], where + N = (1 + num_chunks_before + num_chunks_after). """ if num_chunks_before == 0 and num_chunks_after == 0: return vectors @@ -310,16 +319,25 @@ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) return torch.cat(slices, dim=3) - def _transpose_for_scores(self, x, num_attn_heads, attn_head_size): + def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size): + """ + splits hidden_size dim into attn_head_size and num_attn_heads + """ new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) x = x.view(*new_x_shape) return x.transpose(2, 1) - def _transpose_for_output(self, x, num_attn_heads, attn_head_size): + def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size): + """ + merges attn_head_size dim and num_attn_heads dim into hidden_size + """ x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) - def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): + def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): + """ + splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims + """ batch_size = vectors.shape[0] split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) @@ -330,7 +348,10 @@ def _split_dim_by(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, att else: raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - def _merge_by_middle_dim(self, vectors, num_attn_heads): + def _merge_seq_length_dims(self, vectors, num_attn_heads): + """ + merges factorized sequence length dims into one dim + """ batch_size = vectors.shape[0] new_dim_shape = (batch_size, num_attn_heads, -1) @@ -345,7 +366,6 @@ def _merge_by_middle_dim(self, vectors, num_attn_heads): class LSHSelfAttention(nn.Module, EfficientAttentionUtils): def __init__(self, config): super().__init__() - self.num_attention_heads = config.num_attention_heads self.hash_seed = config.hash_seed self.num_hashes = config.num_hashes @@ -356,6 +376,8 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings + self.dropout = config.lsh_attention_probs_dropout_prob + self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size @@ -364,8 +386,6 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - self.dropout = config.lsh_attention_probs_dropout_prob - # save mask value here self.register_buffer("self_mask_value_float16", torch.tensor(-1e3)) self.register_buffer("self_mask_value_float32", torch.tensor(-1e5)) @@ -382,28 +402,29 @@ def forward( buckets=None, **kwargs ): - # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] + # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes + # project hidden_states to query_key and value query_key_vectors = self.query_key(hidden_states) value_vectors = self.value(hidden_states) # free memory del hidden_states - query_key_vectors = self._transpose_for_scores( + query_key_vectors = self._split_hidden_size_dim( query_key_vectors, self.num_attention_heads, self.attention_head_size ) - value_vectors = self._transpose_for_scores(value_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) assert query_key_vectors.shape[-1] == self.attention_head_size assert value_vectors.shape[-1] == self.attention_head_size + # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: - # set `num_buckets` on the fly self._set_num_buckets_on_the_fly(sequence_length) # use cached buckets for backprop only @@ -423,10 +444,10 @@ def forward( query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) - query_key_vectors = self._split_dim_by( + query_key_vectors = self._split_seq_length_dim_to( query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) - value_vectors = self._split_dim_by( + value_vectors = self._split_seq_length_dim_to( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) @@ -452,10 +473,10 @@ def forward( ) if num_hashes > 1: - out_vectors = self._split_dim_by( + out_vectors = self._split_seq_length_dim_to( out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ) - logits = self._split_dim_by( + logits = self._split_seq_length_dim_to( logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ).unsqueeze(-1) @@ -469,7 +490,7 @@ def forward( assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) - out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) if do_output_attentions is False: attention_probs = () @@ -538,7 +559,7 @@ def _hash_vectors(self, vectors, num_hashes): # repeat same values for Batch_size and Num_Attn_Heads offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_by_middle_dim(buckets + offsets, self.num_attention_heads) + offset_buckets = self._merge_seq_length_dims(buckets + offsets, self.num_attention_heads) return offset_buckets @@ -587,7 +608,7 @@ def _attend( # free memory del query_vectors, key_vectors - query_bucket_idx = self._split_dim_by(sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads) + query_bucket_idx = self._split_seq_length_dim_to(sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) mask = None @@ -666,8 +687,8 @@ def _attend( del value_vectors # merge chunk length - logits = self._merge_by_middle_dim(logits, self.num_attention_heads).squeeze(-1) - out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) + logits = self._merge_seq_length_dims(logits, self.num_attention_heads).squeeze(-1) + out_vectors = self._merge_seq_length_dims(out_vectors, self.num_attention_heads) return out_vectors, logits, attention_probs @@ -775,9 +796,9 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) - query_vectors = self._transpose_for_scores(query_vectors, self.num_attention_heads, self.attention_head_size) - key_vectors = self._transpose_for_scores(key_vectors, self.num_attention_heads, self.attention_head_size) - value_vectors = self._transpose_for_scores(value_vectors, self.num_attention_heads, self.attention_head_size) + query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) + key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) assert query_vectors.shape[-1] == self.attention_head_size assert key_vectors.shape[-1] == self.attention_head_size @@ -791,13 +812,13 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ ) # chunk vectors - query_vectors = self._split_dim_by( + query_vectors = self._split_seq_length_dim_to( query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size - key_vectors = self._split_dim_by( + key_vectors = self._split_seq_length_dim_to( key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size - value_vectors = self._split_dim_by( + value_vectors = self._split_seq_length_dim_to( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size @@ -805,8 +826,8 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) - query_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) - key_value_indices = self._split_dim_by(indices, -1, self.chunk_length, self.num_attention_heads) + query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + key_value_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) # append chunks before and after key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) @@ -816,7 +837,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attention_mask = self._split_dim_by(attention_mask, -1, self.chunk_length, 1) + attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) # get logits and dots @@ -876,11 +897,11 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ del value_vectors # merge chunk length - out_vectors = self._merge_by_middle_dim(out_vectors, self.num_attention_heads) + out_vectors = self._merge_seq_length_dims(out_vectors, self.num_attention_heads) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) - out_vectors = self._transpose_for_output(out_vectors, self.num_attention_heads, self.attention_head_size) + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) if do_output_attentions is False: attention_probs = () From c3031b8306a3b0e5091126d88346916b5d8e9dff Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 18:48:04 +0200 Subject: [PATCH 126/162] refactor --- src/transformers/modeling_reformer.py | 235 ++++++++++++++++---------- 1 file changed, 145 insertions(+), 90 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 47ec483a8ea1..b4be2409b907 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -366,18 +366,18 @@ def _merge_seq_length_dims(self, vectors, num_attn_heads): class LSHSelfAttention(nn.Module, EfficientAttentionUtils): def __init__(self, config): super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hash_seed = config.hash_seed + self.chunk_length = config.lsh_attn_chunk_length self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets - self.chunk_length = config.lsh_attn_chunk_length self.num_chunks_before = config.lsh_num_chunks_before self.num_chunks_after = config.lsh_num_chunks_after + self.hash_seed = config.hash_seed self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings self.dropout = config.lsh_attention_probs_dropout_prob + self.num_attention_heads = config.num_attention_heads self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size @@ -386,7 +386,7 @@ def __init__(self, config): self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) - # save mask value here + # save mask value here. Need fp32 and fp16 mask values self.register_buffer("self_mask_value_float16", torch.tensor(-1e3)) self.register_buffer("self_mask_value_float32", torch.tensor(-1e5)) self.register_buffer("mask_value_float16", torch.tensor(-1e4)) @@ -420,8 +420,8 @@ def forward( ) value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) - assert query_key_vectors.shape[-1] == self.attention_head_size - assert value_vectors.shape[-1] == self.attention_head_size + assert query_key_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(query_key_vectors.shape[-1], self.attention_head_size) + assert value_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(value_vectors.shape[-1], self.attention_head_size) # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: @@ -432,7 +432,7 @@ def forward( # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors, num_hashes) - assert int(buckets.shape[-1]) == num_hashes * sequence_length + assert int(buckets.shape[-1]) == num_hashes * sequence_length, "lash dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sequence_length, buckets, num_hashes @@ -441,6 +441,7 @@ def forward( # make sure bucket idx is not longer then sequence length sorted_bucket_idx = sorted_bucket_idx % sequence_length + # cluster query key value vectors according to hashed buckets query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) @@ -452,10 +453,12 @@ def forward( ) if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + assert self.num_chunks_before == 0 and self.num_chunks_after == 0, "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + # scale key vectors key_vectors = self._len_and_dim_norm(query_key_vectors) + # get attention probs out_vectors, logits, attention_probs = self._attend( query_vectors=query_key_vectors, key_vectors=key_vectors, @@ -467,11 +470,12 @@ def forward( # free memory del query_key_vectors, key_vectors, value_vectors - # calculate total concatenad chunks to split gradients correctly - out_vectors, logits = GatherSorted.apply( + # sort clusters back to correct ordering + out_vectors, logits = ReverseSort.apply( out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes ) + # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, @@ -488,7 +492,7 @@ def forward( # free memory del logits - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`." out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) @@ -503,7 +507,6 @@ def _hash_vectors(self, vectors, num_hashes): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. - if isinstance(self.num_buckets, int): assert ( self.num_buckets % 2 == 0 @@ -605,40 +608,13 @@ def _attend( # get logits and dots query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + # free memory del query_vectors, key_vectors query_bucket_idx = self._split_seq_length_dim_to(sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) - mask = None - # Causal mask - if self.is_decoder: - mask = torch.ge(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( - query_bucket_idx.device - ) - - # Attention mask: chunk, look up correct mask value from key_value_bucket_idx - # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. - if attention_mask is not None: - attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] - # expand attn_mask to fit with key_value_bucket_idx shape - attention_mask = attention_mask.expand(query_bucket_idx.shape[:-1] + (-1,)) - key_attn_mask = torch.gather(attention_mask, -1, key_value_bucket_idx) - query_attn_mask = torch.gather(attention_mask, -1, query_bucket_idx) - # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk - attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) - # free memory - del query_attn_mask, key_attn_mask, attention_mask - - # multiply by casaul mask if necessary - if mask is not None: - mask = mask * attn_mask - else: - mask = attn_mask - # free memory - del attn_mask - # get correct mask values depending on precision if query_key_dots.dtype == torch.float16: self_mask_value = self.self_mask_value_float16 @@ -647,10 +623,14 @@ def _attend( self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - # if attention_mask and/or casaul mask apply here + mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask) + if mask is not None: query_key_dots = torch.where(mask, query_key_dots, mask_value) + # free memory + del mask + # Self mask is ALWAYS applied. # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): # " While attention to the future is not allowed, typical implementations of the @@ -660,11 +640,14 @@ def _attend( # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " - mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) - query_key_dots = torch.where(mask, query_key_dots, self_mask_value) + + self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) + + # apply self_mask + query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value) # free memory - del mask + del self_mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` @@ -692,7 +675,40 @@ def _attend( return out_vectors, logits, attention_probs + def _compute_attn_mask(self, query_indices, key_indices, attention_mask): + mask = None + + # Causal mask + if self.is_decoder: + mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to( + query_indices.device + ) + + # Attention mask: chunk, look up correct mask value from key_value_bucket_idx + # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. + if attention_mask is not None: + attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] + # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) + key_attn_mask = torch.gather(attention_mask, -1, key_indices) + query_attn_mask = torch.gather(attention_mask, -1, query_indices) + # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk + attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + # free memory + del query_attn_mask, key_attn_mask, attention_mask + + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask + + return mask + def _len_and_dim_norm(self, vectors): + """ + length and attention head size dim normalization + """ vectors = self._len_norm(vectors) vectors = vectors / torch.sqrt( torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype) @@ -700,17 +716,30 @@ def _len_and_dim_norm(self, vectors): return vectors def _len_norm(self, x, epsilon=1e-6): + """ + length normalization + """ variance = torch.mean(x ** 2, -1, keepdim=True) norm_x = x / torch.sqrt(variance + epsilon) return norm_x def _gather_by_expansion(self, vectors, idxs, num_hashes): + """ + expand dims of idxs and vectors for all hashes and gather + """ expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) vectors = vectors.repeat(1, 1, num_hashes, 1) return torch.gather(vectors, 2, expanded_idxs) -class GatherSorted(Function): +class ReverseSort(Function): + """ + After chunked attention is applied which sorted clusters, + original ordering has to be restored. + Since customized backward function is used for Reformer, + the gradients of the output vectors have to be explicitely + sorted here. + """ @staticmethod def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): # save sorted_bucket_idx for backprop @@ -745,10 +774,10 @@ def backward(ctx, grad_out_vectors, grad_logits): grad_out_vectors, (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:]) ) + # reshape and expand sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) - - # reverse sort of forward expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) + # reverse sort of forward grad_out_vectors = torch.gather(grad_out_vectors, 3, expanded_sort_indices) grad_logits = torch.gather(grad_logits, 3, sorted_bucket_idx) @@ -787,83 +816,61 @@ def __init__(self, config): self.register_buffer("mask_value_float32", torch.tensor(-1e9)) def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): - - # get SeqLen and BatchSize sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] + # project hidden_states to query, key and value query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) + # split last dim into `config.num_attention_heads` and `config.attention_head_size` query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) - assert query_vectors.shape[-1] == self.attention_head_size - assert key_vectors.shape[-1] == self.attention_head_size - assert value_vectors.shape[-1] == self.attention_head_size + assert query_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(query_vectors.shape[-1], self.attention_head_size) + assert key_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(key_vectors.shape[-1], self.attention_head_size) + assert value_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(value_vectors.shape[-1], self.attention_head_size) if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0 + assert self.num_chunks_before == 0 and self.num_chunks_after == 0, "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + # normalize key vectors key_vectors = key_vectors / torch.sqrt( torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) ) # chunk vectors + # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size query_vectors = self._split_seq_length_dim_to( query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + ) key_vectors = self._split_seq_length_dim_to( key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + ) value_vectors = self._split_seq_length_dim_to( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, - ) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + ) # chunk indices indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) - key_value_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) # append chunks before and after key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) - key_value_indices = self._look_adjacent(key_value_indices, self.num_chunks_before, self.num_chunks_after) - - # chunk attention mask and look before and after - if attention_mask is not None: - attention_mask = attention_mask.to(torch.uint8)[:, None, :] - attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) - attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) - # get logits and dots - # query_key_dots shape is `[batch_size, num_attn_heads, seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory del query_vectors, key_vectors - mask = None - # Causal mask - if self.is_decoder is True: - mask = torch.ge(query_indices.unsqueeze(-1), key_value_indices.unsqueeze(-2)).to(query_indices.device) - - # Attention mask - if attention_mask is not None: - # create attn_mask - attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots.shape) - # multiply by casaul mask if necessary - if mask is not None: - mask = mask * attn_mask - else: - mask = attn_mask - - # free memory - del attn_mask, attention_mask, attention_mask_key + mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape) if mask is not None: # get mask tensor depending on half precision or not @@ -877,6 +884,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ # free memory del mask + # softmax logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) @@ -908,14 +916,39 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) + def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape): + mask = None + + # chunk attention mask and look before and after + if attention_mask is not None: + attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) + attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + + # Causal mask + if self.is_decoder is True: + mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # Attention mask + if attention_mask is not None: + # create attn_mask + attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape) + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask + return mask + class ReformerSelfOutput(nn.Module): def __init__(self, config): super().__init__() all_head_size = config.num_attention_heads * config.attention_head_size - self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) self.dropout = config.hidden_dropout_prob + self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) @@ -926,15 +959,16 @@ class ReformerAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.layer_id = layer_id - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn_layers = config.attn_layers + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": self.self_attention = LocalSelfAttention(config) elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): - # get correct attn_layers + # get correct attn layers if self.attn_layers[self.layer_id] == "lsh": self.self_attention = LSHSelfAttention(config) else: @@ -983,13 +1017,15 @@ def forward( class ReformerFeedForwardDense(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) self.dropout = config.hidden_dropout_prob + if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act + self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) @@ -1000,9 +1036,10 @@ def forward(self, hidden_states): class ReformerFeedForwardOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) self.dropout = config.hidden_dropout_prob + self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) @@ -1012,8 +1049,9 @@ def forward(self, hidden_states): class ChunkReformerFeedForward(nn.Module): def __init__(self, config): super().__init__() - self.seq_len_dim = 1 self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) @@ -1033,13 +1071,22 @@ class ReformerLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.attention = ReformerAttention(config, layer_id) - self.feed_forward = ChunkReformerFeedForward(config) # dropout requires to have the same # seed for forward and backward pass self.attention_seed = None self.feed_forward_seed = None + self.feed_forward = ChunkReformerFeedForward(config) + def _init_attention_seed(self): + """ + This function sets a new seed for the + attention layer to make dropout deterministic + for both forward calls: 1 normal forward + call and 1 forward call in backward + to recalculate activations. + """ + # randomize seeds if next(self.parameters()).device.type == "cuda": # GPU @@ -1052,6 +1099,14 @@ def _init_attention_seed(self): torch.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): + """ + This function sets a new seed for the + feed forward layer to make dropout deterministic + for both forward calls: 1 normal forward + call and 1 forward call in backward + to recalculate activations. + """ + # randomize seeds if next(self.parameters()).device.type == "cuda": # GPU From 6c2be30c0291f0cd55cfaac1032e0b31176d8964 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 20:23:28 +0200 Subject: [PATCH 127/162] docstring --- src/transformers/modeling_reformer.py | 280 +++++++++++++++++++------- 1 file changed, 212 insertions(+), 68 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index b4be2409b907..f3fdf4bb9492 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -31,17 +31,15 @@ from .activations import gelu, gelu_fast, gelu_new, swish from .configuration_reformer import ReformerConfig +from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel logger = logging.getLogger(__name__) -REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {} -# TODO: fill with pretrained model weights -# "reformer-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-uncased-pytorch_model.bin", -# "reformer-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-uncased-pytorch_model.bin", -# "reformer-base-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-base-cased-pytorch_model.bin", -# "reformer-large-cased": "https://s3.amazonaws.com/models.huggingface.co/reformer/reformer-large-cased-pytorch_model.bin", +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin" +} def mish(x): @@ -420,8 +418,16 @@ def forward( ) value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) - assert query_key_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(query_key_vectors.shape[-1], self.attention_head_size) - assert value_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(value_vectors.shape[-1], self.attention_head_size) + assert ( + query_key_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + query_key_vectors.shape[-1], self.attention_head_size + ) + assert ( + value_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + value_vectors.shape[-1], self.attention_head_size + ) # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: @@ -432,7 +438,9 @@ def forward( # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors, num_hashes) - assert int(buckets.shape[-1]) == num_hashes * sequence_length, "lash dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) + assert ( + int(buckets.shape[-1]) == num_hashes * sequence_length + ), "lash dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sequence_length, buckets, num_hashes @@ -453,7 +461,9 @@ def forward( ) if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0, "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + assert ( + self.num_chunks_before == 0 and self.num_chunks_after == 0 + ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." # scale key vectors key_vectors = self._len_and_dim_norm(query_key_vectors) @@ -492,7 +502,12 @@ def forward( # free memory del logits - assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`." + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`." out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) @@ -612,7 +627,9 @@ def _attend( # free memory del query_vectors, key_vectors - query_bucket_idx = self._split_seq_length_dim_to(sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads) + query_bucket_idx = self._split_seq_length_dim_to( + sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads + ) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) # get correct mask values depending on precision @@ -641,7 +658,9 @@ def _attend( # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " - self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(query_bucket_idx.device) + self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( + query_bucket_idx.device + ) # apply self_mask query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value) @@ -680,9 +699,7 @@ def _compute_attn_mask(self, query_indices, key_indices, attention_mask): # Causal mask if self.is_decoder: - mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to( - query_indices.device - ) + mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) # Attention mask: chunk, look up correct mask value from key_value_bucket_idx # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. @@ -740,6 +757,7 @@ class ReverseSort(Function): the gradients of the output vectors have to be explicitely sorted here. """ + @staticmethod def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): # save sorted_bucket_idx for backprop @@ -829,12 +847,26 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) - assert query_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(query_vectors.shape[-1], self.attention_head_size) - assert key_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(key_vectors.shape[-1], self.attention_head_size) - assert value_vectors.shape[-1] == self.attention_head_size, "last dim of query_key_vectors is {} but should be {}.".format(value_vectors.shape[-1], self.attention_head_size) + assert ( + query_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + query_vectors.shape[-1], self.attention_head_size + ) + assert ( + key_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + key_vectors.shape[-1], self.attention_head_size + ) + assert ( + value_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + value_vectors.shape[-1], self.attention_head_size + ) if self.chunk_length is None: - assert self.num_chunks_before == 0 and self.num_chunks_after == 0, "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + assert ( + self.num_chunks_before == 0 and self.num_chunks_after == 0 + ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." # normalize key vectors key_vectors = key_vectors / torch.sqrt( @@ -1172,6 +1204,9 @@ def backward_pass( head_mask=None, buckets=None, ): + # Implements the backward pass for reversible ResNets. + # A good blog post on how this works can be found here: + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py with torch.enable_grad(): @@ -1179,7 +1214,6 @@ def backward_pass( # set seed to have correct dropout torch.manual_seed(self.feed_forward_seed) - # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # g(Y_1) res_hidden_states = self.feed_forward(next_attn_output) res_hidden_states.backward(grad_hidden_states, retain_graph=True) @@ -1223,6 +1257,10 @@ def backward_pass( class _ReversibleFunction(Function): """ + To prevent PyTorch from performing the usual backpropagation, + a customized backward function is implemented here. This way + it is made sure that no memory expensive activations are + saved during the forward pass. This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py """ @@ -1239,9 +1277,11 @@ def forward( do_output_hidden_states, do_output_attentions, ): - all_buckets = () + + # split duplicated tensor hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) + for layer, layer_head_mask in zip(layers, head_mask): if do_output_hidden_states is True: all_hidden_states.append(hidden_states) @@ -1325,11 +1365,12 @@ def backward(ctx, grad_hidden_states): class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() + self.dropout = config.hidden_dropout_prob + self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) - self.dropout = config.hidden_dropout_prob def forward( self, @@ -1344,7 +1385,7 @@ def forward( all_hidden_states = [] all_attentions = [] - # Make this work + # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states = _ReversibleFunction.apply( hidden_states, @@ -1369,6 +1410,27 @@ def forward( ) +class ReformerOnlyLMHead(nn.Module): + def __init__(self, config): + super().__init__() + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.seq_len_dim = 1 + self.chunk_size_lm_head = config.chunk_size_lm_head + self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + class ReformerPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -1378,6 +1440,16 @@ class ReformerPreTrainedModel(PreTrainedModel): pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "reformer" + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + } + return dummy_inputs + def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, AxialPositionEmbeddings): @@ -1392,35 +1464,79 @@ def _init_weights(self, module): if module.weight.requires_grad: module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # TODO(PVP): discuss with Thom if necessary here to use different init - # torch.nn.init.xavier_uniform_(module.weight) + # torch.nn.init.xavier_uniform_(module.weight) elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() # TODO(PVP): discuss with Thom if necessary here to use different init + # module.bias.data.normal_(mean=0.0, std=1e-6) -# module.bias.data.normal_(mean=0.0, std=1e-6) - - -class ReformerModel(ReformerPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. +REFORMER_START_DOCSTRING = r""" + Reformer was proposed in + `Reformer: The Efficient Transformer`_ + by Nikita Kitaev, Ɓukasz Kaiser, Anselm Levskaya. - To behave as an decoder the model needs to be initialized with the - :obj:`is_decoder` argument of the configuration set to :obj:`True`; an - :obj:`encoder_hidden_states` is expected as an input to the forward pass. + .. _`Reformer: The Efficient Transformer`: + https://arxiv.org/abs/2001.04451 - .. _`Attention is all you need`: - https://arxiv.org/abs/1706.03762 + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. - """ + Parameters: + config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" +REFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + During training the input_ids sequence_length has to be a multiple of the relevant model's + chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically + padded to be a multiple of the chunk length. + + Indices can be obtained using :class:`transformers.ReformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + num_hashes (:obj:`int`, `optional`, defaults to :obj:`None`): + `num_hashes` is the number of hashing rounds that should be performed during + bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined + in `config.num_hashes`. + For more information, see `num_hashes` in :class:`transformers.ReformerConfig`. +""" + + +@add_start_docstrings( + "The bare Reformer Model transformer outputting raw hidden-states" "without any specific head on top.", + REFORMER_START_DOCSTRING, +) +class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config @@ -1447,6 +1563,7 @@ def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) def forward( self, input_ids=None, @@ -1463,12 +1580,12 @@ def forward( :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. @@ -1477,7 +1594,18 @@ def forward( Examples:: + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased') + model = ReformerModel.from_pretrained('bert-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple """ + # TODO(PVP): delete when PR to change output_attentions is made do_output_attentions = self.config.output_attentions do_output_hidden_states = self.config.output_hidden_states @@ -1535,7 +1663,6 @@ def forward( sequence_output = sequence_output[:, :orig_sequence_length] outputs = (sequence_output,) - # TODO(PVP): Replace by named tuple after namedtuples are introduced in the library. if do_output_hidden_states is True: outputs = outputs + (encoder_outputs.all_hidden_states,) @@ -1597,35 +1724,13 @@ def _pad_to_mult_of_chunk_length( padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) input_shape = inputs_embeds.size() - return input_ids, inputs_embeds, attention_mask, position_ids, input_shape -class ReformerOnlyLMHead(nn.Module): - def __init__(self, config): - super().__init__() - # Reformer is using Rev Nets, thus last layer outputs are concatenated and - # Layer Norm is done over 2 * hidden_size - self.seq_len_dim = 1 - self.chunk_size_lm_head = config.chunk_size_lm_head - self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) - - def forward_chunk(self, hidden_states): - hidden_states = self.decoder(hidden_states) - return hidden_states - - +@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING) class ReformerModelWithLMHead(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) - self.reformer = ReformerModel(config) self.lm_head = ReformerOnlyLMHead(config) @@ -1635,8 +1740,10 @@ def get_output_embeddings(self): return self.lm_head.decoder def tie_weights(self): + # word embeddings are not tied in Reformer pass + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) def forward( self, input_ids=None, @@ -1649,6 +1756,44 @@ def forward( do_output_hidden_states=False, do_output_attentions=False, ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. + All labels set to ``-100`` are ignored (masked), the loss is only + computed for labels in ``[0, ..., config.vocab_size]`` + + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided): + Classification loss (cross entropy). + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased') + model = ReformerModel.from_pretrained('bert-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=input_ids) + + loss, prediction_scores = outputs[:2] + """ reformer_outputs = self.reformer( input_ids, @@ -1673,7 +1818,6 @@ def forward( loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) outputs = (loss,) + outputs - return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) def prepare_inputs_for_generation(self, input_ids, past, **kwargs): From 3d266fb9332728431def576bdb52428b12da25d1 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 20:27:48 +0200 Subject: [PATCH 128/162] remove sinusoidal position encodings for reformer --- src/transformers/modeling_reformer.py | 51 ++++++++------------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index f3fdf4bb9492..006dad1b136b 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -68,14 +68,6 @@ def mish(x): ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"]) -def create_sinusoidal_embeddings(n_pos, dim, out): - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) - out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - out.requires_grad = False - - def _get_least_common_mult_chunk_len(config): attn_types = config.attn_layers if len(set(attn_types)) == 1 and attn_types[0] == "lsh": @@ -227,18 +219,11 @@ class PositionEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.dropout = config.hidden_dropout_prob - self.is_sinusoidal_pos_embds = config.sinusoidal_pos_embds - self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) - if config.sinusoidal_pos_embds is True: - create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, dim=config.hidden_size, out=self.embedding.weight, - ) def forward(self, position_ids): position_embeddings = self.embedding(position_ids) - if self.is_sinusoidal_pos_embds is False: - position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) + position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) return position_embeddings @@ -255,9 +240,6 @@ def __init__(self, config): self.position_embeddings = ( AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) ) - assert not ( - config.sinusoidal_pos_embds and config.axial_pos_embds - ), "Select either config.sinusoidal_pos_embds or config.axial_pos_embds" def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): if input_ids is not None: @@ -1112,10 +1094,10 @@ def __init__(self, config, layer_id=0): def _init_attention_seed(self): """ - This function sets a new seed for the + This function sets a new seed for the attention layer to make dropout deterministic - for both forward calls: 1 normal forward - call and 1 forward call in backward + for both forward calls: 1 normal forward + call and 1 forward call in backward to recalculate activations. """ @@ -1132,10 +1114,10 @@ def _init_attention_seed(self): def _init_feed_forward_seed(self): """ - This function sets a new seed for the + This function sets a new seed for the feed forward layer to make dropout deterministic - for both forward calls: 1 normal forward - call and 1 forward call in backward + for both forward calls: 1 normal forward + call and 1 forward call in backward to recalculate activations. """ @@ -1257,9 +1239,9 @@ def backward_pass( class _ReversibleFunction(Function): """ - To prevent PyTorch from performing the usual backpropagation, - a customized backward function is implemented here. This way - it is made sure that no memory expensive activations are + To prevent PyTorch from performing the usual backpropagation, + a customized backward function is implemented here. This way + it is made sure that no memory expensive activations are saved during the forward pass. This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py """ @@ -1463,15 +1445,12 @@ def _init_weights(self, module): # cf https://github.com/pytorch/pytorch/pull/5617 if module.weight.requires_grad: module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - # TODO(PVP): discuss with Thom if necessary here to use different init - # torch.nn.init.xavier_uniform_(module.weight) + elif isinstance(module, ReformerLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - # TODO(PVP): discuss with Thom if necessary here to use different init - # module.bias.data.normal_(mean=0.0, std=1e-6) REFORMER_START_DOCSTRING = r""" @@ -1496,7 +1475,7 @@ def _init_weights(self, module): Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. - During training the input_ids sequence_length has to be a multiple of the relevant model's + During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. @@ -1525,9 +1504,9 @@ def _init_weights(self, module): This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. num_hashes (:obj:`int`, `optional`, defaults to :obj:`None`): - `num_hashes` is the number of hashing rounds that should be performed during - bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined - in `config.num_hashes`. + `num_hashes` is the number of hashing rounds that should be performed during + bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined + in `config.num_hashes`. For more information, see `num_hashes` in :class:`transformers.ReformerConfig`. """ From 1be343f627f824d19e428dbbcbb3349a641f98ae Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 20:32:40 +0200 Subject: [PATCH 129/162] move chunking to modeling_utils --- src/transformers/modeling_reformer.py | 55 +---------------------- src/transformers/modeling_utils.py | 52 +++++++++++++++++++++ src/transformers/tokenization_reformer.py | 6 +-- 3 files changed, 54 insertions(+), 59 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 006dad1b136b..6c795582ae55 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -15,13 +15,11 @@ # limitations under the License. """PyTorch REFORMER model. """ -import inspect import logging import sys from collections import namedtuple from functools import reduce from operator import mul -from typing import Callable import numpy as np import torch @@ -32,7 +30,7 @@ from .activations import gelu, gelu_fast, gelu_new, swish from .configuration_reformer import ReformerConfig from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_utils import PreTrainedModel +from .modeling_utils import PreTrainedModel, apply_chunking_to_forward logger = logging.getLogger(__name__) @@ -84,57 +82,6 @@ def _get_least_common_mult_chunk_len(config): ) -def apply_chunking_to_forward( - chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors -) -> torch.Tensor: - """ - This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. - It then applies a layer `forward_fn` to each chunk independently to save memory. - If the `forward_fn` is independent across the `chunk_dim` this function will yield the - same result as not applying it. - - Args: - chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` - chunk_dim: int - the dimension over which the input_tensors should be chunked - forward_fn: fn - the forward fn of the model - input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked - Returns: - a Tensor with the same shape the foward_fn would have given if applied - """ - - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) - tensor_shape = input_tensors[0].shape - assert all( - input_tensor.shape == tensor_shape for input_tensor in input_tensors - ), "All input tenors have to be of the same shape" - - # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability - num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - assert num_args_in_forward_chunk_fn == len( - input_tensors - ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( - num_args_in_forward_chunk_fn, len(input_tensors) - ) - - if chunk_size > 0: - assert ( - input_tensors[0].shape[chunk_dim] % chunk_size == 0 - ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( - input_tensors[0][chunk_dim], chunk_size - ) - - num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size - - # chunk input tensor into tuples - input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) - # apply forward fn to every tuple - output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) - # concatenate output at same dimension - return torch.cat(output_chunks, dim=chunk_dim) - - return forward_fn(*input_tensors) - - class AxialPositionEmbeddings(nn.Module): """Constructs axial position embeddings. Useful for very long input sequences to stay memory and computational efficient diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7c260791b2ca..3a6a105bff2c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. """PyTorch BERT model.""" +import inspect import logging import os from typing import Callable, Tuple @@ -2098,3 +2099,54 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) + + +def apply_chunking_to_forward( + chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. + It then applies a layer `forward_fn` to each chunk independently to save memory. + If the `forward_fn` is independent across the `chunk_dim` this function will yield the + same result as not applying it. + + Args: + chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` + chunk_dim: int - the dimension over which the input_tensors should be chunked + forward_fn: fn - the forward fn of the model + input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked + Returns: + a Tensor with the same shape the foward_fn would have given if applied + """ + + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) + tensor_shape = input_tensors[0].shape + assert all( + input_tensor.shape == tensor_shape for input_tensor in input_tensors + ), "All input tenors have to be of the same shape" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + assert num_args_in_forward_chunk_fn == len( + input_tensors + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( + num_args_in_forward_chunk_fn, len(input_tensors) + ) + + if chunk_size > 0: + assert ( + input_tensors[0].shape[chunk_dim] % chunk_size == 0 + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( + input_tensors[0][chunk_dim], chunk_size + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index 3f4e0ee85e78..a68b2e556790 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 Reformer Authors and HuggingFace Inc. team. +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,10 +27,6 @@ SPIECE_UNDERLINE = "▁" -# IMPORTANT: This is just a copy-paste from T5 and not compared to the original Reformer -# Tokenizer yet!!!! - - #################################################### # Mapping from the keyword arguments names of Tokenizer `__init__` # to file names for serializing Tokenizer instances From a10eb2e2f2eb35e2eaa060db87c8165eb7c13105 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 20:39:06 +0200 Subject: [PATCH 130/162] make style --- src/transformers/configuration_reformer.py | 7 +++---- tests/test_modeling_reformer.py | 20 +++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 3034a158c4de..caef839df051 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -30,8 +30,7 @@ class ReformerConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a :class:`~transformers.ReformerModel`. It is used to instantiate an Reformer model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the Reformer `bert-base-uncased `__ architecture. + architecture. Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` @@ -97,7 +96,7 @@ class ReformerConfig(PretrainedConfig): def __init__( self, - vocab_size=10, + vocab_size=320, attention_head_size=32, hidden_size=64, num_attention_heads=1, @@ -124,7 +123,7 @@ def __init__( axial_pos_embds=False, axial_pos_shape=[32, 16], axial_pos_embds_dim=[32, 32], - attn_layers=["lsh", "lsh", "lsh", "lsh"], + attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], is_decoder=False, pad_token_id=0, eos_token_id=2, diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 212cdcb4f258..5f7b37aebf7c 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -19,7 +19,7 @@ from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor -from .utils import CACHE_DIR, require_torch, slow, torch_device +from .utils import require_torch, slow, torch_device if is_torch_available(): @@ -215,12 +215,12 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu # normal padded attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) @@ -521,12 +521,12 @@ def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, inpu # normal padded attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) # shifted padded input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size),], dim=-1, + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) @@ -705,7 +705,7 @@ def test_reformer_model_fp16_generate(self): @slow def test_model_from_pretrained(self): for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = ReformerModelWithLMHead.from_pretrained(model_name, cache_dir=CACHE_DIR).to(torch_device) + model = ReformerModelWithLMHead.from_pretrained(model_name) self.assertIsNotNone(model) @@ -837,8 +837,8 @@ def _get_attn_mask(self): def _get_input_ids_and_mask(self): mask = torch.tensor( [ - [1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1,], - [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,], + [1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0], ], dtype=torch.long, device=torch_device, @@ -1096,9 +1096,7 @@ def test_lsh_lm_model_grad(self): @slow def test_pretrained_generate_crime_and_punish(self): - model = ReformerModelWithLMHead.from_pretrained( - "patrickvonplaten/reformer-crime-and-punish", cache_dir=CACHE_DIR - ).to(torch_device) + model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish").to(torch_device) tokenizer = ReformerTokenizer.from_pretrained("patrickvonplaten/reformer-crime-and-punish") model.eval() From f31b57018c560c5a4e0730d41f55c371be50a4d5 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 21:02:11 +0200 Subject: [PATCH 131/162] clean config --- src/transformers/configuration_reformer.py | 118 ++++++++++++--------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index caef839df051..b473d16e4ad9 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -36,27 +36,18 @@ class ReformerConfig(PretrainedConfig): to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. - Args: - vocab_size (:obj:`int`, optional, defaults to 30522): - Vocabulary size of the Reformer model. Defines the different tokens that - can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`. - hidden_size (:obj:`int`, optional, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (:obj:`int`, optional, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (:obj:`int`, optional, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. - num_buckets (:obj:`int`, optional, defaults to ): - TODO (PVP) - num_hashes (:obj:`int`, optional, defaults to ): - TODO (PVP) - chunk_length (:obj:`int`, optional, defaults to ): + attention_head_size=64, + attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], + axial_pos_embds=True, + axial_norm_std=1.0, + axial_pos_shape=[512, 1024], + axial_pos_embds_dim=[512, 1024], TODO (PVP) - num_chunks_before (:obj:`int`, optional, defaults to ): TODO (PVP) - num_chunks_after (:obj:`int`, optional, defaults to ): + chunk_size_lm_head=0, TODO (PVP) + chunk_size_feed_forward=0, feed_forward_size (:obj:`int`, optional, defaults to 3072): Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the Transformer encoder. hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): @@ -64,24 +55,49 @@ class ReformerConfig(PretrainedConfig): If string, "gelu", "relu", "swish" and "gelu_new" are supported. hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): - The dropout ratio for the attention probabilities. - max_position_embeddings (:obj:`int`, optional, defaults to 512): - The maximum sequence length that this model might ever be used with. - Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + hidden_size (:obj:`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. initializer_range (:obj:`float`, optional, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. + local_chunk_length (:obj:`int`, optional, defaults to ): + TODO (PVP) + local_num_chunks_before (:obj:`int`, optional, defaults to ): + TODO (PVP) + local_num_chunks_after (:obj:`int`, optional, defaults to ): + TODO (PVP) + local_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + lsh_chunk_length (:obj:`int`, optional, defaults to ): + TODO (PVP) + lsh_num_chunks_before (:obj:`int`, optional, defaults to ): + TODO (PVP) + lsh_num_chunks_after (:obj:`int`, optional, defaults to ): + TODO (PVP) + lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_buckets (:obj:`int`, optional, defaults to ): + TODO (PVP) + num_hashes (:obj:`int`, optional, defaults to ): + TODO (PVP) + vocab_size (:obj:`int`, optional, defaults to 30522): + Vocabulary size of the Reformer model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`. Example:: from transformers import ReformerModel, ReformerConfig - # Initializing a Reformer bert-base-uncased style configuration + # Initializing a Reformer configuration configuration = ReformerConfig() - # Initializing a model from the bert-base-uncased style configuration + # Initializing a Reformer model model = ReformerModel(configuration) # Accessing the model configuration @@ -96,38 +112,37 @@ class ReformerConfig(PretrainedConfig): def __init__( self, - vocab_size=320, - attention_head_size=32, - hidden_size=64, - num_attention_heads=1, - num_buckets=[2, 4], - num_hashes=4, - lsh_attn_chunk_length=64, - local_attn_chunk_length=64, - lsh_num_chunks_before=1, - lsh_num_chunks_after=0, - local_num_chunks_before=1, - local_num_chunks_after=0, + attention_head_size=64, + attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], + axial_norm_std=1.0, + axial_pos_embds=True, + axial_pos_shape=[512, 1024], + axial_pos_embds_dim=[512, 1024], chunk_size_lm_head=0, chunk_size_feed_forward=0, - feed_forward_size=128, - hidden_act="gelu", - hidden_dropout_prob=0.0, - lsh_attention_probs_dropout_prob=0.0, - local_attention_probs_dropout_prob=0.0, - max_position_embeddings=512, + eos_token_id=2, + feed_forward_size=512, + hash_seed=None, + hidden_act="relu", + hidden_dropout_prob=0.05, + hidden_size=256, initializer_range=0.02, - axial_norm_std=1.0, + is_decoder=True, layer_norm_eps=1e-12, - sinusoidal_pos_embds=False, - axial_pos_embds=False, - axial_pos_shape=[32, 16], - axial_pos_embds_dim=[32, 32], - attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], - is_decoder=False, + local_num_chunks_before=1, + local_num_chunks_after=0, + local_attention_probs_dropout_prob=0.05, + local_attn_chunk_length=64, + lsh_attn_chunk_length=64, + lsh_attention_probs_dropout_prob=0.0, + lsh_num_chunks_before=1, + lsh_num_chunks_after=0, + max_position_embeddings=524288, + num_attention_heads=2, + num_buckets=[64, 128], + num_hashes=1, pad_token_id=0, - eos_token_id=2, - hash_seed=None, + vocab_size=320, **kwargs ): super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs) @@ -154,7 +169,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.sinusoidal_pos_embds = sinusoidal_pos_embds self.axial_pos_embds = axial_pos_embds self.axial_pos_shape = tuple(axial_pos_shape) self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) From b2a660f1eb2c49c1d25e47ff0b92ce85d742bc03 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 1 May 2020 23:56:13 +0200 Subject: [PATCH 132/162] make style --- src/transformers/configuration_reformer.py | 110 +++++++++++++-------- 1 file changed, 71 insertions(+), 39 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index b473d16e4ad9..4d62747c3527 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -37,56 +37,88 @@ class ReformerConfig(PretrainedConfig): for more information. Args: - attention_head_size=64, - attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], - axial_pos_embds=True, - axial_norm_std=1.0, - axial_pos_shape=[512, 1024], - axial_pos_embds_dim=[512, 1024], - TODO (PVP) - TODO (PVP) - chunk_size_lm_head=0, - TODO (PVP) - chunk_size_feed_forward=0, - feed_forward_size (:obj:`int`, optional, defaults to 3072): - Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the Transformer encoder. - hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): - The non-linear activation function (function or string) in the encoder and pooler. - If string, "gelu", "relu", "swish" and "gelu_new" are supported. - hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): + attention_head_size (:obj:`int`, optional, defaults to 64): + Dimensionality of the projected key, query and value vectors + attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]): + List of attention layer types in ascending order. It can be chosen between a + LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local"). + For more information on LSHSelfAttention layer, see `Local Sensitive Hashing Self Attention `__ . + For more information on LocalSelfAttention layer, see `Local Self Attention `__ . + axial_pos_embds (:obj:`bool`, optional, defaults to True): + If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings `__ + axial_norm_std (:obj:`float`, optional, defaluts to 1.0): + The standard deviation of the normal_initializer for initializing the weight matrices of the axial positional encodings. + axial_pos_shape (:obj:`list(int)`, optional, defaults to [512, 1024]): + The position dims of the axial position encodings. + During training the product of the position dims has to equal the sequence length. + For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. + axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to [512, 1024]): + The embedding dims of the axial position encodings. + The sum of the embedding dims has to equal the hidden size. + + For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. + chunk_size_lm_head (:obj:`int`, optional, defaults to 0): + The chunk size of the final language model feed forward head layer. + A chunk size of 0 means that the feed forward layer is not chunked. + A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. + For more information on feed forward chunking, see `How does Feed Forward Chunking work <../glossary.html#feed-forward-chunking>`__ . + chunk_size_feed_forward (:obj:`int`, optional, defaults to 0): + The chunk size of all feed forward layers in the residual attention blocks. + A chunk size of 0 means that the feed forward layer is not chunked. + A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. + + For more information on feed forward chunking, see `How does Feed Forward Chunking work <../glossary.html#feed-forward-chunking>`__ . + eos_token_id (:obj: `int`, optional, defaults to 2): + The token id for the token. + feed_forward_size (:obj:`int`, optional, defaults to 512): + Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the residual attention block. + hash_seed (:obj: `int`, optional, defaults to `None`): + Seed that can be used to make local sensitive hashing in LSHSelfAttention deterministic. This should only be set for testing purposed. For evaluation and training purposes `hash_seed` should be set to `None` to ensure fully random rotations in local sensitive hashing scheme. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"): + The non-linear activation function (function or string) in the feed forward layer in the residual attention block. + If string, "gelu", "relu", "swish", "gelu_new" and "gelu_fast" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.05): The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. - hidden_size (:obj:`int`, optional, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. + hidden_size (:obj:`int`, optional, defaults to 256): + Dimensionality of the output hidden states of the residual attention blocks. initializer_range (:obj:`float`, optional, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + is_decoder (:obj:`bool`, optional, defaults to True): + If `is_decoder` is True, a causal mask is used in addition to `attention_mask`. + When using the Reformer for casaul language modeling, `is_decoder` is set to `True`. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. - local_chunk_length (:obj:`int`, optional, defaults to ): - TODO (PVP) - local_num_chunks_before (:obj:`int`, optional, defaults to ): - TODO (PVP) - local_num_chunks_after (:obj:`int`, optional, defaults to ): - TODO (PVP) + local_chunk_length (:obj:`int`, optional, defaults to 64): + Length of chunk which attends to itself in LocalSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention). + local_num_chunks_before (:obj:`int`, optional, defaults to 1): + Number of previous neighbouring chunks to attend to in LocalSelfAttention layer to itself. + local_num_chunks_after (:obj:`int`, optional, defaults to 0): + Number of following neighbouring chunks to attend to in LocalSelfAttention layer in addition to itself. local_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): - The dropout ratio for the attention probabilities. - lsh_chunk_length (:obj:`int`, optional, defaults to ): - TODO (PVP) - lsh_num_chunks_before (:obj:`int`, optional, defaults to ): - TODO (PVP) - lsh_num_chunks_after (:obj:`int`, optional, defaults to ): - TODO (PVP) + The dropout ratio for the attention probabilities in LocalSelfAttention. + lsh_chunk_length (:obj:`int`, optional, defaults to 64): + Length of chunk which attends to itself in LSHSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention). + lsh_num_chunks_before (:obj:`int`, optional, defaults to 1): + Number of previous neighbouring chunks to attend to in LSHSelfAttention layer to itself. + lsh_num_chunks_after (:obj:`int`, optional, defaults to 0): + Number of following neighbouring chunks to attend to in LSHSelfAttention layer to itself. lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): - The dropout ratio for the attention probabilities. + The dropout ratio for the attention probabilities in LSHSelfAttention. max_position_embeddings (:obj:`int`, optional, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). num_attention_heads (:obj:`int`, optional, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. - num_buckets (:obj:`int`, optional, defaults to ): - TODO (PVP) - num_hashes (:obj:`int`, optional, defaults to ): - TODO (PVP) - vocab_size (:obj:`int`, optional, defaults to 30522): + num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `[64, 128]`): + Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`. + The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors. + The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. + num_hashes (:obj:`int`, optional, defaults to 1): + Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme. + The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes. + pad_token_id (:obj: `int`, optional, defaults to 0): + The token id for the token. + vocab_size (:obj:`int`, optional, defaults to 320): Vocabulary size of the Reformer model. Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`. @@ -117,7 +149,7 @@ def __init__( axial_norm_std=1.0, axial_pos_embds=True, axial_pos_shape=[512, 1024], - axial_pos_embds_dim=[512, 1024], + axial_pos_embds_dim=[64, 192], chunk_size_lm_head=0, chunk_size_feed_forward=0, eos_token_id=2, From dfc1f64994df7542a99f62d06686e7e8661afc9a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 00:33:04 +0200 Subject: [PATCH 133/162] fix tests --- src/transformers/configuration_reformer.py | 22 ++++++++++++---------- src/transformers/tokenization_reformer.py | 10 ++++++++-- tests/test_modeling_common.py | 2 +- tests/test_modeling_reformer.py | 12 +++++++++++- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 4d62747c3527..83d544d47531 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -23,7 +23,9 @@ logger = logging.getLogger(__name__) -REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json" +} class ReformerConfig(PretrainedConfig): @@ -48,11 +50,11 @@ class ReformerConfig(PretrainedConfig): If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings `__ axial_norm_std (:obj:`float`, optional, defaluts to 1.0): The standard deviation of the normal_initializer for initializing the weight matrices of the axial positional encodings. - axial_pos_shape (:obj:`list(int)`, optional, defaults to [512, 1024]): + axial_pos_shape (:obj:`list(int)`, optional, defaults to `[64, 64]`): The position dims of the axial position encodings. During training the product of the position dims has to equal the sequence length. For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. - axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to [512, 1024]): + axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to `[64, 192]`): The embedding dims of the axial position encodings. The sum of the embedding dims has to equal the hidden size. @@ -83,7 +85,7 @@ class ReformerConfig(PretrainedConfig): Dimensionality of the output hidden states of the residual attention blocks. initializer_range (:obj:`float`, optional, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - is_decoder (:obj:`bool`, optional, defaults to True): + is_decoder (:obj:`bool`, optional, defaults to False): If `is_decoder` is True, a causal mask is used in addition to `attention_mask`. When using the Reformer for casaul language modeling, `is_decoder` is set to `True`. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): @@ -104,12 +106,12 @@ class ReformerConfig(PretrainedConfig): Number of following neighbouring chunks to attend to in LSHSelfAttention layer to itself. lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): The dropout ratio for the attention probabilities in LSHSelfAttention. - max_position_embeddings (:obj:`int`, optional, defaults to 512): + max_position_embeddings (:obj:`int`, optional, defaults to 4096): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). num_attention_heads (:obj:`int`, optional, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. - num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `[64, 128]`): + num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`): Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`. The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors. The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. @@ -148,7 +150,7 @@ def __init__( attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], axial_norm_std=1.0, axial_pos_embds=True, - axial_pos_shape=[512, 1024], + axial_pos_shape=[64, 64], axial_pos_embds_dim=[64, 192], chunk_size_lm_head=0, chunk_size_feed_forward=0, @@ -159,7 +161,7 @@ def __init__( hidden_dropout_prob=0.05, hidden_size=256, initializer_range=0.02, - is_decoder=True, + is_decoder=False, layer_norm_eps=1e-12, local_num_chunks_before=1, local_num_chunks_after=0, @@ -169,9 +171,9 @@ def __init__( lsh_attention_probs_dropout_prob=0.0, lsh_num_chunks_before=1, lsh_num_chunks_after=0, - max_position_embeddings=524288, + max_position_embeddings=4096, num_attention_heads=2, - num_buckets=[64, 128], + num_buckets=32, num_hashes=1, pad_token_id=0, vocab_size=320, diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py index a68b2e556790..4accdcc3cfbf 100644 --- a/src/transformers/tokenization_reformer.py +++ b/src/transformers/tokenization_reformer.py @@ -37,12 +37,18 @@ # Mapping from the keyword arguments names of Tokenizer `__init__` # to pretrained vocabulary URL for all the model shortcut names. #################################################### -PRETRAINED_VOCAB_FILES_MAP = {"vocab_file": {}} +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" + } +} #################################################### # Mapping from model shortcut names to max length of inputs #################################################### -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "google/reformer-crime-and-punishment": 524288, +} class ReformerTokenizer(PreTrainedTokenizer): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0342d22119ca..b0b3603af6c1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -489,7 +489,7 @@ def test_hidden_states_output(self): seq_length = self.model_tester.seq_length self.assertListEqual( - list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size,], + list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size], ) def test_resize_tokens_embeddings(self): diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 5f7b37aebf7c..ef49d13e4862 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -74,6 +74,7 @@ def __init__( pad_token_id=0, eos_token_id=2, scope=None, + hash_seed=0, ): self.parent = parent self.batch_size = batch_size @@ -105,6 +106,7 @@ def __init__( self.scope = scope self.attn_layers = attn_layers self.pad_token_id = pad_token_id + self.hash_seed = hash_seed self.encoder_seq_length = seq_length // local_attn_chunk_length + ( self.seq_length % local_attn_chunk_length != 0 @@ -140,6 +142,7 @@ def prepare_config_and_inputs(self): local_num_chunks_before=self.local_num_chunks_before, attn_layers=self.attn_layers, pad_token_id=self.pad_token_id, + hash_seed=self.hash_seed ) return ( @@ -371,9 +374,11 @@ def __init__( lsh_attention_probs_dropout_prob=0.1, max_position_embeddings=512, initializer_range=0.02, + axial_norm_std=1.0, layer_norm_eps=1e-12, - sinusoidal_pos_embds=True, axial_pos_embds=True, + axial_pos_shape=[2, 8], + axial_pos_embds_dim=[16, 48], attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0, eos_token_id=2, @@ -410,6 +415,9 @@ def __init__( self.hash_seed = hash_seed self.pad_token_id = pad_token_id self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length @@ -434,6 +442,8 @@ def prepare_config_and_inputs(self): max_position_embeddings=self.max_position_embeddings, is_decoder=self.is_decoder, axial_pos_embds=self.axial_pos_embds, + axial_pos_shape=self.axial_pos_shape, + axial_pos_embds_dim=self.axial_pos_embds_dim, num_hashes=self.num_hashes, num_buckets=self.num_buckets, lsh_attn_chunk_length=self.lsh_attn_chunk_length, From 2e95c17fbd5d515abc550ce1a0d390096ca4eb0e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 00:46:30 +0200 Subject: [PATCH 134/162] fix auto tests --- docs/source/glossary.rst | 5 ++ docs/source/index.rst | 1 + docs/source/model_doc/reformer.rst | 60 ++++++++++++++++++++++ src/transformers/configuration_reformer.py | 4 +- src/transformers/tokenization_auto.py | 3 ++ tests/test_modeling_reformer.py | 2 +- 6 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 docs/source/model_doc/reformer.rst diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index cfd8c50dd6bd..cb7d1f1f438c 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -143,3 +143,8 @@ positional embeddings. Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models use other types of positional embeddings, such as sinusoidal position embeddings or relative position embeddings. + + +Feed Forward Chunking +-------------------------- +TODO diff --git a/docs/source/index.rst b/docs/source/index.rst index 32c71fba9a58..9d9005c1a8e2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,3 +107,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train model_doc/t5 model_doc/electra model_doc/dialogpt + model_doc/reformer diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst new file mode 100644 index 000000000000..f0294b376a53 --- /dev/null +++ b/docs/source/model_doc/reformer.rst @@ -0,0 +1,60 @@ +Reformer +---------------------------------------------------- +**DISCLAIMER:** This model is still a work in progress, if you see something strange, +file a `Github Issue `_ + +Overview +~~~~~ +The Reformer model was presented in `Reformer: The Efficient Transformer `_ by Nikita Kitaev, Ɓukasz Kaiser, Anselm Levskaya. +Here the abstract: + +*Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.* + +The Authors' code can be found `here https://github.com/google/trax/tree/master/trax/models/reformer>`_ . + +Axial Positional Encodings +~~~~~~~~~~~~~~~~~~~~ +TODO + +LSH Sensitive Hashing Self Attention +~~~~~~~~~~~~~~~~~~~~ +TODO + +Local Sensitive Hashing Self Attention +~~~~~~~~~~~~~~~~~~~~ +TODO + +Training +~~~~~~~~~~~~~~~~~~~~ +TODO + +Tips +~~~~~~~~~~~~~~~~~~~~ +TODO + +ReformerConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerConfig + :members: + + +ReformerTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerTokenizer + :members: + + +ReformerModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerModel + :members: + + +ReformerModelWithLMHead +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerModelWithLMHead + :members: diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 83d544d47531..f0055a81d818 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -44,7 +44,7 @@ class ReformerConfig(PretrainedConfig): attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]): List of attention layer types in ascending order. It can be chosen between a LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local"). - For more information on LSHSelfAttention layer, see `Local Sensitive Hashing Self Attention `__ . + For more information on LSHSelfAttention layer, see `LSH Sensitive Hashing Self Attention `__ . For more information on LocalSelfAttention layer, see `Local Self Attention `__ . axial_pos_embds (:obj:`bool`, optional, defaults to True): If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings `__ @@ -63,7 +63,7 @@ class ReformerConfig(PretrainedConfig): The chunk size of the final language model feed forward head layer. A chunk size of 0 means that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. - For more information on feed forward chunking, see `How does Feed Forward Chunking work <../glossary.html#feed-forward-chunking>`__ . + For more information on feed forward chunking, see `How does Feed Forward Chunking work ? <../glossary.html#feed-forward-chunking>`__ . chunk_size_feed_forward (:obj:`int`, optional, defaults to 0): The chunk size of all feed forward layers in the residual attention blocks. A chunk size of 0 means that the feed forward layer is not chunked. diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index f26e4fa0fdf8..9982e31cea9d 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -30,6 +30,7 @@ FlaubertConfig, GPT2Config, OpenAIGPTConfig, + ReformerConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -49,6 +50,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast +from .tokenization_reformer import ReformerTokenizer from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast @@ -69,6 +71,7 @@ (XLMRobertaConfig, (XLMRobertaTokenizer, None)), (BartConfig, (BartTokenizer, None)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), + (ReformerConfig, (ReformerTokenizer, None)), (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index ef49d13e4862..58f88fff019b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -142,7 +142,7 @@ def prepare_config_and_inputs(self): local_num_chunks_before=self.local_num_chunks_before, attn_layers=self.attn_layers, pad_token_id=self.pad_token_id, - hash_seed=self.hash_seed + hash_seed=self.hash_seed, ) return ( From b95f6ae82fc0a5fb6e28819d3fd23baa13b15b76 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 00:49:28 +0200 Subject: [PATCH 135/162] pretrained models --- docs/source/pretrained_models.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 35b92ad219e7..aae5024654d6 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,3 +296,5 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From a6f69cb9c5a27c3f6d2c5aa2d8361e189d4798c2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:16:25 +0200 Subject: [PATCH 136/162] fix docstring --- src/transformers/configuration_reformer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index f0055a81d818..de6225269a97 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -57,24 +57,22 @@ class ReformerConfig(PretrainedConfig): axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to `[64, 192]`): The embedding dims of the axial position encodings. The sum of the embedding dims has to equal the hidden size. - For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. chunk_size_lm_head (:obj:`int`, optional, defaults to 0): The chunk size of the final language model feed forward head layer. A chunk size of 0 means that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. - For more information on feed forward chunking, see `How does Feed Forward Chunking work ? <../glossary.html#feed-forward-chunking>`__ . + For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . chunk_size_feed_forward (:obj:`int`, optional, defaults to 0): The chunk size of all feed forward layers in the residual attention blocks. A chunk size of 0 means that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. - - For more information on feed forward chunking, see `How does Feed Forward Chunking work <../glossary.html#feed-forward-chunking>`__ . - eos_token_id (:obj: `int`, optional, defaults to 2): + For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . + eos_token_id (:obj:`int`, optional, defaults to 2): The token id for the token. feed_forward_size (:obj:`int`, optional, defaults to 512): Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the residual attention block. - hash_seed (:obj: `int`, optional, defaults to `None`): + hash_seed (:obj:`int`, optional, defaults to `None`): Seed that can be used to make local sensitive hashing in LSHSelfAttention deterministic. This should only be set for testing purposed. For evaluation and training purposes `hash_seed` should be set to `None` to ensure fully random rotations in local sensitive hashing scheme. hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"): The non-linear activation function (function or string) in the feed forward layer in the residual attention block. @@ -118,7 +116,7 @@ class ReformerConfig(PretrainedConfig): num_hashes (:obj:`int`, optional, defaults to 1): Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme. The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes. - pad_token_id (:obj: `int`, optional, defaults to 0): + pad_token_id (:obj:`int`, optional, defaults to 0): The token id for the token. vocab_size (:obj:`int`, optional, defaults to 320): Vocabulary size of the Reformer model. Defines the different tokens that From 59868f31fc7a374afa9961f939de86845b038717 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:26:12 +0200 Subject: [PATCH 137/162] update conversion file --- .../convert_reformer_trax_checkpoint_to_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py index 6f51addf97aa..501ddbe60018 100755 --- a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -21,7 +21,6 @@ import numpy as np import torch -from tensorflow.compat.v1.io.gfile import GFile from transformers import ReformerConfig, ReformerModelWithLMHead @@ -30,6 +29,7 @@ def set_param(torch_layer, weight, bias=None): + # set parameter of one layer assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: @@ -180,7 +180,7 @@ def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch print("Building PyTorch model from configuration: {}".format(str(config))) model = ReformerModelWithLMHead(config) - with GFile(trax_model_pkl_path, "rb") as f: + with open(trax_model_pkl_path, "rb") as f: model_weights = pickle.load(f)["weights"] set_model_weights_in_torch(model_weights, model, config.hidden_size) From a81c3e091acedb2e35885a65f1d7ab52bb22cb62 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:30:33 +0200 Subject: [PATCH 138/162] Update pretrained_models.rst --- docs/source/pretrained_models.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index aae5024654d6..3fa151229753 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,5 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| | | | Train on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From c0ddf94060e1fc0c8b9d6d28e02d7abd6819f3f4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:32:57 +0200 Subject: [PATCH 139/162] fix rst --- docs/source/pretrained_models.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 3fa151229753..e04cf0763a93 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,6 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | -| | | | Train on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| | | | Train on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From 62a8eb00bc16f3ba40b25dbd9507ffc0eefe1a6b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:34:04 +0200 Subject: [PATCH 140/162] fix rst --- docs/source/pretrained_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index e04cf0763a93..a1cc818ba5af 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,6 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | | | | | Train on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From 47e5fc8a8636cfb40eca03da188d581918112505 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:37:42 +0200 Subject: [PATCH 141/162] update copyright --- examples/summarization/t5/output.txt | 0 src/transformers/configuration_reformer.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 examples/summarization/t5/output.txt diff --git a/examples/summarization/t5/output.txt b/examples/summarization/t5/output.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index de6225269a97..ba82f83a3bdc 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); From b6576c898b3521d1e132e1c0f9a79516d4ac6f07 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:40:06 +0200 Subject: [PATCH 142/162] fix test path --- tests/test_modeling_reformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 58f88fff019b..800fb3b0045f 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1106,8 +1106,8 @@ def test_lsh_lm_model_grad(self): @slow def test_pretrained_generate_crime_and_punish(self): - model = ReformerModelWithLMHead.from_pretrained("patrickvonplaten/reformer-crime-and-punish").to(torch_device) - tokenizer = ReformerTokenizer.from_pretrained("patrickvonplaten/reformer-crime-and-punish") + model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punish").to(torch_device) + tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punish") model.eval() input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) From a111720d20e9c8da55f7f4d2141f90fc98829055 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:40:53 +0200 Subject: [PATCH 143/162] fix test path --- tests/test_modeling_reformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 800fb3b0045f..77b94b1e0016 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1106,8 +1106,8 @@ def test_lsh_lm_model_grad(self): @slow def test_pretrained_generate_crime_and_punish(self): - model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punish").to(torch_device) - tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punish") + model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device) + tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment") model.eval() input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) From ff5e7835382a21e2bd91077d81d066a7f1d84520 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 01:56:16 +0200 Subject: [PATCH 144/162] fix small issue in test --- src/transformers/modeling_reformer.py | 5 +++-- src/transformers/modeling_utils.py | 1 - tests/test_modeling_reformer.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 6c795582ae55..346d97f1afca 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1560,7 +1560,8 @@ def forward( # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) has_to_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 - if has_to_pad_to_match_chunk_length is True: + + if has_to_pad_to_match_chunk_length: # pad input input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( input_ids, @@ -1585,7 +1586,7 @@ def forward( sequence_output = encoder_outputs.hidden_states # if padding was applied - if has_to_pad_to_match_chunk_length is True: + if has_to_pad_to_match_chunk_length: sequence_output = sequence_output[:, :orig_sequence_length] outputs = (sequence_output,) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3a6a105bff2c..81f396f0c9d3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -789,7 +789,6 @@ def generate( attention_mask=None, decoder_start_token_id=None, use_cache=None, - num_hashes=None, **model_specific_kwargs ): r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 77b94b1e0016..c7ed0e03e11e 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -1111,7 +1111,9 @@ def test_pretrained_generate_crime_and_punish(self): model.eval() input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) - output_ids = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False) + output_ids = model.generate( + input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 + ) output_text = tokenizer.decode(output_ids[0]) self.assertEqual( output_text, From f7f949b3e4737266ccf89fa48cc24565a87362e1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 11:34:27 +0200 Subject: [PATCH 145/162] include reformer in generation tests --- tests/test_activations.py | 2 ++ tests/test_modeling_common.py | 14 +++++++++++--- tests/test_modeling_reformer.py | 8 +++++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_activations.py b/tests/test_activations.py index d6cbc1f9e564..79e9eec0184c 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -22,6 +22,8 @@ def test_get_activation(self): get_activation("swish") get_activation("relu") get_activation("tanh") + get_activation("gelu_new") + get_activation("gelu_fast") with self.assertRaises(KeyError): get_activation("bogus") with self.assertRaises(KeyError): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b0b3603af6c1..a0a8ec74ade5 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -127,7 +127,7 @@ def test_attention_outputs(self): encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) chunk_length = getattr(self.model_tester, "chunk_length", None) if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - chunk_length = self.model_tester.chunk_length * config.num_hashes + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes for model_class in self.all_model_classes: config.output_attentions = True @@ -144,7 +144,7 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], ) else: self.assertListEqual( @@ -187,7 +187,7 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], ) else: self.assertListEqual( @@ -648,9 +648,13 @@ def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] + # max length of input_ids should be < max_length + input_ids = input_ids[..., :10] + # iterate over all generative models for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined, model needs input_ids @@ -689,8 +693,12 @@ def test_lm_head_model_random_beam_search_generate(self): torch_device ) + # max length of input_ids should be < max_length + input_ids = input_ids[..., :10] + for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1 diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c7ed0e03e11e..460e698526cc 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -37,6 +37,7 @@ @require_torch class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -46,7 +47,7 @@ def __init__( self, parent, batch_size=13, - seq_length=16, + seq_length=32, is_training=True, is_decoder=False, use_input_mask=True, @@ -68,7 +69,7 @@ def __init__( axial_norm_std=1.0, layer_norm_eps=1e-12, axial_pos_embds=True, - axial_pos_shape=[4, 4], + axial_pos_shape=[4, 8], axial_pos_embds_dim=[16, 16], attn_layers=["local", "local", "local", "local"], pad_token_id=0, @@ -344,6 +345,7 @@ def test_reformer_model_fp16_forward(self): @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -377,7 +379,7 @@ def __init__( axial_norm_std=1.0, layer_norm_eps=1e-12, axial_pos_embds=True, - axial_pos_shape=[2, 8], + axial_pos_shape=[4, 8], axial_pos_embds_dim=[16, 48], attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0, From 91472b8b06f40072ac13a04bc49aad256047293f Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sat, 2 May 2020 14:25:23 +0200 Subject: [PATCH 146/162] add docs for axial position encoding --- docs/source/model_doc/reformer.rst | 39 ++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index f0294b376a53..e076744b07ba 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -10,11 +10,46 @@ Here the abstract: *Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.* -The Authors' code can be found `here https://github.com/google/trax/tree/master/trax/models/reformer>`_ . +The Authors' code can be found `here `_ . Axial Positional Encodings ~~~~~~~~~~~~~~~~~~~~ -TODO +Axial Positional Encodings were first implemented in Google's `trax library `_ and developed by the authors of this model's paper. In Models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` (``config.hidden_size``) for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix: + +.. math:: + X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right] + +which alone has over 500M parameters to store. Axial positional encodings factorize :math:`X_{i,j}` into two matrices: + +.. math:: + X^{1}_{i,j^1}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j^1 \in \left[1,\ldots, n_s^1\right] + +and + +.. math:: + X^{2}_{i,j^2}, \text{ with } i \in \left[1,\ldots, d^2\right] \text{ and } j^2 \in \left[1,\ldots, n_s^2\right] + +with: + +.. math:: + d = d^1 + d^2 \text{ and } n_s = n_s^1 \times n_s^2 . + +Therefore the following holds: + +.. math:: + X_{i,j} = \begin{cases} + X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\ + X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor + \end{cases} + +Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`. +This design ensures that each position embedding vector :math:`x_j` is unique. + +Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^10` can drastically reduced the number of parameters to :math:`2^14 + 2^15 \approx 49000` parameters. + +In praxis, the parameter ``config.axial_pos_embds_dim`` is set to ``list(d^1, d^2)`` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list(n_s^1, n_s^2)`` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``. + + LSH Sensitive Hashing Self Attention ~~~~~~~~~~~~~~~~~~~~ From 6ed2fa86f6cb223d63e026bdc5ceb4e521e57e33 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sat, 2 May 2020 16:15:11 +0200 Subject: [PATCH 147/162] finish docs --- docs/source/glossary.rst | 8 ++++- docs/source/main_classes/model.rst | 6 ++++ docs/source/model_doc/reformer.rst | 37 ++++++++++++++++------ src/transformers/__init__.py | 2 +- src/transformers/configuration_reformer.py | 2 +- src/transformers/modeling_reformer.py | 3 ++ 6 files changed, 46 insertions(+), 12 deletions(-) diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index cb7d1f1f438c..7f8fbefc052c 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -147,4 +147,10 @@ use other types of positional embeddings, such as sinusoidal position embeddings Feed Forward Chunking -------------------------- -TODO + +In transformers two feed forward layers usually follows the self attention layer in each residual attention block. The intermediate embedding size of the feed forward layers is often bigger than the hidden size of the model (*e.g.* for ``bert-base-uncased``). + +For an input of size ``[batch_size, sequence_length]``, the memory required to store the intermediate feed forward embeddings ``[batch_size, sequence_length, config.intermediate_size]`` can account for a large fraction of the memory use. The authors of `Reformer: The Efficient Transformer `_ noticed that since the computation is independent of the ``sequence_length`` dimension, it is mathematically equivalent to compute the output embeddings of both feed forward layers ``[batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n`` individually and concat them afterward to ``[batch_size, sequence_length, config.hidden_size]`` with ``n = sequence_length``, which trades increased computation time against reduced memory use, but yields a mathematically **equivalent** result. + +For models employing the function :func:`~.transformers.apply_chunking_to_forward`, the ``chunk_size`` defines the number of output embeddings that are computed in parallel and thus defines the trade-off between memory and time complexity. +If ``chunk_size`` is set to 0, no feed forward chunking is done. diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index 6e3da45bc2df..0c5ef99d21d9 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -14,6 +14,12 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav .. autoclass:: transformers.PreTrainedModel :members: +``Helper Functions`` +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformers.apply_chunking_to_forward + + ``TFPreTrainedModel`` ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index e076744b07ba..46102446aa56 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -47,25 +47,44 @@ This design ensures that each position embedding vector :math:`x_j` is unique. Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^10` can drastically reduced the number of parameters to :math:`2^14 + 2^15 \approx 49000` parameters. -In praxis, the parameter ``config.axial_pos_embds_dim`` is set to ``list(d^1, d^2)`` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list(n_s^1, n_s^2)`` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``. +In praxis, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``. -LSH Sensitive Hashing Self Attention +LSH Self Attention ~~~~~~~~~~~~~~~~~~~~ -TODO +In Locality sensitive hashing (LSH) self attention the key and query projection weights are tied to that the key query embedding vectors are also tied. +LSH self attention uses the locality sensitive +hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance `_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the "similar" key query embedding vectors (in terms of *euclidean distance* to each other) are likely to be assigned to the same bucket. +The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention. +The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighbooring chunks and ``config.lsh_num_chunks_after`` following neighbooring chunks. +For more information, see the `original Paper _` or this great `blog post _`. -Local Sensitive Hashing Self Attention +Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory. + +It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly. + +Using LSH Self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times n_s)`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. + + +Local Self Attention ~~~~~~~~~~~~~~~~~~~~ -TODO +Local self attention is essentially a "normal" self attention layer with +key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighbooring chunks and ``config.local_num_chunks_after`` following neighbooring chunks. + +Using LSH Self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times n_s)`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. + Training ~~~~~~~~~~~~~~~~~~~~ -TODO +During training, it has be made sure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens. +For training, the ``ReformerModelWithLMHead`` should be used as follows: + +:: + + input_ids = tokenizer.encode('This is a sentence from the training data', return_tensors='pt') + loss = model(input_ids, labels=input_ids)[0] -Tips -~~~~~~~~~~~~~~~~~~~~ -TODO ReformerConfig ~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 52adf8a41238..01b868d726be 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -161,7 +161,7 @@ # Modeling if is_torch_available(): - from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering + from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering, apply_chunking_to_forward from .modeling_auto import ( AutoModel, AutoModelForPreTraining, diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index ba82f83a3bdc..db674b6f079b 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -44,7 +44,7 @@ class ReformerConfig(PretrainedConfig): attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]): List of attention layer types in ascending order. It can be chosen between a LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local"). - For more information on LSHSelfAttention layer, see `LSH Sensitive Hashing Self Attention `__ . + For more information on LSHSelfAttention layer, see `LSH Self Attention `__ . For more information on LocalSelfAttention layer, see `Local Self Attention `__ . axial_pos_embds (:obj:`bool`, optional, defaults to True): If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings `__ diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 346d97f1afca..bca7adebc42f 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -121,6 +121,9 @@ def forward(self, position_ids): weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights ] + import ipdb + ipdb.set_trace() + if self.training is True: assert ( reduce(mul, self.axial_pos_shape) == sequence_length From 963bb5ea4877ff11d31ef9005678781247e62438 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 16:21:20 +0200 Subject: [PATCH 148/162] Update convert_reformer_trax_checkpoint_to_pytorch.py --- src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py index 501ddbe60018..31dd0c0933b7 100755 --- a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert BERT checkpoint.""" +"""Convert Reformer checkpoint.""" import argparse From 425b18551c6d2fadc8b586106b25b8349537e29f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 16:31:50 +0200 Subject: [PATCH 149/162] remove isort --- src/transformers/modeling_reformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index bca7adebc42f..346d97f1afca 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -121,9 +121,6 @@ def forward(self, position_ids): weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights ] - import ipdb - ipdb.set_trace() - if self.training is True: assert ( reduce(mul, self.axial_pos_shape) == sequence_length From 3336d8f0314e23d8a121d1921f8b382cc1d8ec30 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 3 May 2020 14:43:30 +0200 Subject: [PATCH 150/162] include sams comments --- src/transformers/modeling_reformer.py | 113 ++++++++++++++------------ 1 file changed, 61 insertions(+), 52 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 346d97f1afca..1bd2ef98adaf 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -52,7 +52,6 @@ def mish(x): "gelu_fast": gelu_fast, "mish": mish, } -ReformerLayerNorm = torch.nn.LayerNorm # Define named tuples for nn.Modules here @@ -84,7 +83,7 @@ def _get_least_common_mult_chunk_len(config): class AxialPositionEmbeddings(nn.Module): """Constructs axial position embeddings. Useful for very long input - sequences to stay memory and computational efficient + sequences to save memory and time. """ def __init__(self, config): @@ -104,12 +103,12 @@ def __init__(self, config): # create weights for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): - # create shape + # create expanded shapes ax_shape = [1] * len(self.axial_pos_shape) ax_shape[axis] = self.axial_pos_shape[axis] ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,) - # create tenser and init + # create tensor and init self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) def forward(self, position_ids): @@ -132,8 +131,8 @@ def forward(self, position_ids): # permute weights so that 2D correctly drops dims 1 and 2 perm_weigthts = weights.permute(0, 3, 2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) - drop_perm_weights = nn.functional.dropout2d(perm_weigthts, self.dropout, training=self.training) - drop_weights = drop_perm_weights.permute(0, 3, 2, 1) + drop_permuted_weights = nn.functional.dropout2d(perm_weigthts, p=self.dropout, training=self.training) + drop_weights = drop_permuted_weights.permute(0, 3, 2, 1) position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) @@ -170,7 +169,7 @@ def __init__(self, config): def forward(self, position_ids): position_embeddings = self.embedding(position_ids) - position_embeddings = nn.functional.dropout(position_embeddings, self.dropout, self.training) + position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training) return position_embeddings @@ -210,15 +209,16 @@ def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): position_ids.shape[-1], self.max_position_embeddings ) - embeddings = nn.functional.dropout(inputs_embeds, self.dropout, self.training) - position_embeddings = self.position_embeddings(position_ids) + # dropout + embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training) + # add positional embeddings + position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings - embeddings = nn.functional.dropout(embeddings, self.dropout, self.training) return embeddings -class EfficientAttentionUtils(object): +class EfficientAttentionMixin: """ A few utilities for nn.Modules in Reformer, to be used as a mixin. """ @@ -232,7 +232,7 @@ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): num_chunks_after: chunks after current chunk to include in attention Returns: - array of shape [num_chunks, N * chunk_length, ...], where + tensor of shape [num_chunks, N * chunk_length, ...], where N = (1 + num_chunks_before + num_chunks_after). """ if num_chunks_before == 0 and num_chunks_after == 0: @@ -290,7 +290,7 @@ def _merge_seq_length_dims(self, vectors, num_attn_heads): raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) -class LSHSelfAttention(nn.Module, EfficientAttentionUtils): +class LSHSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config): super().__init__() self.chunk_length = config.lsh_attn_chunk_length @@ -354,13 +354,13 @@ def forward( ) assert ( value_vectors.shape[-1] == self.attention_head_size - ), "last dim of query_key_vectors is {} but should be {}.".format( + ), "last dim of value_vectors is {} but should be {}.".format( value_vectors.shape[-1], self.attention_head_size ) # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: - self._set_num_buckets_on_the_fly(sequence_length) + self._set_num_buckets(sequence_length) # use cached buckets for backprop only if buckets is None: @@ -369,7 +369,7 @@ def forward( assert ( int(buckets.shape[-1]) == num_hashes * sequence_length - ), "lash dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) + ), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sequence_length, buckets, num_hashes @@ -532,7 +532,7 @@ def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buc return sorted_bucket_idx, undo_sorted_bucket_idx - def _set_num_buckets_on_the_fly(self, sequence_length): + def _set_num_buckets(self, sequence_length): # recommended `num_buckets` from paper num_buckets = 2 * sequence_length // self.chunk_length @@ -605,7 +605,7 @@ def _attend( del query_key_dots # dropout - attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) # Mask heads if we want to if head_mask is not None: @@ -736,7 +736,7 @@ def backward(ctx, grad_out_vectors, grad_logits): return grad_out_vectors, grad_logits, None, None, None -class LocalSelfAttention(nn.Module, EfficientAttentionUtils): +class LocalSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config): super().__init__() @@ -853,7 +853,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ del logits # dropout - attention_probs = nn.functional.dropout(attention_probs, self.dropout, self.training) + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) # Mask heads if we want to if head_mask is not None: @@ -912,7 +912,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states @@ -922,7 +922,7 @@ def __init__(self, config, layer_id=0): self.layer_id = layer_id self.attn_layers = config.attn_layers - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config) @@ -989,7 +989,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = self.act_fn(hidden_states) return hidden_states @@ -1003,7 +1003,7 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = self.dense(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states @@ -1013,7 +1013,7 @@ def __init__(self, config): self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.layer_norm = ReformerLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) @@ -1090,7 +1090,7 @@ def forward( ): with torch.no_grad(): # every forward pass we sample a different seed - # for dropout and save seed for forward fn in backward + # for dropout and save for forward fn in backward pass # to have correct dropout self._init_attention_seed() attn_outputs = self.attention( @@ -1299,7 +1299,7 @@ def __init__(self, config): self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size - self.layer_norm = ReformerLayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) def forward( self, @@ -1332,7 +1332,7 @@ def forward( hidden_states = self.layer_norm(hidden_states) # Apply dropout - hidden_states = nn.functional.dropout(hidden_states, self.dropout, self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return ReformerEncoderOutput( hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions @@ -1385,15 +1385,13 @@ def _init_weights(self, module): for weight in module.weights: torch.nn.init.normal_(weight, std=self.config.axial_norm_std) elif isinstance(module, nn.Embedding): - if module.weight.requires_grad: - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - if module.weight.requires_grad: - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, ReformerLayerNorm): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: @@ -1559,18 +1557,28 @@ def forward( # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) - has_to_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 + must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 + + if must_pad_to_match_chunk_length: + padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length + + if self.training is True: + raise ValueError( + "If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format( + input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length + ) + ) - if has_to_pad_to_match_chunk_length: # pad input input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( input_ids, - inputs_embeds, - attention_mask, - position_ids, - input_shape, - least_common_mult_chunk_length, - device, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + input_shape=input_shape, + padding_length=padding_length, + padded_seq_length=least_common_mult_chunk_length, + device=device, ) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) @@ -1586,7 +1594,7 @@ def forward( sequence_output = encoder_outputs.hidden_states # if padding was applied - if has_to_pad_to_match_chunk_length: + if must_pad_to_match_chunk_length: sequence_output = sequence_output[:, :orig_sequence_length] outputs = (sequence_output,) @@ -1598,15 +1606,16 @@ def forward( return outputs def _pad_to_mult_of_chunk_length( - self, input_ids, inputs_embeds, attention_mask, position_ids, input_shape, padded_seq_length, device + self, + input_ids, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + input_shape=None, + padding_length=None, + padded_seq_length=None, + device=None, ): - padding_length = padded_seq_length - input_shape[-1] % padded_seq_length - assert ( - self.training is False - ), "Sequence Length {} has to be a multiple of least common multiple chunk_length {} if training. Please consider padding the input to a length of {}.".format( - input_shape[-2], padded_seq_length, input_shape[-2] + padding_length - ) - logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( input_shape[-1], input_shape[-1] + padding_length, padded_seq_length @@ -1713,8 +1722,8 @@ def forward( from transformers import ReformerModel, ReformerTokenizer import torch - tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased') - model = ReformerModel.from_pretrained('bert-base-uncased') + tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment') + model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 outputs = model(input_ids, labels=input_ids) From 54eb6293a42ac27c3885c6cab580da1fce3930a9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 3 May 2020 14:44:34 +0200 Subject: [PATCH 151/162] remove wrong comment in utils --- src/transformers/modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 81f396f0c9d3..821f322ac4e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch BERT model.""" import inspect import logging From e4e1e5915b61a7513f6fe0cf1d44005a791192f3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 3 May 2020 14:59:50 +0200 Subject: [PATCH 152/162] correct typos --- docs/source/model_doc/reformer.rst | 26 +++++++++---------- ...ert_reformer_trax_checkpoint_to_pytorch.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index 46102446aa56..4aea10c46327 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -14,7 +14,7 @@ The Authors' code can be found `here `_ and developed by the authors of this model's paper. In Models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` (``config.hidden_size``) for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix: +Axial Positional Encodings were first implemented in Google's `trax library `_ and developed by the authors of this model's paper. In models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` being the ``config.hidden_size`` for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix: .. math:: X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right] @@ -22,12 +22,12 @@ Axial Positional Encodings were first implemented in Google's `trax library `_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the "similar" key query embedding vectors (in terms of *euclidean distance* to each other) are likely to be assigned to the same bucket. +hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance `_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the more "similar" key query embedding vectors (in terms of *cosine similarity*) are to each other, the more likely they are assigned to the same bucket. The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention. -The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighbooring chunks and ``config.lsh_num_chunks_after`` following neighbooring chunks. -For more information, see the `original Paper _` or this great `blog post _`. +The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighboring chunks and ``config.lsh_num_chunks_after`` following neighboring chunks. +For more information, see the `original Paper `_ or this great `blog post `_. Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory. It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly. -Using LSH Self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times n_s)`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. +Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. Local Self Attention ~~~~~~~~~~~~~~~~~~~~ Local self attention is essentially a "normal" self attention layer with -key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighbooring chunks and ``config.local_num_chunks_after`` following neighbooring chunks. +key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighboring chunks and ``config.local_num_chunks_after`` following neighboring chunks. -Using LSH Self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times n_s)`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. +Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. Training ~~~~~~~~~~~~~~~~~~~~ -During training, it has be made sure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens. +During training, we must ensure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens. For training, the ``ReformerModelWithLMHead`` should be used as follows: :: diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py index 31dd0c0933b7..3d88e3d78d13 100755 --- a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. +# Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 5f5c89ba1469dbf9aa771c1c3b20643e078bfb19 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 3 May 2020 15:00:59 +0200 Subject: [PATCH 153/162] fix typo --- docs/source/pretrained_models.rst | 2 +- ~ | 301 ++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 ~ diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index a1cc818ba5af..1333ec36d3ad 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -297,5 +297,5 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | -| | | | Train on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/~ b/~ new file mode 100644 index 000000000000..1333ec36d3ad --- /dev/null +++ b/~ @@ -0,0 +1,301 @@ +Pretrained models +================================================ + +Here is the full list of the currently provided pretrained models together with a short presentation of each model. + +For a list that includes community-uploaded models, refer to `https://huggingface.co/models `__. + ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Architecture | Shortcut name | Details of the model | ++===================+============================================================+=======================================================================================================================================+ +| BERT | ``bert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on lower-cased English text. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-uncased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | Trained on lower-cased English text. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased English text. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | Trained on cased English text. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-multilingual-uncased`` | | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias | +| | | (see `details `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-multilingual-cased`` | | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased text in the top 104 languages with the largest Wikipedias | +| | | (see `details `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-chinese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased Chinese Simplified and Traditional text. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-german-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased German text by Deepset.ai | +| | | (see `details on deepset.ai website `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-uncased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | Trained on lower-cased English text using Whole-Word-Masking | +| | | (see `details `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-cased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | Trained on cased English text using Whole-Word-Masking | +| | | (see `details `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD | +| | | (see details of fine-tuning in the `example section `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters | +| | | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD | +| | | (see `details of fine-tuning in the example section `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-cased-finetuned-mrpc`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | The ``bert-base-cased`` model fine-tuned on MRPC | +| | | (see `details of fine-tuning in the example section `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-german-dbmdz-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased German text by DBMDZ | +| | | (see `details on dbmdz repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-german-dbmdz-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on uncased German text by DBMDZ | +| | | (see `details on dbmdz repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. | +| | | | `MeCab `__ is required for tokenization. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. | +| | | | `MeCab `__ is required for tokenization. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text. Text is tokenized into characters. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. | +| | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased Finnish text. | +| | | (see `details on turkunlp.org `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on uncased Finnish text. | +| | | (see `details on turkunlp.org `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased Dutch text. | +| | | (see `details on wietsedv repository `__). | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | OpenAI GPT English model | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| GPT-2 | ``gpt2`` | | 12-layer, 768-hidden, 12-heads, 117M parameters. | +| | | | OpenAI GPT-2 English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. | +| | | | OpenAI's Medium-sized GPT-2 English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. | +| | | | OpenAI's Large-sized GPT-2 English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``gpt2-xl`` | | 48-layer, 1600-hidden, 25-heads, 1558M parameters. | +| | | | OpenAI's XL-sized GPT-2 English model | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. | +| | | | English model trained on wikitext-103 | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| XLNet | ``xlnet-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | XLNet English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | +| | | | XLNet Large English model | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads | +| | | | XLM English model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | +| | | | XLM English-German model trained on the concatenation of English and German wikipedia | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads | +| | | | XLM English-French model trained on the concatenation of English and French wikipedia | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads | +| | | | XLM English-Romanian Multi-language model | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | | | XLM Model pre-trained with MLM on the `15 XNLI languages `__. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-tlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages `__. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-clm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads | +| | | | XLM English-French model trained with CLM (Causal Language Modeling) on the concatenation of English and French wikipedia | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | +| | | | XLM English-German model trained with CLM (Causal Language Modeling) on the concatenation of English and German wikipedia | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-17-1280`` | | 16-layer, 1280-hidden, 16-heads | +| | | | XLM model trained with MLM (Masked Language Modeling) on 17 languages. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-mlm-100-1280`` | | 16-layer, 1280-hidden, 16-heads | +| | | | XLM model trained with MLM (Masked Language Modeling) on 100 languages. | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | +| | | | RoBERTa using the BERT-base architecture | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``roberta-large`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | +| | | | RoBERTa using the BERT-large architecture | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``roberta-large-mnli`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | +| | | | ``roberta-large`` fine-tuned on `MNLI `__. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilroberta-base`` | | 6-layer, 768-hidden, 12-heads, 82M parameters | +| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``roberta-base-openai-detector`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | +| | | | ``roberta-base`` fine-tuned by OpenAI on the outputs of the 1.5B-parameter GPT-2 model. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``roberta-large-openai-detector`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | +| | | | ``roberta-large`` fine-tuned by OpenAI on the outputs of the 1.5B-parameter GPT-2 model. | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| DistilBERT | ``distilbert-base-uncased`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | +| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilbert-base-uncased-distilled-squad`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | +| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilbert-base-cased`` | | 6-layer, 768-hidden, 12-heads, 65M parameters | +| | | | The DistilBERT model distilled from the BERT model `bert-base-cased` checkpoint | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilbert-base-cased-distilled-squad`` | | 6-layer, 768-hidden, 12-heads, 65M parameters | +| | | | The DistilBERT model distilled from the BERT model `bert-base-cased` checkpoint, with an additional question answering layer. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilgpt2`` | | 6-layer, 768-hidden, 12-heads, 82M parameters | +| | | | The DistilGPT2 model distilled from the GPT2 model `gpt2` checkpoint. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilbert-base-german-cased`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | +| | | | The German DistilBERT model distilled from the German DBMDZ BERT model `bert-base-german-dbmdz-cased` checkpoint. | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``distilbert-base-multilingual-cased`` | | 6-layer, 768-hidden, 12-heads, 134M parameters | +| | | | The multilingual DistilBERT model distilled from the Multilingual BERT model `bert-base-multilingual-cased` checkpoint. | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters | +| | | | Salesforce's Large-sized CTRL English model | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| CamemBERT | ``camembert-base`` | | 12-layer, 768-hidden, 12-heads, 110M parameters | +| | | | CamemBERT using the BERT-base architecture | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| ALBERT | ``albert-base-v1`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | +| | | | ALBERT base model | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-large-v1`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | +| | | | ALBERT large model | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-xlarge-v1`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | +| | | | ALBERT xlarge model | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-xxlarge-v1`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | +| | | | ALBERT xxlarge model | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-base-v2`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | +| | | | ALBERT base model with no dropout, additional training data and longer training | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-large-v2`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | +| | | | ALBERT large model with no dropout, additional training data and longer training | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-xlarge-v2`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | +| | | | ALBERT xlarge model with no dropout, additional training data and longer training | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``albert-xxlarge-v2`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | +| | | | ALBERT xxlarge model with no dropout, additional training data and longer training | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| T5 | ``t5-small`` | | ~60M parameters with 6-layers, 512-hidden-state, 2048 feed-forward hidden-state, 8-heads, | +| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``t5-base`` | | ~220M parameters with 12-layers, 768-hidden-state, 3072 feed-forward hidden-state, 12-heads, | +| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``t5-large`` | | ~770M parameters with 24-layers, 1024-hidden-state, 4096 feed-forward hidden-state, 16-heads, | +| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``t5-3B`` | | ~2.8B parameters with 24-layers, 1024-hidden-state, 16384 feed-forward hidden-state, 32-heads, | +| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``t5-11B`` | | ~11B parameters with 24-layers, 1024-hidden-state, 65536 feed-forward hidden-state, 128-heads, | +| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| XLM-RoBERTa | ``xlm-roberta-base`` | | ~125M parameters with 12-layers, 768-hidden-state, 3072 feed-forward hidden-state, 8-heads, | +| | | | Trained on on 2.5 TB of newly created clean CommonCrawl data in 100 languages | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, | +| | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| FlauBERT | ``flaubert-small-cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters | +| | | | FlauBERT small architecture | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``flaubert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters | +| | | | FlauBERT base architecture with uncased vocabulary | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``flaubert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters | +| | | | FlauBERT base architecture with cased vocabulary | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``flaubert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters | +| | | | FlauBERT large architecture | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters | +| | | (see `details `_) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters | +| | | | bart-large base architecture with a classification head, finetuned on MNLI | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) | +| | | | bart-large base architecture finetuned on cnn summarization task | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters | +| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters | +| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``DialoGPT-medium`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | +| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | +| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From 7fdf16bc481e795f90b7b5a3516137032cf10191 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 02:06:09 +0200 Subject: [PATCH 154/162] Update reformer.rst --- docs/source/model_doc/reformer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index 4aea10c46327..a0d00433dfa4 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -45,7 +45,7 @@ Therefore the following holds: Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`. This design ensures that each position embedding vector :math:`x_j` is unique. -Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^14 + 2^15 \approx 49000` parameters. +Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^{14} + 2^{15} \approx 49000` parameters. In practice, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``. From 7ccec6a0af5777666a0d8e0b7ec808706949c647 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 20:01:42 +0200 Subject: [PATCH 155/162] applied morgans optimization --- src/transformers/modeling_reformer.py | 103 ++++++++++++-------------- 1 file changed, 46 insertions(+), 57 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1bd2ef98adaf..98fa0c6f9389 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -67,11 +67,12 @@ def mish(x): def _get_least_common_mult_chunk_len(config): attn_types = config.attn_layers - if len(set(attn_types)) == 1 and attn_types[0] == "lsh": + attn_types_set = set(attn_types) + if len(attn_types_set) == 1 and attn_types[0] == "lsh": return config.lsh_attn_chunk_length - elif len(set(attn_types)) == 1 and attn_types[0] == "local": + elif len(attn_types_set) == 1 and attn_types[0] == "local": return config.local_attn_chunk_length - elif len(set(attn_types)) == 2 and set(attn_types) == set(["lsh", "local"]): + elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) else: raise NotImplementedError( @@ -129,12 +130,12 @@ def forward(self, position_ids): if self.dropout > 0: weights = torch.cat(broadcasted_weights, dim=-1) # permute weights so that 2D correctly drops dims 1 and 2 - perm_weigthts = weights.permute(0, 3, 2, 1) + transposed_weights = weights.transpose(2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) - drop_permuted_weights = nn.functional.dropout2d(perm_weigthts, p=self.dropout, training=self.training) - drop_weights = drop_permuted_weights.permute(0, 3, 2, 1) + dropped_transposed_weights = nn.functional.dropout2d(transposed_weights, p=self.dropout, training=self.training) + dropped_weights = dropped_transposed_weights.transpose(2, 1) - position_encodings = torch.reshape(drop_weights, (batch_size, sequence_length, -1)) + position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1)) else: position_encodings = torch.cat( @@ -275,20 +276,6 @@ def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn else: raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) - def _merge_seq_length_dims(self, vectors, num_attn_heads): - """ - merges factorized sequence length dims into one dim - """ - batch_size = vectors.shape[0] - new_dim_shape = (batch_size, num_attn_heads, -1) - - if len(vectors.shape) == 5: - return torch.reshape(vectors, new_dim_shape + (vectors.shape[-1],)) - elif len(vectors.shape) == 4: - return torch.reshape(vectors, new_dim_shape) - else: - raise ValueError("Input vector rank should be one of [4, 5], but is: {}".format(len(vectors.shape))) - class LSHSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config): @@ -476,7 +463,7 @@ def _hash_vectors(self, vectors, num_hashes): rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x num_buckets/2 - random_rotations = torch.randn(rotations_shape, device=vectors.device).to(vectors.dtype) + random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) @@ -502,33 +489,38 @@ def _hash_vectors(self, vectors, num_hashes): # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. offsets = torch.arange(num_hashes, device=vectors.device) - offsets = torch.reshape(offsets * num_buckets, (-1, 1)) + offsets = (offsets * num_buckets).view((1, 1, -1, 1)) - # repeat same values for Batch_size and Num_Attn_Heads - offsets = offsets.repeat(batch_size, self.num_attention_heads, 1, 1) - offset_buckets = self._merge_seq_length_dims(buckets + offsets, self.num_attention_heads) + # expand to batch size and num attention heads + offsets = offsets.expand((batch_size, self.num_attention_heads) + offsets.shape[-2:]) + offset_buckets = (buckets + offsets).flatten(start_dim=2, end_dim=3) return offset_buckets def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): - batch_size = buckets.shape[0] + # no gradients are needed + with torch.no_grad(): + batch_size = buckets.shape[0] - orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device) - orig_indices = orig_indices.repeat(batch_size, self.num_attention_heads, 1) + # arange and expand + orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device).view(1, 1, -1) + orig_indices = orig_indices.expand(batch_size, self.num_attention_heads, orig_indices.shape[-1]) - # scale buckets - scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) + # scale buckets + scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) - # remove gradient - scaled_buckets = scaled_buckets.detach() + # remove gradient + scaled_buckets = scaled_buckets.detach() - # Hash-based sort - sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) - undo_sorted_bucket_idx = torch.argsort(sorted_bucket_idx, dim=-1) + # Hash-based sort + sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) - # remove gradient - sorted_bucket_idx = sorted_bucket_idx.detach() - undo_sorted_bucket_idx = undo_sorted_bucket_idx.detach() + # create simple indices to scatter to, to have undo sort + indices = torch.arange(sorted_bucket_idx.shape[-1]).view(1, 1, -1).expand(sorted_bucket_idx.shape) + + # get undo sort + undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) + undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) return sorted_bucket_idx, undo_sorted_bucket_idx @@ -618,8 +610,8 @@ def _attend( del value_vectors # merge chunk length - logits = self._merge_seq_length_dims(logits, self.num_attention_heads).squeeze(-1) - out_vectors = self._merge_seq_length_dims(out_vectors, self.num_attention_heads) + logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) return out_vectors, logits, attention_probs @@ -656,7 +648,7 @@ def _len_and_dim_norm(self, vectors): length and attention head size dim normalization """ vectors = self._len_norm(vectors) - vectors = vectors / torch.sqrt( + vectors = vectors * torch.rsqrt( torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype) ) return vectors @@ -666,7 +658,7 @@ def _len_norm(self, x, epsilon=1e-6): length normalization """ variance = torch.mean(x ** 2, -1, keepdim=True) - norm_x = x / torch.sqrt(variance + epsilon) + norm_x = x * torch.rsqrt(variance + epsilon) return norm_x def _gather_by_expansion(self, vectors, idxs, num_hashes): @@ -690,15 +682,14 @@ class ReverseSort(Function): @staticmethod def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): # save sorted_bucket_idx for backprop - ctx.sorted_bucket_idx = sorted_bucket_idx - ctx.num_hashes = num_hashes - - out_vectors = out_vectors.detach() - logits = logits.detach() - # undo sort to have correct order for next layer - expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) - out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) - logits = torch.gather(logits, 2, undo_sorted_bucket_idx) + with torch.no_grad(): + ctx.sorted_bucket_idx = sorted_bucket_idx + ctx.num_hashes = num_hashes + + # undo sort to have correct order for next layer + expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_sorted_bucket_idx) return out_vectors, logits @staticmethod @@ -715,11 +706,9 @@ def backward(ctx, grad_out_vectors, grad_logits): # split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen - grad_logits = torch.reshape(grad_logits, (grad_logits_shape[:2] + (num_hashes, -1))) + grad_logits = grad_logits.view((grad_logits_shape[:2] + (num_hashes, -1))) # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen - grad_out_vectors = torch.reshape( - grad_out_vectors, (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:]) - ) + grad_out_vectors = grad_out_vectors.view((grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])) # reshape and expand sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) @@ -866,7 +855,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_ del value_vectors # merge chunk length - out_vectors = self._merge_seq_length_dims(out_vectors, self.num_attention_heads) + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) From 3978afa4a3d6428742c6c8201658fcca01dc3555 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 20:01:59 +0200 Subject: [PATCH 156/162] make style --- src/transformers/modeling_reformer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 98fa0c6f9389..d7c87bb085be 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -132,7 +132,9 @@ def forward(self, position_ids): # permute weights so that 2D correctly drops dims 1 and 2 transposed_weights = weights.transpose(2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) - dropped_transposed_weights = nn.functional.dropout2d(transposed_weights, p=self.dropout, training=self.training) + dropped_transposed_weights = nn.functional.dropout2d( + transposed_weights, p=self.dropout, training=self.training + ) dropped_weights = dropped_transposed_weights.transpose(2, 1) position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1)) @@ -708,7 +710,9 @@ def backward(ctx, grad_out_vectors, grad_logits): # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen grad_logits = grad_logits.view((grad_logits_shape[:2] + (num_hashes, -1))) # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen - grad_out_vectors = grad_out_vectors.view((grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])) + grad_out_vectors = grad_out_vectors.view( + (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:]) + ) # reshape and expand sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) From 01b100692efe61549dd6a2bb0b04230d9ff7e71c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 20:06:40 +0200 Subject: [PATCH 157/162] make gpu compatible --- src/transformers/modeling_reformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index d7c87bb085be..193662564109 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -518,7 +518,11 @@ def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buc sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) # create simple indices to scatter to, to have undo sort - indices = torch.arange(sorted_bucket_idx.shape[-1]).view(1, 1, -1).expand(sorted_bucket_idx.shape) + indices = ( + torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) + .view(1, 1, -1) + .expand(sorted_bucket_idx.shape) + ) # get undo sort undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) From e983a69713b045cdafe3069dc230c4c8f7f97520 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 20:07:41 +0200 Subject: [PATCH 158/162] remove bogus file --- ~ | 301 -------------------------------------------------------------- 1 file changed, 301 deletions(-) delete mode 100644 ~ diff --git a/~ b/~ deleted file mode 100644 index 1333ec36d3ad..000000000000 --- a/~ +++ /dev/null @@ -1,301 +0,0 @@ -Pretrained models -================================================ - -Here is the full list of the currently provided pretrained models together with a short presentation of each model. - -For a list that includes community-uploaded models, refer to `https://huggingface.co/models `__. - -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Architecture | Shortcut name | Details of the model | -+===================+============================================================+=======================================================================================================================================+ -| BERT | ``bert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on lower-cased English text. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-uncased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | Trained on lower-cased English text. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased English text. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | Trained on cased English text. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-multilingual-uncased`` | | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias | -| | | (see `details `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-multilingual-cased`` | | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased text in the top 104 languages with the largest Wikipedias | -| | | (see `details `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-chinese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased Chinese Simplified and Traditional text. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-german-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased German text by Deepset.ai | -| | | (see `details on deepset.ai website `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-uncased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | Trained on lower-cased English text using Whole-Word-Masking | -| | | (see `details `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-cased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | Trained on cased English text using Whole-Word-Masking | -| | | (see `details `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD | -| | | (see details of fine-tuning in the `example section `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters | -| | | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD | -| | | (see `details of fine-tuning in the example section `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-cased-finetuned-mrpc`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | The ``bert-base-cased`` model fine-tuned on MRPC | -| | | (see `details of fine-tuning in the example section `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-german-dbmdz-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased German text by DBMDZ | -| | | (see `details on dbmdz repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-german-dbmdz-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on uncased German text by DBMDZ | -| | | (see `details on dbmdz repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. | -| | | | `MeCab `__ is required for tokenization. | -| | | (see `details on cl-tohoku repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. | -| | | | `MeCab `__ is required for tokenization. | -| | | (see `details on cl-tohoku repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on Japanese text. Text is tokenized into characters. | -| | | (see `details on cl-tohoku repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. | -| | | (see `details on cl-tohoku repository `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased Finnish text. | -| | | (see `details on turkunlp.org `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on uncased Finnish text. | -| | | (see `details on turkunlp.org `__). | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | Trained on cased Dutch text. | -| | | (see `details on wietsedv repository `__). | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | OpenAI GPT English model | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| GPT-2 | ``gpt2`` | | 12-layer, 768-hidden, 12-heads, 117M parameters. | -| | | | OpenAI GPT-2 English model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. | -| | | | OpenAI's Medium-sized GPT-2 English model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. | -| | | | OpenAI's Large-sized GPT-2 English model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``gpt2-xl`` | | 48-layer, 1600-hidden, 25-heads, 1558M parameters. | -| | | | OpenAI's XL-sized GPT-2 English model | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. | -| | | | English model trained on wikitext-103 | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| XLNet | ``xlnet-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | -| | | | XLNet English model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | -| | | | XLNet Large English model | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads | -| | | | XLM English model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | -| | | | XLM English-German model trained on the concatenation of English and German wikipedia | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads | -| | | | XLM English-French model trained on the concatenation of English and French wikipedia | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads | -| | | | XLM English-Romanian Multi-language model | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | -| | | | XLM Model pre-trained with MLM on the `15 XNLI languages `__. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-tlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | -| | | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages `__. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-clm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads | -| | | | XLM English-French model trained with CLM (Causal Language Modeling) on the concatenation of English and French wikipedia | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | -| | | | XLM English-German model trained with CLM (Causal Language Modeling) on the concatenation of English and German wikipedia | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-17-1280`` | | 16-layer, 1280-hidden, 16-heads | -| | | | XLM model trained with MLM (Masked Language Modeling) on 17 languages. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-100-1280`` | | 16-layer, 1280-hidden, 16-heads | -| | | | XLM model trained with MLM (Masked Language Modeling) on 100 languages. | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | -| | | | RoBERTa using the BERT-base architecture | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``roberta-large`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | -| | | | RoBERTa using the BERT-large architecture | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``roberta-large-mnli`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | -| | | | ``roberta-large`` fine-tuned on `MNLI `__. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilroberta-base`` | | 6-layer, 768-hidden, 12-heads, 82M parameters | -| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``roberta-base-openai-detector`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | -| | | | ``roberta-base`` fine-tuned by OpenAI on the outputs of the 1.5B-parameter GPT-2 model. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``roberta-large-openai-detector`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | -| | | | ``roberta-large`` fine-tuned by OpenAI on the outputs of the 1.5B-parameter GPT-2 model. | -| | | (see `details `__) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| DistilBERT | ``distilbert-base-uncased`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | -| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilbert-base-uncased-distilled-squad`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | -| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilbert-base-cased`` | | 6-layer, 768-hidden, 12-heads, 65M parameters | -| | | | The DistilBERT model distilled from the BERT model `bert-base-cased` checkpoint | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilbert-base-cased-distilled-squad`` | | 6-layer, 768-hidden, 12-heads, 65M parameters | -| | | | The DistilBERT model distilled from the BERT model `bert-base-cased` checkpoint, with an additional question answering layer. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilgpt2`` | | 6-layer, 768-hidden, 12-heads, 82M parameters | -| | | | The DistilGPT2 model distilled from the GPT2 model `gpt2` checkpoint. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilbert-base-german-cased`` | | 6-layer, 768-hidden, 12-heads, 66M parameters | -| | | | The German DistilBERT model distilled from the German DBMDZ BERT model `bert-base-german-dbmdz-cased` checkpoint. | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``distilbert-base-multilingual-cased`` | | 6-layer, 768-hidden, 12-heads, 134M parameters | -| | | | The multilingual DistilBERT model distilled from the Multilingual BERT model `bert-base-multilingual-cased` checkpoint. | -| | | (see `details `__) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters | -| | | | Salesforce's Large-sized CTRL English model | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| CamemBERT | ``camembert-base`` | | 12-layer, 768-hidden, 12-heads, 110M parameters | -| | | | CamemBERT using the BERT-base architecture | -| | | (see `details `__) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| ALBERT | ``albert-base-v1`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | -| | | | ALBERT base model | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-large-v1`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | -| | | | ALBERT large model | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-xlarge-v1`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | -| | | | ALBERT xlarge model | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-xxlarge-v1`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | -| | | | ALBERT xxlarge model | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-base-v2`` | | 12 repeating layers, 128 embedding, 768-hidden, 12-heads, 11M parameters | -| | | | ALBERT base model with no dropout, additional training data and longer training | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-large-v2`` | | 24 repeating layers, 128 embedding, 1024-hidden, 16-heads, 17M parameters | -| | | | ALBERT large model with no dropout, additional training data and longer training | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-xlarge-v2`` | | 24 repeating layers, 128 embedding, 2048-hidden, 16-heads, 58M parameters | -| | | | ALBERT xlarge model with no dropout, additional training data and longer training | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``albert-xxlarge-v2`` | | 12 repeating layer, 128 embedding, 4096-hidden, 64-heads, 223M parameters | -| | | | ALBERT xxlarge model with no dropout, additional training data and longer training | -| | | (see `details `__) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| T5 | ``t5-small`` | | ~60M parameters with 6-layers, 512-hidden-state, 2048 feed-forward hidden-state, 8-heads, | -| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``t5-base`` | | ~220M parameters with 12-layers, 768-hidden-state, 3072 feed-forward hidden-state, 12-heads, | -| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``t5-large`` | | ~770M parameters with 24-layers, 1024-hidden-state, 4096 feed-forward hidden-state, 16-heads, | -| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``t5-3B`` | | ~2.8B parameters with 24-layers, 1024-hidden-state, 16384 feed-forward hidden-state, 32-heads, | -| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``t5-11B`` | | ~11B parameters with 24-layers, 1024-hidden-state, 65536 feed-forward hidden-state, 128-heads, | -| | | | Trained on English text: the Colossal Clean Crawled Corpus (C4) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| XLM-RoBERTa | ``xlm-roberta-base`` | | ~125M parameters with 12-layers, 768-hidden-state, 3072 feed-forward hidden-state, 8-heads, | -| | | | Trained on on 2.5 TB of newly created clean CommonCrawl data in 100 languages | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, | -| | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| FlauBERT | ``flaubert-small-cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters | -| | | | FlauBERT small architecture | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``flaubert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters | -| | | | FlauBERT base architecture with uncased vocabulary | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``flaubert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters | -| | | | FlauBERT base architecture with cased vocabulary | -| | | (see `details `__) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``flaubert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters | -| | | | FlauBERT large architecture | -| | | (see `details `__) | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters | -| | | (see `details `_) | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters | -| | | | bart-large base architecture with a classification head, finetuned on MNLI | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) | -| | | | bart-large base architecture finetuned on cnn summarization task | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters | -| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters | -| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``DialoGPT-medium`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters | -| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | -| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | -| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | -| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky | -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ From 9ce32f099bc93cbeff0a310768bc20a6488387cf Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 4 May 2020 21:57:04 +0200 Subject: [PATCH 159/162] big test refactor --- tests/test_modeling_common.py | 9 +- tests/test_modeling_reformer.py | 1070 +++++++++++++------------------ 2 files changed, 456 insertions(+), 623 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a0a8ec74ade5..8c9c6a9f5a56 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -141,6 +141,7 @@ def test_attention_outputs(self): self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: self.assertListEqual( list(attentions[0].shape[-4:]), @@ -648,8 +649,8 @@ def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] - # max length of input_ids should be < max_length - input_ids = input_ids[..., :10] + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] # iterate over all generative models for model_class in self.all_generative_model_classes: @@ -693,8 +694,8 @@ def test_lm_head_model_random_beam_search_generate(self): torch_device ) - # max length of input_ids should be < max_length - input_ids = input_ids[..., :10] + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 460e698526cc..c79b212a8c59 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -34,646 +34,371 @@ import torch -@require_torch -class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () - test_pruning = False - test_headmasking = False - test_torchscript = False +class ReformerModelTester: + def __init__( + self, + parent, + batch_size=None, + seq_length=None, + is_training=None, + is_decoder=None, + use_input_mask=None, + vocab_size=None, + attention_head_size=None, + hidden_size=None, + num_attention_heads=None, + local_attn_chunk_length=None, + local_num_chunks_before=None, + local_num_chunks_after=None, + num_buckets=None, + num_hashes=1, + lsh_attn_chunk_length=None, + lsh_num_chunks_before=None, + lsh_num_chunks_after=None, + chunk_size_lm_head=None, + chunk_size_feed_forward=None, + feed_forward_size=None, + hidden_act=None, + hidden_dropout_prob=None, + local_attention_probs_dropout_prob=None, + lsh_attention_probs_dropout_prob=None, + max_position_embeddings=None, + initializer_range=None, + axial_norm_std=None, + layer_norm_eps=None, + axial_pos_embds=None, + axial_pos_shape=None, + axial_pos_embds_dim=None, + attn_layers=None, + pad_token_id=None, + eos_token_id=None, + scope=None, + hash_seed=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.is_decoder = is_decoder + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = len(attn_layers) + self.local_attn_chunk_length = local_attn_chunk_length + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before + self.num_hashes = num_hashes + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.scope = scope + self.attn_layers = attn_layers + self.pad_token_id = pad_token_id + self.hash_seed = hash_seed + + attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length + num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after + num_chunks_before = local_num_chunks_before if local_num_chunks_before is not None else lsh_num_chunks_before + + self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0) + self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length + self.chunk_length = attn_chunk_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = ReformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + feed_forward_size=self.feed_forward_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, + lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + is_decoder=self.is_decoder, + axial_pos_embds=self.axial_pos_embds, + axial_pos_shape=self.axial_pos_shape, + axial_pos_embds_dim=self.axial_pos_embds_dim, + local_attn_chunk_length=self.local_attn_chunk_length, + local_num_chunks_after=self.local_num_chunks_after, + local_num_chunks_before=self.local_num_chunks_before, + num_hashes=self.num_hashes, + num_buckets=self.num_buckets, + lsh_attn_chunk_length=self.lsh_attn_chunk_length, + lsh_num_chunks_after=self.lsh_num_chunks_after, + lsh_num_chunks_before=self.lsh_num_chunks_before, + attn_layers=self.attn_layers, + pad_token_id=self.pad_token_id, + hash_seed=self.hash_seed, + ) - class ReformerLocalAttnModelTester(object): - def __init__( - self, - parent, - batch_size=13, - seq_length=32, - is_training=True, - is_decoder=False, - use_input_mask=True, - vocab_size=32, - attention_head_size=16, - hidden_size=32, - num_attention_heads=2, - local_attn_chunk_length=4, - local_num_chunks_before=1, - local_num_chunks_after=0, - chunk_size_lm_head=0, - chunk_size_feed_forward=0, - feed_forward_size=32, - hidden_act="gelu", - hidden_dropout_prob=0.1, - local_attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - initializer_range=0.02, - axial_norm_std=1.0, - layer_norm_eps=1e-12, - axial_pos_embds=True, - axial_pos_shape=[4, 8], - axial_pos_embds_dim=[16, 16], - attn_layers=["local", "local", "local", "local"], - pad_token_id=0, - eos_token_id=2, - scope=None, - hash_seed=0, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.is_decoder = is_decoder - self.use_input_mask = use_input_mask - self.vocab_size = vocab_size - self.attention_head_size = attention_head_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = len(attn_layers) - self.local_attn_chunk_length = local_attn_chunk_length - self.local_num_chunks_after = local_num_chunks_after - self.local_num_chunks_before = local_num_chunks_before - self.hidden_act = hidden_act - self.feed_forward_size = feed_forward_size - self.hidden_dropout_prob = hidden_dropout_prob - self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.axial_pos_embds = axial_pos_embds - self.axial_pos_shape = tuple(axial_pos_shape) - self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) - self.axial_norm_std = axial_norm_std - self.chunk_size_lm_head = chunk_size_lm_head - self.chunk_size_feed_forward = chunk_size_feed_forward - self.scope = scope - self.attn_layers = attn_layers - self.pad_token_id = pad_token_id - self.hash_seed = hash_seed - - self.encoder_seq_length = seq_length // local_attn_chunk_length + ( - self.seq_length % local_attn_chunk_length != 0 - ) - self.key_length = ( - self.local_num_chunks_before + self.local_num_chunks_after + 1 - ) * local_attn_chunk_length - self.chunk_length = local_attn_chunk_length - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - - config = ReformerConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - feed_forward_size=self.feed_forward_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - is_decoder=self.is_decoder, - axial_pos_embds=self.axial_pos_embds, - axial_pos_shape=self.axial_pos_shape, - axial_pos_embds_dim=self.axial_pos_embds_dim, - local_attn_chunk_length=self.local_attn_chunk_length, - local_num_chunks_after=self.local_num_chunks_after, - local_num_chunks_before=self.local_num_chunks_before, - attn_layers=self.attn_layers, - pad_token_id=self.pad_token_id, - hash_seed=self.hash_seed, - ) - - return ( - config, - input_ids, - input_mask, - ) - - def check_loss_output(self, result): - self.parent.assertListEqual(list(result["loss"].size()), []) - - def create_and_check_reformer_model( - self, config, input_ids, input_mask, - ): - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - (sequence_output,) = model(input_ids, attention_mask=input_mask) - (sequence_output,) = model(input_ids) - - result = { - "sequence_output": sequence_output, - } - # 2 * hidden_size because we use reversible resnet layers - self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], - ) - - def create_and_check_reformer_model_with_lm_backward( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] - loss.backward() - - def create_and_check_reformer_with_lm( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) - result = { - "loss": loss, - "prediction_scores": prediction_scores, - } - self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], - ) - self.check_loss_output(result) - - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): - # no special position embeddings - config.axial_pos_embds = False - config.is_decoder = is_decoder - - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - # set all position encodings to zero so that postions don't matter - with torch.no_grad(): - embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) - embedding.weight.requires_grad = False - - half_seq_len = self.seq_length // 2 - roll = self.local_attn_chunk_length - roll = self.local_attn_chunk_length - half_input_ids = input_ids[:, :half_seq_len] - - # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) - input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - - # shifted padded - input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) - attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - - # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) - - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll - ] - - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): - config.is_decoder = is_decoder - layer = ReformerLayer(config).to(torch_device) - layer.train() - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - # get random tensors - hidden_states = floats_tensor(shape) - prev_attn_output = floats_tensor(shape) - - # now the random seeds for attention and feed forward is initialized - # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) - - next_attn_output = layer_outputs.attn_output - next_hidden_states = layer_outputs.hidden_states + return ( + config, + input_ids, + input_mask, + ) - torch.manual_seed(layer.attention_seed) - attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue( - torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) - ) + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) - torch.manual_seed(layer.feed_forward_seed) - feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue( - torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) - ) - - def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] - - config.chunk_size_lm_head = input_ids.shape[-1] - config.chunk_size_feed_forward = input_ids.shape[-1] - - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - - hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) - - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): - model = ReformerModel(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask,) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + def create_and_check_reformer_model( + self, config, input_ids, input_mask, + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + (sequence_output,) = model(input_ids, attention_mask=input_mask) + (sequence_output,) = model(input_ids) - def setUp(self): - self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + result = { + "sequence_output": sequence_output, + } + # 2 * hidden_size because we use reversible resnet layers + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], + ) - def test_config(self): - self.config_tester.run_common_tests() + def create_and_check_reformer_model_with_lm_backward( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] + loss.backward() - def test_reformer_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def create_and_check_reformer_with_lm( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], + ) + self.check_loss_output(result) - def test_reformer_lm_model_backward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs) + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + # no special position embeddings + config.axial_pos_embds = False + config.is_decoder = is_decoder - def test_reformer_model_attn_masking(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + if self.lsh_attn_chunk_length is not None: + # need to set chunk length equal sequence length to be certain that chunking works + config.lsh_attn_chunk_length = self.seq_length - def test_reformer_with_lm(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + # set all position encodings to zero so that postions don't matter + with torch.no_grad(): + embedding = model.embeddings.position_embeddings.embedding + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) + embedding.weight.requires_grad = False - def test_reformer_layer_training_dropout(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + half_seq_len = self.seq_length // 2 + roll = self.chunk_length - def test_reformer_chunking_forward_equality(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + half_input_ids = input_ids[:, :half_seq_len] - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_reformer_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + # normal padded + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) + input_ids_padded = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + # shifted padded + input_ids_roll = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) + attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) + + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll] + + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + config.is_decoder = is_decoder + layer = ReformerLayer(config).to(torch_device) + layer.train() + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + # get random tensors + hidden_states = floats_tensor(shape) + prev_attn_output = floats_tensor(shape) + + # now the random seeds for attention and feed forward is initialized + # forward tensors with dropout + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + + next_attn_output = layer_outputs.attn_output + next_hidden_states = layer_outputs.hidden_states + + torch.manual_seed(layer.attention_seed) + attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) + self.parent.assertTrue( + torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) + ) -@require_torch -class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () - test_pruning = False - test_headmasking = False - test_torchscript = False + torch.manual_seed(layer.feed_forward_seed) + feed_forward_hidden_states = layer.feed_forward(next_attn_output) + self.parent.assertTrue( + torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) + ) - class ReformerLSHAttnModelTester(object): - def __init__( - self, - parent, - batch_size=13, - seq_length=13, - use_input_mask=True, - is_training=False, - is_decoder=False, - vocab_size=32, - attention_head_size=16, - hidden_size=64, - num_attention_heads=2, - num_buckets=2, - num_hashes=4, - lsh_attn_chunk_length=4, - lsh_num_chunks_before=2, - lsh_num_chunks_after=3, - chunk_size_lm_head=5, - chunk_size_feed_forward=6, - feed_forward_size=32, - hidden_act="relu", - hidden_dropout_prob=0.1, - lsh_attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - initializer_range=0.02, - axial_norm_std=1.0, - layer_norm_eps=1e-12, - axial_pos_embds=True, - axial_pos_shape=[4, 8], - axial_pos_embds_dim=[16, 48], - attn_layers=["lsh", "lsh", "lsh", "lsh"], - pad_token_id=0, - eos_token_id=2, - scope=None, - hash_seed=0, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.is_decoder = is_decoder - self.use_input_mask = use_input_mask - self.vocab_size = vocab_size - self.attention_head_size = attention_head_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hashes = num_hashes - self.num_hidden_layers = len(attn_layers) - self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets - self.lsh_attn_chunk_length = lsh_attn_chunk_length - self.lsh_num_chunks_after = lsh_num_chunks_after - self.lsh_num_chunks_before = lsh_num_chunks_before - self.hidden_act = hidden_act - self.feed_forward_size = feed_forward_size - self.hidden_dropout_prob = hidden_dropout_prob - self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.chunk_size_lm_head = chunk_size_lm_head - self.chunk_size_feed_forward = chunk_size_feed_forward - self.scope = scope - self.attn_layers = attn_layers - self.hash_seed = hash_seed - self.pad_token_id = pad_token_id - self.axial_pos_embds = axial_pos_embds - self.axial_pos_shape = tuple(axial_pos_shape) - self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) - self.axial_norm_std = axial_norm_std - - self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) - self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length - self.chunk_length = lsh_attn_chunk_length - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - - config = ReformerConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - feed_forward_size=self.feed_forward_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - is_decoder=self.is_decoder, - axial_pos_embds=self.axial_pos_embds, - axial_pos_shape=self.axial_pos_shape, - axial_pos_embds_dim=self.axial_pos_embds_dim, - num_hashes=self.num_hashes, - num_buckets=self.num_buckets, - lsh_attn_chunk_length=self.lsh_attn_chunk_length, - lsh_num_chunks_after=self.lsh_num_chunks_after, - lsh_num_chunks_before=self.lsh_num_chunks_before, - attn_layers=self.attn_layers, - hash_seed=self.hash_seed, - pad_token_id=self.pad_token_id, - ) - - return ( - config, - input_ids, - input_mask, - ) - - def check_loss_output(self, result): - self.parent.assertListEqual(list(result["loss"].size()), []) - - def create_and_check_reformer_model( - self, config, input_ids, input_mask, - ): - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - (sequence_output,) = model(input_ids, attention_mask=input_mask) - (sequence_output,) = model(input_ids) - - result = { - "sequence_output": sequence_output, - } - # 2 * hidden_size because we use reversible resnet layers - self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], - ) - - def create_and_check_reformer_model_with_lm_backward( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] - loss.backward() - - def create_and_check_reformer_with_lm( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) - result = { - "loss": loss, - "prediction_scores": prediction_scores, - } - self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], - ) - self.check_loss_output(result) - - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): - # no special position embeddings - config.axial_pos_embds = False - config.is_decoder = is_decoder + def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] - # need to set chunk length equal sequence length to be certain that chunking works - config.lsh_attn_chunk_length = self.seq_length + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - # set all position encodings to zero so that postions don't matter - - with torch.no_grad(): - embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) - embedding.weight.requires_grad = False - - half_seq_len = self.seq_length // 2 - roll = self.lsh_attn_chunk_length - roll = half_seq_len - half_input_ids = input_ids[:, :half_seq_len] - - # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) - input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - - # shifted padded - input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) - attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll - ] - - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): - config.is_decoder = is_decoder - layer = ReformerLayer(config).to(torch_device) - layer.train() - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - # get random tensors - hidden_states = floats_tensor(shape) - prev_attn_output = floats_tensor(shape) - - # now the random seeds for attention and feed forward is initialized - # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) - - next_attn_output = layer_outputs.attn_output - next_hidden_states = layer_outputs.hidden_states + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() - torch.manual_seed(layer.attention_seed) - attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue( - torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) - ) + hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) + + def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): + if not self.is_training: + return + # disable dropout + config.hidden_dropout_prob = 0 + config.local_attention_probs_dropout_prob = 0 + config.lsh_attention_probs_dropout_prob = 0 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_no_chunk.backward() + grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_chunk.backward() + grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) + ) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) + ) + + def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): + layer = ReformerLayer(config).to(torch_device) + layer.train() + + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + hidden_states = floats_tensor(shape) + attn_output = floats_tensor(shape) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.attention_seed) + seeds.append(layer.attention_seed) + self.parent.assertGreater(len(set(seeds)), 70) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.feed_forward_seed) - feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue( - torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) - ) - - @slow - def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): - layer = ReformerLayer(config).to(torch_device) - layer.train() - - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - hidden_states = floats_tensor(shape) - attn_output = floats_tensor(shape) - - seeds = [] - for _ in range(100): - layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) - attn_output = layer_outputs.attn_output - hidden_states = layer_outputs.hidden_states - torch.manual_seed(layer.attention_seed) - seeds.append(layer.attention_seed) - self.parent.assertGreater(len(set(seeds)), 70) - - seeds = [] - for _ in range(100): - layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) - attn_output = layer_outputs.attn_output - hidden_states = layer_outputs.hidden_states - torch.manual_seed(layer.feed_forward_seed) - seeds.append(layer.feed_forward_seed) - self.parent.assertGreater(len(set(seeds)), 70) - - def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): - torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.train() - model.zero_grad() - loss_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] - loss_no_chunk.backward() - grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] - grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] - - config.chunk_size_lm_head = input_ids.shape[-1] - config.chunk_size_feed_forward = input_ids.shape[-1] - - torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.train() - model.zero_grad() - loss_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] - loss_chunk.backward() - grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] - grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] - self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) - self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) - self.parent.assertTrue( - torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) - ) - self.parent.assertTrue( - torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) - ) - - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): - model = ReformerModel(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) - self.parent.assertFalse(torch.isnan(output).any().item()) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask,) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + seeds.append(layer.feed_forward_seed) + self.parent.assertGreater(len(set(seeds)), 70) - def setUp(self): - self.model_tester = ReformerLSHAttnModelTest.ReformerLSHAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) + self.parent.assertFalse(torch.isnan(output).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask,) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class ReformerTesterMixin: + """ + Reformer Local and Reformer LSH run essentially the same tests + """ def test_config(self): self.config_tester.run_common_tests() @@ -700,6 +425,15 @@ def test_reformer_layer_training_dropout(self): self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + def test_reformer_chunking_forward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + + def test_reformer_chunking_backward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) + + @slow def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) @@ -714,6 +448,54 @@ def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + +@require_torch +class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 32, + "is_training": True, + "is_decoder": False, + "use_input_mask": True, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "local_attn_chunk_length": 4, + "local_num_chunks_before": 1, + "local_num_chunks_after": 0, + "chunk_size_lm_head": 0, + "chunk_size_feed_forward": 0, + "feed_forward_size": 32, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "local_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 16], + "attn_layers": ["local", "local", "local", "local"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + @slow def test_model_from_pretrained(self): for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: @@ -721,6 +503,56 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) +@require_torch +class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 13, + "use_input_mask": True, + "is_training": False, + "is_decoder": False, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 64, + "num_attention_heads": 2, + "num_buckets": 2, + "num_hashes": 4, + "lsh_attn_chunk_length": 4, + "lsh_num_chunks_before": 2, + "lsh_num_chunks_after": 3, + "chunk_size_lm_head": 5, + "chunk_size_feed_forward": 6, + "feed_forward_size": 32, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "lsh_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 48], + "attn_layers": ["lsh", "lsh", "lsh", "lsh"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + @require_torch class ReformerIntegrationTests(unittest.TestCase): """ From 67f02c05da0654eee4a754b8823a87481e7eba92 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 22:16:47 +0200 Subject: [PATCH 160/162] add example for chunking --- src/transformers/modeling_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 821f322ac4e0..02ae2240bc33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2115,6 +2115,18 @@ def apply_chunking_to_forward( input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked Returns: a Tensor with the same shape the foward_fn would have given if applied + + + Examples:: + + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) """ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) From 4e7252a271a13a1ef05aca13af3f402fec0fcea7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 May 2020 22:18:05 +0200 Subject: [PATCH 161/162] fix typo --- src/transformers/configuration_reformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index db674b6f079b..1a0a04777695 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -85,7 +85,7 @@ class ReformerConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. is_decoder (:obj:`bool`, optional, defaults to False): If `is_decoder` is True, a causal mask is used in addition to `attention_mask`. - When using the Reformer for casaul language modeling, `is_decoder` is set to `True`. + When using the Reformer for causal language modeling, `is_decoder` is set to `True`. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. local_chunk_length (:obj:`int`, optional, defaults to 64): From ca4dab39f71f40eb6967cd29fc6ebd29b0bd53d9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 May 2020 10:10:44 +0200 Subject: [PATCH 162/162] add to README --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 37094826edd3..136d243095a4 100644 --- a/README.md +++ b/README.md @@ -163,8 +163,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training 16. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). -19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. +19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Ɓukasz Kaiser, Anselm Levskaya. +20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). +21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).