diff --git a/README.md b/README.md index 37094826edd3..136d243095a4 100644 --- a/README.md +++ b/README.md @@ -163,8 +163,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training 16. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). -19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. +19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. +20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). +21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index cfd8c50dd6bd..7f8fbefc052c 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -143,3 +143,14 @@ positional embeddings. Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models use other types of positional embeddings, such as sinusoidal position embeddings or relative position embeddings. + + +Feed Forward Chunking +-------------------------- + +In transformers two feed forward layers usually follows the self attention layer in each residual attention block. The intermediate embedding size of the feed forward layers is often bigger than the hidden size of the model (*e.g.* for ``bert-base-uncased``). + +For an input of size ``[batch_size, sequence_length]``, the memory required to store the intermediate feed forward embeddings ``[batch_size, sequence_length, config.intermediate_size]`` can account for a large fraction of the memory use. The authors of `Reformer: The Efficient Transformer `_ noticed that since the computation is independent of the ``sequence_length`` dimension, it is mathematically equivalent to compute the output embeddings of both feed forward layers ``[batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n`` individually and concat them afterward to ``[batch_size, sequence_length, config.hidden_size]`` with ``n = sequence_length``, which trades increased computation time against reduced memory use, but yields a mathematically **equivalent** result. + +For models employing the function :func:`~.transformers.apply_chunking_to_forward`, the ``chunk_size`` defines the number of output embeddings that are computed in parallel and thus defines the trade-off between memory and time complexity. +If ``chunk_size`` is set to 0, no feed forward chunking is done. diff --git a/docs/source/index.rst b/docs/source/index.rst index 32c71fba9a58..9d9005c1a8e2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,3 +107,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train model_doc/t5 model_doc/electra model_doc/dialogpt + model_doc/reformer diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index 6e3da45bc2df..0c5ef99d21d9 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -14,6 +14,12 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav .. autoclass:: transformers.PreTrainedModel :members: +``Helper Functions`` +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformers.apply_chunking_to_forward + + ``TFPreTrainedModel`` ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst new file mode 100644 index 000000000000..a0d00433dfa4 --- /dev/null +++ b/docs/source/model_doc/reformer.rst @@ -0,0 +1,114 @@ +Reformer +---------------------------------------------------- +**DISCLAIMER:** This model is still a work in progress, if you see something strange, +file a `Github Issue `_ + +Overview +~~~~~ +The Reformer model was presented in `Reformer: The Efficient Transformer `_ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. +Here the abstract: + +*Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.* + +The Authors' code can be found `here `_ . + +Axial Positional Encodings +~~~~~~~~~~~~~~~~~~~~ +Axial Positional Encodings were first implemented in Google's `trax library `_ and developed by the authors of this model's paper. In models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` being the ``config.hidden_size`` for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix: + +.. math:: + X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right] + +which alone has over 500M parameters to store. Axial positional encodings factorize :math:`X_{i,j}` into two matrices: + +.. math:: + X^{1}_{i,j}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j \in \left[1,\ldots, n_s^1\right] + +and + +.. math:: + X^{2}_{i,j}, \text{ with } i \in \left[1,\ldots, d^2\right] \text{ and } j \in \left[1,\ldots, n_s^2\right] + +with: + +.. math:: + d = d^1 + d^2 \text{ and } n_s = n_s^1 \times n_s^2 . + +Therefore the following holds: + +.. math:: + X_{i,j} = \begin{cases} + X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\ + X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor + \end{cases} + +Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`. +This design ensures that each position embedding vector :math:`x_j` is unique. + +Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^{14} + 2^{15} \approx 49000` parameters. + +In practice, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``. + + + +LSH Self Attention +~~~~~~~~~~~~~~~~~~~~ +In Locality sensitive hashing (LSH) self attention the key and query projection weights are tied. Therefore, the key query embedding vectors are also tied. +LSH self attention uses the locality sensitive +hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance `_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the more "similar" key query embedding vectors (in terms of *cosine similarity*) are to each other, the more likely they are assigned to the same bucket. +The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention. +The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighboring chunks and ``config.lsh_num_chunks_after`` following neighboring chunks. +For more information, see the `original Paper `_ or this great `blog post `_. + +Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory. + +It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly. + +Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. + + +Local Self Attention +~~~~~~~~~~~~~~~~~~~~ +Local self attention is essentially a "normal" self attention layer with +key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighboring chunks and ``config.local_num_chunks_after`` following neighboring chunks. + +Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. + + +Training +~~~~~~~~~~~~~~~~~~~~ +During training, we must ensure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens. +For training, the ``ReformerModelWithLMHead`` should be used as follows: + +:: + + input_ids = tokenizer.encode('This is a sentence from the training data', return_tensors='pt') + loss = model(input_ids, labels=input_ids)[0] + + +ReformerConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerConfig + :members: + + +ReformerTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerTokenizer + :members: + + +ReformerModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerModel + :members: + + +ReformerModelWithLMHead +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerModelWithLMHead + :members: diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 35b92ad219e7..1333ec36d3ad 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,3 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c869fa55d0e0..01b868d726be 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -47,6 +47,7 @@ from .configuration_marian import MarianConfig from .configuration_mmbt import MMBTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig +from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig @@ -138,6 +139,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast +from .tokenization_reformer import ReformerTokenizer from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast @@ -159,7 +161,7 @@ # Modeling if is_torch_available(): - from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering + from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering, apply_chunking_to_forward from .modeling_auto import ( AutoModel, AutoModelForPreTraining, @@ -190,6 +192,7 @@ BertForQuestionAnswering, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BertLayer, ) from .modeling_openai import ( OpenAIGPTPreTrainedModel, @@ -320,6 +323,14 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ) + from .modeling_reformer import ( + ReformerAttention, + ReformerLayer, + ReformerModel, + ReformerModelWithLMHead, + REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, + ) + # Optimization from .optimization import ( AdamW, diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 60e8e8d05e38..5238427d3b67 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -34,12 +34,18 @@ def gelu_new(x): else: gelu = F.gelu + +def gelu_fast(x): + return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + + ACT2FN = { "relu": F.relu, "swish": swish, "gelu": gelu, "tanh": torch.tanh, "gelu_new": gelu_new, + "gelu_fast": gelu_fast, } diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index ece2289c60f8..59272a86cca7 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -29,6 +29,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig +from .configuration_reformer import ReformerConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig @@ -73,6 +74,7 @@ ("camembert", CamembertConfig,), ("xlm-roberta", XLMRobertaConfig,), ("bart", BartConfig,), + ("reformer", ReformerConfig,), ("roberta", RobertaConfig,), ("flaubert", FlaubertConfig,), ("bert", BertConfig,), @@ -130,6 +132,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model) + - contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model) - contains `bert`: :class:`~transformers.BertConfig` (Bert model) - contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model) - contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py new file mode 100644 index 000000000000..1a0a04777695 --- /dev/null +++ b/src/transformers/configuration_reformer.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Reformer model configuration """ + + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json" +} + + +class ReformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ReformerModel`. + It is used to instantiate an Reformer model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + Args: + attention_head_size (:obj:`int`, optional, defaults to 64): + Dimensionality of the projected key, query and value vectors + attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]): + List of attention layer types in ascending order. It can be chosen between a + LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local"). + For more information on LSHSelfAttention layer, see `LSH Self Attention `__ . + For more information on LocalSelfAttention layer, see `Local Self Attention `__ . + axial_pos_embds (:obj:`bool`, optional, defaults to True): + If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings `__ + axial_norm_std (:obj:`float`, optional, defaluts to 1.0): + The standard deviation of the normal_initializer for initializing the weight matrices of the axial positional encodings. + axial_pos_shape (:obj:`list(int)`, optional, defaults to `[64, 64]`): + The position dims of the axial position encodings. + During training the product of the position dims has to equal the sequence length. + For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. + axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to `[64, 192]`): + The embedding dims of the axial position encodings. + The sum of the embedding dims has to equal the hidden size. + For more information on how axial position embeddings work, see `Axial Position Encodings `__ncodings. + chunk_size_lm_head (:obj:`int`, optional, defaults to 0): + The chunk size of the final language model feed forward head layer. + A chunk size of 0 means that the feed forward layer is not chunked. + A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. + For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . + chunk_size_feed_forward (:obj:`int`, optional, defaults to 0): + The chunk size of all feed forward layers in the residual attention blocks. + A chunk size of 0 means that the feed forward layer is not chunked. + A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time. + For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ . + eos_token_id (:obj:`int`, optional, defaults to 2): + The token id for the token. + feed_forward_size (:obj:`int`, optional, defaults to 512): + Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the residual attention block. + hash_seed (:obj:`int`, optional, defaults to `None`): + Seed that can be used to make local sensitive hashing in LSHSelfAttention deterministic. This should only be set for testing purposed. For evaluation and training purposes `hash_seed` should be set to `None` to ensure fully random rotations in local sensitive hashing scheme. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"): + The non-linear activation function (function or string) in the feed forward layer in the residual attention block. + If string, "gelu", "relu", "swish", "gelu_new" and "gelu_fast" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.05): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + hidden_size (:obj:`int`, optional, defaults to 256): + Dimensionality of the output hidden states of the residual attention blocks. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + is_decoder (:obj:`bool`, optional, defaults to False): + If `is_decoder` is True, a causal mask is used in addition to `attention_mask`. + When using the Reformer for causal language modeling, `is_decoder` is set to `True`. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + local_chunk_length (:obj:`int`, optional, defaults to 64): + Length of chunk which attends to itself in LocalSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention). + local_num_chunks_before (:obj:`int`, optional, defaults to 1): + Number of previous neighbouring chunks to attend to in LocalSelfAttention layer to itself. + local_num_chunks_after (:obj:`int`, optional, defaults to 0): + Number of following neighbouring chunks to attend to in LocalSelfAttention layer in addition to itself. + local_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities in LocalSelfAttention. + lsh_chunk_length (:obj:`int`, optional, defaults to 64): + Length of chunk which attends to itself in LSHSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention). + lsh_num_chunks_before (:obj:`int`, optional, defaults to 1): + Number of previous neighbouring chunks to attend to in LSHSelfAttention layer to itself. + lsh_num_chunks_after (:obj:`int`, optional, defaults to 0): + Number of following neighbouring chunks to attend to in LSHSelfAttention layer to itself. + lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities in LSHSelfAttention. + max_position_embeddings (:obj:`int`, optional, defaults to 4096): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`): + Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`. + The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors. + The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. + num_hashes (:obj:`int`, optional, defaults to 1): + Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme. + The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes. + pad_token_id (:obj:`int`, optional, defaults to 0): + The token id for the token. + vocab_size (:obj:`int`, optional, defaults to 320): + Vocabulary size of the Reformer model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`. + + Example:: + + from transformers import ReformerModel, ReformerConfig + + # Initializing a Reformer configuration + configuration = ReformerConfig() + + # Initializing a Reformer model + model = ReformerModel(configuration) + + # Accessing the model configuration + configuration = model.config + + Attributes: + pretrained_config_archive_map (Dict[str, str]): + A dictionary containing all the available pre-trained checkpoints. + """ + pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "reformer" + + def __init__( + self, + attention_head_size=64, + attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"], + axial_norm_std=1.0, + axial_pos_embds=True, + axial_pos_shape=[64, 64], + axial_pos_embds_dim=[64, 192], + chunk_size_lm_head=0, + chunk_size_feed_forward=0, + eos_token_id=2, + feed_forward_size=512, + hash_seed=None, + hidden_act="relu", + hidden_dropout_prob=0.05, + hidden_size=256, + initializer_range=0.02, + is_decoder=False, + layer_norm_eps=1e-12, + local_num_chunks_before=1, + local_num_chunks_after=0, + local_attention_probs_dropout_prob=0.05, + local_attn_chunk_length=64, + lsh_attn_chunk_length=64, + lsh_attention_probs_dropout_prob=0.0, + lsh_num_chunks_before=1, + lsh_num_chunks_after=0, + max_position_embeddings=4096, + num_attention_heads=2, + num_buckets=32, + num_hashes=1, + pad_token_id=0, + vocab_size=320, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs) + + self.hash_seed = hash_seed + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hashes = num_hashes + self.num_hidden_layers = len(attn_layers) + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.local_attn_chunk_length = local_attn_chunk_length + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.attn_layers = attn_layers diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..3d88e3d78d13 --- /dev/null +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Reformer checkpoint.""" + + +import argparse +import logging +import pickle + +import numpy as np +import torch + +from transformers import ReformerConfig, ReformerModelWithLMHead + + +logging.basicConfig(level=logging.INFO) + + +def set_param(torch_layer, weight, bias=None): + # set parameter of one layer + assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) + torch_layer.weight = torch.nn.Parameter(weight) + if bias is not None: + assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) + torch_layer.bias = torch.nn.Parameter(bias) + + +def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + np_query_key = np.asarray(weights[0]) + np_value = np.asarray(weights[1]) + np_dense = np.asarray(weights[2]) + + set_param( + torch_layer.self_attention.query_key, + torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) + + +def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size): + # set torch weights for 1-to-1 comparison + np_query = np.asarray(weights[0]) + np_key = np.asarray(weights[1]) + np_value = np.asarray(weights[2]) + np_dense = np.asarray(weights[3]) + + set_param( + torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), + ) + set_param( + torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), + ) + + +def set_block_weights_in_torch(weights, torch_block, hidden_size): + # layernorm 1 + layer_norm_1 = weights[0][0][0] + layer_norm_1_weight = np.asarray(layer_norm_1[0]) + layer_norm_1_bias = np.asarray(layer_norm_1[1]) + set_param( + torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), + ) + + # lsh weights + output + attn_weights = weights[0][1] + if len(attn_weights) < 4: + set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size) + else: + set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) + + # intermediate weighs + intermediate_weights = weights[2][0][2][2] + + # Chunked Feed Forward + if len(intermediate_weights) == 4: + intermediate_weights = intermediate_weights[2] + + # layernorm 2 + layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) + layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) + set_param( + torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), + ) + + # intermediate dense + inter_dense_weight = np.asarray(intermediate_weights[1][0]) + inter_dense_bias = np.asarray(intermediate_weights[1][1]) + set_param( + torch_block.feed_forward.dense.dense, + torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(inter_dense_bias), + ) + + # intermediate out + out_dense_weight = np.asarray(intermediate_weights[4][0]) + out_dense_bias = np.asarray(intermediate_weights[4][1]) + set_param( + torch_block.feed_forward.output.dense, + torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), + torch.tensor(out_dense_bias), + ) + + +def set_model_weights_in_torch(weights, torch_model, hidden_size): + # reformer model + torch_model_reformer = torch_model.reformer + + # word embeds + word_embeddings = np.asarray(weights[1]) + set_param( + torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), + ) + + if isinstance(weights[3], tuple): + position_embeddings = torch_model_reformer.embeddings.position_embeddings + for emb_idx in range(len(position_embeddings.weights)): + emb_weights = np.asarray(weights[3][emb_idx][0]) + assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( + position_embeddings[emb_idx] + ) + position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) + + trax_layer_weights = weights[5] + assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( + trax_layer_weights + ), "HF and trax model do not have the same number of layers" + for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): + block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] + set_block_weights_in_torch(block_weights, layer, hidden_size) + + # output weights + out_weights = weights[6] + + # output layer norm + layer_norm_out_weight = np.asarray(out_weights[0][0]) + layer_norm_out_bias = np.asarray(out_weights[0][1]) + set_param( + torch_model_reformer.encoder.layer_norm, + torch.tensor(layer_norm_out_weight), + torch.tensor(layer_norm_out_bias), + ) + + # output embeddings + output_embed_weights = np.asarray(out_weights[2][0]) + output_embed_bias = np.asarray(out_weights[2][1]) + set_param( + torch_model.lm_head.decoder, + torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), + torch.tensor(output_embed_bias), + ) + + +def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = ReformerConfig.from_json_file(config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = ReformerModelWithLMHead(config) + + with open(trax_model_pkl_path, "rb") as f: + model_weights = pickle.load(f)["weights"] + + set_model_weights_in_torch(model_weights, model, config.hidden_size) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained Reformer model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index c54b9c719840..82c1e2532315 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -31,6 +31,7 @@ FlaubertConfig, GPT2Config, OpenAIGPTConfig, + ReformerConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -97,6 +98,7 @@ ) from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel +from .modeling_reformer import ReformerModel, ReformerModelWithLMHead from .modeling_roberta import ( ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, @@ -179,6 +181,7 @@ (XLMConfig, XLMModel), (CTRLConfig, CTRLModel), (ElectraConfig, ElectraModel), + (ReformerConfig, ReformerModel), ] ) @@ -222,6 +225,7 @@ (CTRLConfig, CTRLLMHeadModel), (ElectraConfig, ElectraForMaskedLM), (EncoderDecoderConfig, EncoderDecoderModel), + (ReformerConfig, ReformerModelWithLMHead), ] ) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py new file mode 100644 index 000000000000..193662564109 --- /dev/null +++ b/src/transformers/modeling_reformer.py @@ -0,0 +1,1763 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REFORMER model. """ + +import logging +import sys +from collections import namedtuple +from functools import reduce +from operator import mul + +import numpy as np +import torch +from torch import nn +from torch.autograd.function import Function +from torch.nn import CrossEntropyLoss + +from .activations import gelu, gelu_fast, gelu_new, swish +from .configuration_reformer import ReformerConfig +from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, apply_chunking_to_forward + + +logger = logging.getLogger(__name__) + +REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin" +} + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = { + "gelu": gelu, + "relu": torch.nn.functional.relu, + "swish": swish, + "gelu_new": gelu_new, + "gelu_fast": gelu_fast, + "mish": mish, +} + + +# Define named tuples for nn.Modules here +LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) +AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) +ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) +ReformerBackwardOutput = namedtuple( + "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] +) +ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"]) + + +def _get_least_common_mult_chunk_len(config): + attn_types = config.attn_layers + attn_types_set = set(attn_types) + if len(attn_types_set) == 1 and attn_types[0] == "lsh": + return config.lsh_attn_chunk_length + elif len(attn_types_set) == 1 and attn_types[0] == "local": + return config.local_attn_chunk_length + elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): + return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) + else: + raise NotImplementedError( + "Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format( + config.attn_layers + ) + ) + + +class AxialPositionEmbeddings(nn.Module): + """Constructs axial position embeddings. Useful for very long input + sequences to save memory and time. + """ + + def __init__(self, config): + super().__init__() + self.axial_pos_shape = config.axial_pos_shape + self.axial_pos_embds_dim = config.axial_pos_embds_dim + self.dropout = config.hidden_dropout_prob + + self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config) + self.weights = nn.ParameterList() + + assert ( + sum(self.axial_pos_embds_dim) == config.hidden_size + ), "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format( + self.axial_pos_embds_dim, config.hidden_size + ) + + # create weights + for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): + # create expanded shapes + ax_shape = [1] * len(self.axial_pos_shape) + ax_shape[axis] = self.axial_pos_shape[axis] + ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,) + + # create tensor and init + self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) + + def forward(self, position_ids): + # broadcast weights to correct shape + batch_size = position_ids.shape[0] + sequence_length = position_ids.shape[1] + + broadcasted_weights = [ + weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights + ] + + if self.training is True: + assert ( + reduce(mul, self.axial_pos_shape) == sequence_length + ), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format( + self.axial_pos_shape, sequence_length + ) + if self.dropout > 0: + weights = torch.cat(broadcasted_weights, dim=-1) + # permute weights so that 2D correctly drops dims 1 and 2 + transposed_weights = weights.transpose(2, 1) + # drop entire matrix of last two dims (prev dims 1 and 2) + dropped_transposed_weights = nn.functional.dropout2d( + transposed_weights, p=self.dropout, training=self.training + ) + dropped_weights = dropped_transposed_weights.transpose(2, 1) + + position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1)) + + else: + position_encodings = torch.cat( + [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], + dim=-1, + ) + + else: + assert ( + reduce(mul, self.axial_pos_shape) >= sequence_length + ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( + self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, + ) + + # reshape axial encodings and use only until sequence_length + position_encodings = torch.cat(broadcasted_weights, dim=-1) + position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[ + :, :sequence_length + ] + + return position_encodings + + +class PositionEmbeddings(nn.Module): + """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`. + """ + + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + def forward(self, position_ids): + position_embeddings = self.embedding(position_ids) + position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training) + return position_embeddings + + +class ReformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.max_position_embeddings = config.max_position_embeddings + self.dropout = config.hidden_dropout_prob + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = ( + AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) + ) + + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + + seq_length = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + assert ( + position_ids.shape[-1] <= self.max_position_embeddings + ), "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format( + position_ids.shape[-1], self.max_position_embeddings + ) + + # dropout + embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training) + + # add positional embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + + +class EfficientAttentionMixin: + """ + A few utilities for nn.Modules in Reformer, to be used as a mixin. + """ + + def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): + """ Used to implement attention between consecutive chunks. + + Args: + vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] + num_chunks_before: chunks before current chunk to include in attention + num_chunks_after: chunks after current chunk to include in attention + + Returns: + tensor of shape [num_chunks, N * chunk_length, ...], where + N = (1 + num_chunks_before + num_chunks_after). + """ + if num_chunks_before == 0 and num_chunks_after == 0: + return vectors + + slices = [] + for i in range(-num_chunks_before, num_chunks_after + 1): + if i == 0: + slices.append(vectors) + else: + slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) + return torch.cat(slices, dim=3) + + def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size): + """ + splits hidden_size dim into attn_head_size and num_attn_heads + """ + new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) + x = x.view(*new_x_shape) + return x.transpose(2, 1) + + def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size): + """ + merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + x = x.permute(0, 2, 1, 3) + return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) + + def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): + """ + splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims + """ + batch_size = vectors.shape[0] + split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) + + if len(vectors.shape) == 4: + return torch.reshape(vectors, split_dim_shape + (attn_head_size,)) + elif len(vectors.shape) == 3: + return torch.reshape(vectors, split_dim_shape) + else: + raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape))) + + +class LSHSelfAttention(nn.Module, EfficientAttentionMixin): + def __init__(self, config): + super().__init__() + self.chunk_length = config.lsh_attn_chunk_length + self.num_hashes = config.num_hashes + self.num_buckets = config.num_buckets + self.num_chunks_before = config.lsh_num_chunks_before + self.num_chunks_after = config.lsh_num_chunks_after + self.hash_seed = config.hash_seed + self.is_decoder = config.is_decoder + self.max_position_embeddings = config.max_position_embeddings + + self.dropout = config.lsh_attention_probs_dropout_prob + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = config.hidden_size + + # projection matrices + self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + # save mask value here. Need fp32 and fp16 mask values + self.register_buffer("self_mask_value_float16", torch.tensor(-1e3)) + self.register_buffer("self_mask_value_float32", torch.tensor(-1e5)) + self.register_buffer("mask_value_float16", torch.tensor(-1e4)) + self.register_buffer("mask_value_float32", torch.tensor(-1e9)) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + buckets=None, + **kwargs + ): + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + # num hashes can optionally be overwritten by user + num_hashes = num_hashes if num_hashes is not None else self.num_hashes + + # project hidden_states to query_key and value + query_key_vectors = self.query_key(hidden_states) + value_vectors = self.value(hidden_states) + + # free memory + del hidden_states + + query_key_vectors = self._split_hidden_size_dim( + query_key_vectors, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) + + assert ( + query_key_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + query_key_vectors.shape[-1], self.attention_head_size + ) + assert ( + value_vectors.shape[-1] == self.attention_head_size + ), "last dim of value_vectors is {} but should be {}.".format( + value_vectors.shape[-1], self.attention_head_size + ) + + # set `num_buckets` on the fly, recommended way to do it + if self.num_buckets is None: + self._set_num_buckets(sequence_length) + + # use cached buckets for backprop only + if buckets is None: + # hash query key vectors into buckets + buckets = self._hash_vectors(query_key_vectors, num_hashes) + + assert ( + int(buckets.shape[-1]) == num_hashes * sequence_length + ), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) + + sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( + sequence_length, buckets, num_hashes + ) + + # make sure bucket idx is not longer then sequence length + sorted_bucket_idx = sorted_bucket_idx % sequence_length + + # cluster query key value vectors according to hashed buckets + query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) + value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) + + query_key_vectors = self._split_seq_length_dim_to( + query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + + if self.chunk_length is None: + assert ( + self.num_chunks_before == 0 and self.num_chunks_after == 0 + ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + + # scale key vectors + key_vectors = self._len_and_dim_norm(query_key_vectors) + + # get attention probs + out_vectors, logits, attention_probs = self._attend( + query_vectors=query_key_vectors, + key_vectors=key_vectors, + value_vectors=value_vectors, + sorted_bucket_idx=sorted_bucket_idx, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # free memory + del query_key_vectors, key_vectors, value_vectors + + # sort clusters back to correct ordering + out_vectors, logits = ReverseSort.apply( + out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes + ) + + # sum up all hash rounds + if num_hashes > 1: + out_vectors = self._split_seq_length_dim_to( + out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, + ) + logits = self._split_seq_length_dim_to( + logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, + ).unsqueeze(-1) + + probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) + out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) + # free memory + del probs_vectors + + # free memory + del logits + + assert out_vectors.shape == ( + batch_size, + self.num_attention_heads, + sequence_length, + self.attention_head_size, + ), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`." + + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) + + if do_output_attentions is False: + attention_probs = () + + return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) + + def _hash_vectors(self, vectors, num_hashes): + batch_size = vectors.shape[0] + + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(self.num_buckets, int): + assert ( + self.num_buckets % 2 == 0 + ), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format(self.num_buckets) + rotation_size = self.num_buckets + num_buckets = self.num_buckets + else: + # Factorize the hash if self.num_buckets is a list or tuple + rotation_size, num_buckets = 0, 1 + for bucket_factor in self.num_buckets: + assert bucket_factor % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format( + bucket_factor + ) + rotation_size = rotation_size + bucket_factor + num_buckets = num_buckets * bucket_factor + + # remove gradient + vectors = vectors.detach() + + if self.hash_seed is not None: + # for determinism + torch.manual_seed(self.hash_seed) + + rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) + # create a random self.attention_head_size x num_hashes x num_buckets/2 + random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) + + # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 + rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) + + if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: + rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) + buckets = torch.argmax(rotated_vectors, dim=-1) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for bucket_factor in self.num_buckets: + rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] + cur_sum = cur_sum + bucket_factor // 2 + rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) + + if buckets is None: + buckets = torch.argmax(rotated_vectors_factor, dim=-1) + else: + buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1)) + + cur_product = cur_product * bucket_factor + + # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). + # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(num_hashes, device=vectors.device) + offsets = (offsets * num_buckets).view((1, 1, -1, 1)) + + # expand to batch size and num attention heads + offsets = offsets.expand((batch_size, self.num_attention_heads) + offsets.shape[-2:]) + offset_buckets = (buckets + offsets).flatten(start_dim=2, end_dim=3) + + return offset_buckets + + def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): + # no gradients are needed + with torch.no_grad(): + batch_size = buckets.shape[0] + + # arange and expand + orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device).view(1, 1, -1) + orig_indices = orig_indices.expand(batch_size, self.num_attention_heads, orig_indices.shape[-1]) + + # scale buckets + scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) + + # remove gradient + scaled_buckets = scaled_buckets.detach() + + # Hash-based sort + sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) + + # create simple indices to scatter to, to have undo sort + indices = ( + torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) + .view(1, 1, -1) + .expand(sorted_bucket_idx.shape) + ) + + # get undo sort + undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) + undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) + + return sorted_bucket_idx, undo_sorted_bucket_idx + + def _set_num_buckets(self, sequence_length): + # recommended `num_buckets` from paper + num_buckets = 2 * sequence_length // self.chunk_length + + # factorize `num_buckets` if `num_buckets` becomes too large + num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,) + if num_buckets > 2 * num_buckets_limit: + num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] + + logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) + self.num_buckets = num_buckets + + def _attend( + self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, + ): + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + + # get logits and dots + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # free memory + del query_vectors, key_vectors + + query_bucket_idx = self._split_seq_length_dim_to( + sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads + ) + key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) + + # get correct mask values depending on precision + if query_key_dots.dtype == torch.float16: + self_mask_value = self.self_mask_value_float16 + mask_value = self.mask_value_float16 + else: + self_mask_value = self.self_mask_value_float32 + mask_value = self.mask_value_float32 + + mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask) + + if mask is not None: + query_key_dots = torch.where(mask, query_key_dots, mask_value) + + # free memory + del mask + + # Self mask is ALWAYS applied. + # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): + # " While attention to the future is not allowed, typical implementations of the + # Transformer do allow a position to attend to itself. + # Such behavior is undesirable in a shared-QK formulation because the dot-product + # of a query vector with itself will almost always be greater than the dot product of a + # query vector with a vector at another position. We therefore modify the masking + # to forbid a token from attending to itself, except in situations + # where a token has no other valid attention targets (e.g. the first token in a sequence) " + + self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( + query_bucket_idx.device + ) + + # apply self_mask + query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value) + + # free memory + del self_mask + + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` + attention_probs = torch.exp(query_key_dots - logits) + + # free memory + del query_key_dots + + # dropout + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # attend values + out_vectors = torch.matmul(attention_probs, value_vectors) + + # free memory + del value_vectors + + # merge chunk length + logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + + return out_vectors, logits, attention_probs + + def _compute_attn_mask(self, query_indices, key_indices, attention_mask): + mask = None + + # Causal mask + if self.is_decoder: + mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # Attention mask: chunk, look up correct mask value from key_value_bucket_idx + # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. + if attention_mask is not None: + attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] + # expand attn_mask to fit with key_value_bucket_idx shape + attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) + key_attn_mask = torch.gather(attention_mask, -1, key_indices) + query_attn_mask = torch.gather(attention_mask, -1, query_indices) + # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk + attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) + # free memory + del query_attn_mask, key_attn_mask, attention_mask + + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask + + return mask + + def _len_and_dim_norm(self, vectors): + """ + length and attention head size dim normalization + """ + vectors = self._len_norm(vectors) + vectors = vectors * torch.rsqrt( + torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype) + ) + return vectors + + def _len_norm(self, x, epsilon=1e-6): + """ + length normalization + """ + variance = torch.mean(x ** 2, -1, keepdim=True) + norm_x = x * torch.rsqrt(variance + epsilon) + return norm_x + + def _gather_by_expansion(self, vectors, idxs, num_hashes): + """ + expand dims of idxs and vectors for all hashes and gather + """ + expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) + vectors = vectors.repeat(1, 1, num_hashes, 1) + return torch.gather(vectors, 2, expanded_idxs) + + +class ReverseSort(Function): + """ + After chunked attention is applied which sorted clusters, + original ordering has to be restored. + Since customized backward function is used for Reformer, + the gradients of the output vectors have to be explicitely + sorted here. + """ + + @staticmethod + def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes): + # save sorted_bucket_idx for backprop + with torch.no_grad(): + ctx.sorted_bucket_idx = sorted_bucket_idx + ctx.num_hashes = num_hashes + + # undo sort to have correct order for next layer + expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) + out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) + logits = torch.gather(logits, 2, undo_sorted_bucket_idx) + return out_vectors, logits + + @staticmethod + def backward(ctx, grad_out_vectors, grad_logits): + # get parameters saved in ctx + sorted_bucket_idx = ctx.sorted_bucket_idx + num_hashes = ctx.num_hashes + + # get real gradient shape + # shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes + grad_logits_shape = grad_logits.shape + # shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen + grad_out_vectors_shape = grad_out_vectors.shape + + # split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices + # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen + grad_logits = grad_logits.view((grad_logits_shape[:2] + (num_hashes, -1))) + # shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen + grad_out_vectors = grad_out_vectors.view( + (grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:]) + ) + + # reshape and expand + sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1))) + expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) + # reverse sort of forward + grad_out_vectors = torch.gather(grad_out_vectors, 3, expanded_sort_indices) + grad_logits = torch.gather(grad_logits, 3, sorted_bucket_idx) + + # reshape into correct shape + grad_logits = torch.reshape(grad_logits, grad_logits_shape) + grad_out_vectors = torch.reshape(grad_out_vectors, grad_out_vectors_shape) + + # return grad and `None` fillers for last 3 forward args + return grad_out_vectors, grad_logits, None, None, None + + +class LocalSelfAttention(nn.Module, EfficientAttentionMixin): + def __init__(self, config): + super().__init__() + + self.num_attention_heads = config.num_attention_heads + self.chunk_length = config.local_attn_chunk_length + self.num_chunks_before = config.local_num_chunks_before + self.num_chunks_after = config.local_num_chunks_after + self.is_decoder = config.is_decoder + self.pad_token_id = config.pad_token_id + + self.attention_head_size = config.attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = config.hidden_size + + # projection matrices + self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) + + self.dropout = config.local_attention_probs_dropout_prob + + # save mask value here + self.register_buffer("mask_value_float16", torch.tensor(-1e4)) + self.register_buffer("mask_value_float32", torch.tensor(-1e9)) + + def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): + sequence_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + # project hidden_states to query, key and value + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + # split last dim into `config.num_attention_heads` and `config.attention_head_size` + query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) + key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) + value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) + + assert ( + query_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + query_vectors.shape[-1], self.attention_head_size + ) + assert ( + key_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + key_vectors.shape[-1], self.attention_head_size + ) + assert ( + value_vectors.shape[-1] == self.attention_head_size + ), "last dim of query_key_vectors is {} but should be {}.".format( + value_vectors.shape[-1], self.attention_head_size + ) + + if self.chunk_length is None: + assert ( + self.num_chunks_before == 0 and self.num_chunks_after == 0 + ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + + # normalize key vectors + key_vectors = key_vectors / torch.sqrt( + torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) + ) + + # chunk vectors + # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size + query_vectors = self._split_seq_length_dim_to( + query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + key_vectors = self._split_seq_length_dim_to( + key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, + ) + + # chunk indices + indices = torch.arange(sequence_length, device=query_vectors.device).repeat( + batch_size, self.num_attention_heads, 1 + ) + query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) + + # append chunks before and after + key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) + value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) + key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) + + query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) + + # free memory + del query_vectors, key_vectors + + mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape) + + if mask is not None: + # get mask tensor depending on half precision or not + if query_key_dots.dtype == torch.float16: + mask_value = self.mask_value_float16 + else: + mask_value = self.mask_value_float32 + + query_key_dots = torch.where(mask, query_key_dots, mask_value) + + # free memory + del mask + + # softmax + logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) + attention_probs = torch.exp(query_key_dots - logits) + + # free memory + del logits + + # dropout + attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # attend values + out_vectors = torch.matmul(attention_probs, value_vectors) + + # free memory + del value_vectors + + # merge chunk length + out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) + + assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) + + out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) + + if do_output_attentions is False: + attention_probs = () + + return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) + + def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape): + mask = None + + # chunk attention mask and look before and after + if attention_mask is not None: + attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) + attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) + + # Causal mask + if self.is_decoder is True: + mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) + + # Attention mask + if attention_mask is not None: + # create attn_mask + attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape) + # multiply by casaul mask if necessary + if mask is not None: + mask = mask * attn_mask + else: + mask = attn_mask + return mask + + +class ReformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + all_head_size = config.num_attention_heads * config.attention_head_size + self.dropout = config.hidden_dropout_prob + + self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ReformerAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.layer_id = layer_id + self.attn_layers = config.attn_layers + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": + self.self_attention = LSHSelfAttention(config) + elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": + self.self_attention = LocalSelfAttention(config) + elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): + # get correct attn layers + if self.attn_layers[self.layer_id] == "lsh": + self.self_attention = LSHSelfAttention(config) + else: + self.self_attention = LocalSelfAttention(config) + else: + raise NotImplementedError( + "Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format( + self.attn_layers + ) + ) + self.output = ReformerSelfOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + buckets=None, + ): + hidden_states = self.layer_norm(hidden_states) + + # use cached buckets for backprob if buckets not None for LSHSelfAttention + self_attention_outputs = self.self_attention( + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + buckets=buckets, + ) + attention_output = self.output(self_attention_outputs.hidden_states) + + # add buckets if necessary + if hasattr(self_attention_outputs, "buckets"): + buckets = self_attention_outputs.buckets + else: + buckets = None + + return AttentionOutput( + hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, + ) + + +class ReformerFeedForwardDense(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.act_fn(hidden_states) + return hidden_states + + +class ReformerFeedForwardOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ChunkReformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = ReformerFeedForwardDense(config) + self.output = ReformerFeedForwardOutput(config) + + def forward(self, attention_output): + return apply_chunking_to_forward( + self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, + ) + + def forward_chunk(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dense(hidden_states) + return self.output(hidden_states) + + +class ReformerLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.attention = ReformerAttention(config, layer_id) + # dropout requires to have the same + # seed for forward and backward pass + self.attention_seed = None + self.feed_forward_seed = None + + self.feed_forward = ChunkReformerFeedForward(config) + + def _init_attention_seed(self): + """ + This function sets a new seed for the + attention layer to make dropout deterministic + for both forward calls: 1 normal forward + call and 1 forward call in backward + to recalculate activations. + """ + + # randomize seeds + if next(self.parameters()).device.type == "cuda": + # GPU + device_idx = torch.cuda.current_device() + self.attention_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.attention_seed) + else: + # CPU + self.attention_seed = int(torch.seed() % sys.maxsize) + torch.manual_seed(self.attention_seed) + + def _init_feed_forward_seed(self): + """ + This function sets a new seed for the + feed forward layer to make dropout deterministic + for both forward calls: 1 normal forward + call and 1 forward call in backward + to recalculate activations. + """ + + # randomize seeds + if next(self.parameters()).device.type == "cuda": + # GPU + device_idx = torch.cuda.current_device() + self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() + torch.cuda.manual_seed(self.feed_forward_seed) + else: + # CPU + self.feed_forward_seed = int(torch.seed() % sys.maxsize) + torch.manual_seed(self.feed_forward_seed) + + def forward( + self, + prev_attn_output, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_attentions=False, + ): + with torch.no_grad(): + # every forward pass we sample a different seed + # for dropout and save for forward fn in backward pass + # to have correct dropout + self._init_attention_seed() + attn_outputs = self.attention( + hidden_states=hidden_states, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + ) + attn_output = attn_outputs.hidden_states + + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # Y_1 = X_1 + f(X_2) + attn_output = prev_attn_output + attn_output + + # free memory + del prev_attn_output + + # every forward pass we sample a different seed + # for dropout and save seed for forward fn in backward + # to have correct dropout + self._init_feed_forward_seed() + # Y_2 = X_2 + g(Y_1) + hidden_states = hidden_states + self.feed_forward(attn_output) + + return ReformerOutput( + attn_output=attn_output, + hidden_states=hidden_states, + attention_probs=attn_outputs.attention_probs, + buckets=attn_outputs.buckets, + ) + + def backward_pass( + self, + next_attn_output, + hidden_states, + grad_attn_output, + grad_hidden_states, + attention_mask=None, + head_mask=None, + buckets=None, + ): + # Implements the backward pass for reversible ResNets. + # A good blog post on how this works can be found here: + # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) + # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + + with torch.enable_grad(): + next_attn_output.requires_grad = True + + # set seed to have correct dropout + torch.manual_seed(self.feed_forward_seed) + # g(Y_1) + res_hidden_states = self.feed_forward(next_attn_output) + res_hidden_states.backward(grad_hidden_states, retain_graph=True) + + with torch.no_grad(): + # X_2 = Y_2 - g(Y_1) + hidden_states = hidden_states - res_hidden_states + del res_hidden_states + + grad_attn_output = grad_attn_output + next_attn_output.grad + next_attn_output.grad = None + + with torch.enable_grad(): + hidden_states.requires_grad = True + + # set seed to have correct dropout + torch.manual_seed(self.attention_seed) + # f(X_2) + # use cached buckets for backprob if buckets not None for LSHSelfAttention + output = self.attention( + hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, + ).hidden_states + output.backward(grad_attn_output, retain_graph=True) + + with torch.no_grad(): + # X_1 = Y_1 - f(X_2) + attn_output = next_attn_output - output + del output, next_attn_output + + grad_hidden_states = grad_hidden_states + hidden_states.grad + hidden_states.grad = None + hidden_states = hidden_states.detach() + + return ReformerBackwardOutput( + attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) + + +class _ReversibleFunction(Function): + """ + To prevent PyTorch from performing the usual backpropagation, + a customized backward function is implemented here. This way + it is made sure that no memory expensive activations are + saved during the forward pass. + This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py + """ + + @staticmethod + def forward( + ctx, + hidden_states, + layers, + attention_mask, + head_mask, + num_hashes, + all_hidden_states, + all_attentions, + do_output_hidden_states, + do_output_attentions, + ): + all_buckets = () + + # split duplicated tensor + hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) + + for layer, layer_head_mask in zip(layers, head_mask): + if do_output_hidden_states is True: + all_hidden_states.append(hidden_states) + + layer_outputs = layer( + prev_attn_output=attn_output, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + num_hashes=num_hashes, + do_output_attentions=do_output_attentions, + ) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + all_buckets = all_buckets + (layer_outputs.buckets,) + + if do_output_attentions: + all_attentions.append(layer_outputs.attention_probs) + + # Add last layer + if do_output_hidden_states is True: + all_hidden_states.append(hidden_states) + + # attach params to ctx for backward + ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) + ctx.layers = layers + ctx.all_buckets = all_buckets + ctx.head_mask = head_mask + ctx.attention_mask = attention_mask + + # Concatenate 2 RevNet outputs + return torch.cat([attn_output, hidden_states], dim=-1) + + @staticmethod + def backward(ctx, grad_hidden_states): + grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) + + # retrieve params from ctx for backward + attn_output, hidden_states = ctx.saved_tensors + + # create tuple + output = ReformerBackwardOutput( + attn_output=attn_output, + hidden_states=hidden_states, + grad_attn_output=grad_attn_output, + grad_hidden_states=grad_hidden_states, + ) + + # free memory + del grad_attn_output, grad_hidden_states, attn_output, hidden_states + + layers = ctx.layers + all_buckets = ctx.all_buckets + head_mask = ctx.head_mask + attention_mask = ctx.attention_mask + + for idx, layer in enumerate(layers[::-1]): + # pop last buckets from stack + buckets = all_buckets[-1] + all_buckets = all_buckets[:-1] + + # backprop + output = layer.backward_pass( + next_attn_output=output.attn_output, + hidden_states=output.hidden_states, + grad_attn_output=output.grad_attn_output, + grad_hidden_states=output.grad_hidden_states, + head_mask=head_mask[len(layers) - idx - 1], + attention_mask=attention_mask, + buckets=buckets, + ) + + assert all_buckets == (), "buckets have to be empty after backpropagation" + grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1) + + # num of return vars has to match num of forward() args + # return gradient for hidden_states arg and None for other args + return grad_hidden_states, None, None, None, None, None, None, None, None + + +class ReformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.dropout = config.hidden_dropout_prob + + self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + num_hashes=None, + do_output_hidden_states=False, + do_output_attentions=False, + ): + # hidden_states and attention lists to be filled if wished + all_hidden_states = [] + all_attentions = [] + + # concat same tensor for reversible ResNet + hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) + hidden_states = _ReversibleFunction.apply( + hidden_states, + self.layers, + attention_mask, + head_mask, + num_hashes, + all_hidden_states, + all_attentions, + do_output_hidden_states, + do_output_attentions, + ) + + # Apply layer norm to concatenated hidden states + hidden_states = self.layer_norm(hidden_states) + + # Apply dropout + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + return ReformerEncoderOutput( + hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions + ) + + +class ReformerOnlyLMHead(nn.Module): + def __init__(self, config): + super().__init__() + # Reformer is using Rev Nets, thus last layer outputs are concatenated and + # Layer Norm is done over 2 * hidden_size + self.seq_len_dim = 1 + self.chunk_size_lm_head = config.chunk_size_lm_head + self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class ReformerPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = ReformerConfig + pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "reformer" + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, AxialPositionEmbeddings): + for weight in module.weights: + torch.nn.init.normal_(weight, std=self.config.axial_norm_std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +REFORMER_START_DOCSTRING = r""" + Reformer was proposed in + `Reformer: The Efficient Transformer`_ + by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. + + .. _`Reformer: The Efficient Transformer`: + https://arxiv.org/abs/2001.04451 + + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +REFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + During training the input_ids sequence_length has to be a multiple of the relevant model's + chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically + padded to be a multiple of the chunk length. + + Indices can be obtained using :class:`transformers.ReformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + num_hashes (:obj:`int`, `optional`, defaults to :obj:`None`): + `num_hashes` is the number of hashing rounds that should be performed during + bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined + in `config.num_hashes`. + For more information, see `num_hashes` in :class:`transformers.ReformerConfig`. +""" + + +@add_start_docstrings( + "The bare Reformer Model transformer outputting raw hidden-states" "without any specific head on top.", + REFORMER_START_DOCSTRING, +) +class ReformerModel(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + assert ( + self.config.num_hidden_layers > 0 + ), "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" + + self.embeddings = ReformerEmbeddings(config) + self.encoder = ReformerEncoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + num_hashes=None, + do_output_hidden_states=False, + do_output_attentions=False, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased') + model = ReformerModel.from_pretrained('bert-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + """ + + # TODO(PVP): delete when PR to change output_attentions is made + do_output_attentions = self.config.output_attentions + do_output_hidden_states = self.config.output_hidden_states + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() # noqa: F841 + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] # noqa: F841 + device = inputs_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + assert ( + len(input_shape) == 2 + ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) + + # prepare head mask + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) + + # original sequence length for padding + orig_sequence_length = input_shape[-1] + + # if needs padding + least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) + must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 + + if must_pad_to_match_chunk_length: + padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length + + if self.training is True: + raise ValueError( + "If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format( + input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length + ) + ) + + # pad input + input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( + input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + input_shape=input_shape, + padding_length=padding_length, + padded_seq_length=least_common_mult_chunk_length, + device=device, + ) + + embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + head_mask=head_mask, + attention_mask=attention_mask, + num_hashes=num_hashes, + do_output_hidden_states=do_output_hidden_states, + do_output_attentions=do_output_attentions, + ) + sequence_output = encoder_outputs.hidden_states + + # if padding was applied + if must_pad_to_match_chunk_length: + sequence_output = sequence_output[:, :orig_sequence_length] + + outputs = (sequence_output,) + # TODO(PVP): Replace by named tuple after namedtuples are introduced in the library. + if do_output_hidden_states is True: + outputs = outputs + (encoder_outputs.all_hidden_states,) + if do_output_attentions is True: + outputs = outputs + (encoder_outputs.all_attentions,) + return outputs + + def _pad_to_mult_of_chunk_length( + self, + input_ids, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + input_shape=None, + padding_length=None, + padded_seq_length=None, + device=None, + ): + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format( + input_shape[-1], input_shape[-1] + padding_length, padded_seq_length + ) + ) + + padded_input_ids = torch.full( + (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, + ) + + # Extend `attention_mask` + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask, + torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,), + ], + dim=-1, + ) + else: + attention_mask = torch.cat( + [ + torch.ones(input_shape, device=device, dtype=torch.uint8), + torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), + ], + dim=-1, + ) + + # Extend `input_ids` with padding to match least common multiple chunk_length + if input_ids is not None: + input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) + input_shape = input_ids.size() + + # Pad position ids if given + if position_ids is not None: + padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device) + padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length) + position_ids = torch.cat([position_ids, padded_position_ids], dim=-1) + + # Extend `input_embeds` with padding to match least common multiple chunk_length + if inputs_embeds is not None: + padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) + inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) + input_shape = inputs_embeds.size() + return input_ids, inputs_embeds, attention_mask, position_ids, input_shape + + +@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING) +class ReformerModelWithLMHead(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.reformer = ReformerModel(config) + self.lm_head = ReformerOnlyLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def tie_weights(self): + # word embeddings are not tied in Reformer + pass + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + num_hashes=None, + labels=None, + do_output_hidden_states=False, + do_output_attentions=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. + All labels set to ``-100`` are ignored (masked), the loss is only + computed for labels in ``[0, ..., config.vocab_size]`` + + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided): + Classification loss (cross entropy). + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ReformerModel, ReformerTokenizer + import torch + + tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment') + model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=input_ids) + + loss, prediction_scores = outputs[:2] + """ + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + do_output_hidden_states=do_output_hidden_states, + do_output_attentions=do_output_attentions, + ) + + sequence_output = reformer_outputs[0] + logits = self.lm_head(sequence_output) + outputs = (logits,) + reformer_outputs[1:] + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + outputs = (loss,) + outputs + return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) + + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + # TODO(PVP): Add smart caching + inputs_dict = {"input_ids": input_ids} + + if "num_hashes" in kwargs: + inputs_dict["num_hashes"] = kwargs["num_hashes"] + + return inputs_dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 68a6c4de1c67..02ae2240bc33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch BERT model.""" +import inspect import logging import os from typing import Callable, Tuple @@ -175,7 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask, num_hidden_layers): + def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False): """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -189,6 +189,8 @@ def get_head_mask(self, head_mask, num_hidden_layers): """ if head_mask is not None: head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) else: head_mask = [None] * num_hidden_layers @@ -786,6 +788,7 @@ def generate( attention_mask=None, decoder_start_token_id=None, use_cache=None, + **model_specific_kwargs ): r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. @@ -863,6 +866,9 @@ def generate( use_cache: (`optional`) bool If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`. + model_specific_kwargs: (`optional`) dict + Additional model specific kwargs will be forwarded to the `forward` function of the model. + Return: output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` @@ -1116,6 +1122,7 @@ def generate( encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, ) else: output = self._generate_no_beam_search( @@ -1138,6 +1145,7 @@ def generate( encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, ) return output @@ -1163,6 +1171,7 @@ def _generate_no_beam_search( encoder_outputs, attention_mask, use_cache, + model_specific_kwargs, ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -1175,7 +1184,7 @@ def _generate_no_beam_search( while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) outputs = self(**model_inputs) @@ -1288,6 +1297,7 @@ def _generate_beam_search( encoder_outputs, attention_mask, use_cache, + model_specific_kwargs, ): """ Generate sequences for each example with beam search. """ @@ -1314,7 +1324,7 @@ def _generate_beam_search( while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) @@ -2087,3 +2097,66 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) + + +def apply_chunking_to_forward( + chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. + It then applies a layer `forward_fn` to each chunk independently to save memory. + If the `forward_fn` is independent across the `chunk_dim` this function will yield the + same result as not applying it. + + Args: + chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` + chunk_dim: int - the dimension over which the input_tensors should be chunked + forward_fn: fn - the forward fn of the model + input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked + Returns: + a Tensor with the same shape the foward_fn would have given if applied + + + Examples:: + + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + """ + + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) + tensor_shape = input_tensors[0].shape + assert all( + input_tensor.shape == tensor_shape for input_tensor in input_tensors + ), "All input tenors have to be of the same shape" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + assert num_args_in_forward_chunk_fn == len( + input_tensors + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( + num_args_in_forward_chunk_fn, len(input_tensors) + ) + + if chunk_size > 0: + assert ( + input_tensors[0].shape[chunk_dim] % chunk_size == 0 + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( + input_tensors[0][chunk_dim], chunk_size + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index f26e4fa0fdf8..9982e31cea9d 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -30,6 +30,7 @@ FlaubertConfig, GPT2Config, OpenAIGPTConfig, + ReformerConfig, RobertaConfig, T5Config, TransfoXLConfig, @@ -49,6 +50,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast +from .tokenization_reformer import ReformerTokenizer from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_t5 import T5Tokenizer from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast @@ -69,6 +71,7 @@ (XLMRobertaConfig, (XLMRobertaTokenizer, None)), (BartConfig, (BartTokenizer, None)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), + (ReformerConfig, (ReformerTokenizer, None)), (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py new file mode 100644 index 000000000000..4accdcc3cfbf --- /dev/null +++ b/src/transformers/tokenization_reformer.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model Reformer.""" + + +import logging +import os +from shutil import copyfile + +from .tokenization_utils import PreTrainedTokenizer + + +logger = logging.getLogger(__name__) + +SPIECE_UNDERLINE = "▁" + + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to file names for serializing Tokenizer instances +#################################################### +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to pretrained vocabulary URL for all the model shortcut names. +#################################################### +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" + } +} + +#################################################### +# Mapping from model shortcut names to max length of inputs +#################################################### +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "google/reformer-crime-and-punishment": 524288, +} + + +class ReformerTokenizer(PreTrainedTokenizer): + """ + Constructs an Reformer tokenizer. Based on `SentencePiece `__ . + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`string`): + `SentencePiece `__ file (generally has a `.spm` extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (:obj:`string`, `optional`, defaults to ""): + The end of sequence token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the end + of sequence. The token used is the :obj:`sep_token`. + unk_token (:obj:`string`, `optional`, defaults to ""): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`string`, `optional`, defaults to ""): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=[], + **kwargs + ): + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use ReformerTokenizer:" + "https://github.com/google/sentencepiece" + "pip install sentencepiece" + ) + raise + + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece" + "pip install sentencepiece" + ) + raise + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text, sample=False): + """ Take as input a string and return a list of strings (tokens) for words/sub-words + """ + if not sample: + pieces = self.sp_model.EncodeAsPieces(text) + else: + pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) + return pieces + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index < self.sp_model.get_piece_size(): + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = self.sp_model.decode_pieces(tokens) + return out_string + + def save_vocabulary(self, save_directory): + """ Save the sentencepiece vocabulary (copy original file) and special tokens file + to a directory. + """ + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) diff --git a/tests/test_activations.py b/tests/test_activations.py index d6cbc1f9e564..79e9eec0184c 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -22,6 +22,8 @@ def test_get_activation(self): get_activation("swish") get_activation("relu") get_activation("tanh") + get_activation("gelu_new") + get_activation("gelu_fast") with self.assertRaises(KeyError): get_activation("bogus") with self.assertRaises(KeyError): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2909e17e9e49..8c9c6a9f5a56 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -125,6 +125,9 @@ def test_attention_outputs(self): encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes for model_class in self.all_model_classes: config.output_attentions = True @@ -138,10 +141,17 @@ def test_attention_outputs(self): self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) out_len = len(outputs) if self.is_encoder_decoder: @@ -175,10 +185,16 @@ def test_attention_outputs(self): self_attentions = outputs[-1] self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) def test_torchscript(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -465,14 +481,16 @@ def test_hidden_states_output(self): self.assertEqual(model.config.output_attentions, False) self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [ - self.model_tester.encoder_seq_length - if hasattr(self.model_tester, "encoder_seq_length") - else self.model_tester.seq_length, - self.model_tester.hidden_size, - ], + list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size], ) def test_resize_tokens_embeddings(self): @@ -485,6 +503,9 @@ def test_resize_tokens_embeddings(self): model = model_class(config) model.to(torch_device) + if self.model_tester.is_training is False: + model.eval() + model_vocab_size = config.vocab_size # Retrieve the embeddings and clone theme model_embed = model.resize_token_embeddings(model_vocab_size) @@ -628,9 +649,13 @@ def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] + # iterate over all generative models for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined, model needs input_ids @@ -669,8 +694,12 @@ def test_lm_head_model_random_beam_search_generate(self): torch_device ) + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] + for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1 @@ -750,7 +779,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None): def floats_tensor(shape, scale=1.0, rng=None, name=None): - """Creates a random float32 tensor of the shape within the vocab size.""" + """Creates a random float32 tensor""" if rng is None: rng = global_rng diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py new file mode 100644 index 000000000000..c79b212a8c59 --- /dev/null +++ b/tests/test_modeling_reformer.py @@ -0,0 +1,955 @@ +# coding=utf-8 # Copyright 2020 Huggingface +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from .utils import require_torch, slow, torch_device + + +if is_torch_available(): + from transformers import ( + ReformerConfig, + ReformerModel, + ReformerModelWithLMHead, + ReformerTokenizer, + ReformerLayer, + REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, + ) + import torch + + +class ReformerModelTester: + def __init__( + self, + parent, + batch_size=None, + seq_length=None, + is_training=None, + is_decoder=None, + use_input_mask=None, + vocab_size=None, + attention_head_size=None, + hidden_size=None, + num_attention_heads=None, + local_attn_chunk_length=None, + local_num_chunks_before=None, + local_num_chunks_after=None, + num_buckets=None, + num_hashes=1, + lsh_attn_chunk_length=None, + lsh_num_chunks_before=None, + lsh_num_chunks_after=None, + chunk_size_lm_head=None, + chunk_size_feed_forward=None, + feed_forward_size=None, + hidden_act=None, + hidden_dropout_prob=None, + local_attention_probs_dropout_prob=None, + lsh_attention_probs_dropout_prob=None, + max_position_embeddings=None, + initializer_range=None, + axial_norm_std=None, + layer_norm_eps=None, + axial_pos_embds=None, + axial_pos_shape=None, + axial_pos_embds_dim=None, + attn_layers=None, + pad_token_id=None, + eos_token_id=None, + scope=None, + hash_seed=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.is_decoder = is_decoder + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = len(attn_layers) + self.local_attn_chunk_length = local_attn_chunk_length + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before + self.num_hashes = num_hashes + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.scope = scope + self.attn_layers = attn_layers + self.pad_token_id = pad_token_id + self.hash_seed = hash_seed + + attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length + num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after + num_chunks_before = local_num_chunks_before if local_num_chunks_before is not None else lsh_num_chunks_before + + self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0) + self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length + self.chunk_length = attn_chunk_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = ReformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + feed_forward_size=self.feed_forward_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, + lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + is_decoder=self.is_decoder, + axial_pos_embds=self.axial_pos_embds, + axial_pos_shape=self.axial_pos_shape, + axial_pos_embds_dim=self.axial_pos_embds_dim, + local_attn_chunk_length=self.local_attn_chunk_length, + local_num_chunks_after=self.local_num_chunks_after, + local_num_chunks_before=self.local_num_chunks_before, + num_hashes=self.num_hashes, + num_buckets=self.num_buckets, + lsh_attn_chunk_length=self.lsh_attn_chunk_length, + lsh_num_chunks_after=self.lsh_num_chunks_after, + lsh_num_chunks_before=self.lsh_num_chunks_before, + attn_layers=self.attn_layers, + pad_token_id=self.pad_token_id, + hash_seed=self.hash_seed, + ) + + return ( + config, + input_ids, + input_mask, + ) + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) + + def create_and_check_reformer_model( + self, config, input_ids, input_mask, + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + (sequence_output,) = model(input_ids, attention_mask=input_mask) + (sequence_output,) = model(input_ids) + + result = { + "sequence_output": sequence_output, + } + # 2 * hidden_size because we use reversible resnet layers + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], + ) + + def create_and_check_reformer_model_with_lm_backward( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] + loss.backward() + + def create_and_check_reformer_with_lm( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], + ) + self.check_loss_output(result) + + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + # no special position embeddings + config.axial_pos_embds = False + config.is_decoder = is_decoder + + if self.lsh_attn_chunk_length is not None: + # need to set chunk length equal sequence length to be certain that chunking works + config.lsh_attn_chunk_length = self.seq_length + + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + # set all position encodings to zero so that postions don't matter + with torch.no_grad(): + embedding = model.embeddings.position_embeddings.embedding + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) + embedding.weight.requires_grad = False + + half_seq_len = self.seq_length // 2 + roll = self.chunk_length + + half_input_ids = input_ids[:, :half_seq_len] + + # normal padded + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) + input_ids_padded = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + + # shifted padded + input_ids_roll = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) + attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) + + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll] + + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + config.is_decoder = is_decoder + layer = ReformerLayer(config).to(torch_device) + layer.train() + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + # get random tensors + hidden_states = floats_tensor(shape) + prev_attn_output = floats_tensor(shape) + + # now the random seeds for attention and feed forward is initialized + # forward tensors with dropout + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + + next_attn_output = layer_outputs.attn_output + next_hidden_states = layer_outputs.hidden_states + + torch.manual_seed(layer.attention_seed) + attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) + self.parent.assertTrue( + torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) + ) + + torch.manual_seed(layer.feed_forward_seed) + feed_forward_hidden_states = layer.feed_forward(next_attn_output) + self.parent.assertTrue( + torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) + ) + + def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] + + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 + + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + + hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) + + def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): + if not self.is_training: + return + + # disable dropout + config.hidden_dropout_prob = 0 + config.local_attention_probs_dropout_prob = 0 + config.lsh_attention_probs_dropout_prob = 0 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_no_chunk.backward() + grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_chunk.backward() + grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) + ) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) + ) + + def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): + layer = ReformerLayer(config).to(torch_device) + layer.train() + + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + hidden_states = floats_tensor(shape) + attn_output = floats_tensor(shape) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.attention_seed) + seeds.append(layer.attention_seed) + self.parent.assertGreater(len(set(seeds)), 70) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.feed_forward_seed) + seeds.append(layer.feed_forward_seed) + self.parent.assertGreater(len(set(seeds)), 70) + + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) + self.parent.assertFalse(torch.isnan(output).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask,) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class ReformerTesterMixin: + """ + Reformer Local and Reformer LSH run essentially the same tests + """ + + def test_config(self): + self.config_tester.run_common_tests() + + def test_reformer_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model(*config_and_inputs) + + def test_reformer_lm_model_backward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs) + + def test_reformer_model_attn_masking(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + + def test_reformer_with_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + + def test_reformer_layer_training_dropout(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) + self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + + def test_reformer_chunking_forward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + + def test_reformer_chunking_backward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) + + @slow + def test_dropout_random_seed_is_changing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_reformer_model_fp16_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + + +@require_torch +class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 32, + "is_training": True, + "is_decoder": False, + "use_input_mask": True, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "local_attn_chunk_length": 4, + "local_num_chunks_before": 1, + "local_num_chunks_after": 0, + "chunk_size_lm_head": 0, + "chunk_size_feed_forward": 0, + "feed_forward_size": 32, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "local_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 16], + "attn_layers": ["local", "local", "local", "local"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + @slow + def test_model_from_pretrained(self): + for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + model = ReformerModelWithLMHead.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 13, + "use_input_mask": True, + "is_training": False, + "is_decoder": False, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 64, + "num_attention_heads": 2, + "num_buckets": 2, + "num_hashes": 4, + "lsh_attn_chunk_length": 4, + "lsh_num_chunks_before": 2, + "lsh_num_chunks_after": 3, + "chunk_size_lm_head": 5, + "chunk_size_feed_forward": 6, + "feed_forward_size": 32, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "lsh_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 48], + "attn_layers": ["lsh", "lsh", "lsh", "lsh"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + +@require_torch +class ReformerIntegrationTests(unittest.TestCase): + """ + These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`. + """ + + def _get_basic_config_and_input(self): + config = { + "vocab_size": 320, + "attention_head_size": 8, + "hidden_size": 16, + "num_attention_heads": 2, + "num_buckets": 2, + "num_hashes": 4, + "lsh_attn_chunk_length": 4, + "local_attn_chunk_length": 4, + "lsh_num_chunks_before": 1, + "lsh_num_chunks_after": 0, + "local_num_chunks_before": 1, + "local_num_chunks_after": 0, + "chunk_size_lm_head": 0, + "chunk_size_feed_forward": 0, + "feed_forward_size": 32, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "lsh_attention_probs_dropout_prob": 0.0, + "local_attention_probs_dropout_prob": 0.0, + "max_position_embeddings": 32, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "sinusoidal_pos_embds": False, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [8, 8], + "hash_seed": 0, + "is_decoder": True, + } + return config + + def _get_hidden_states(self): + return torch.tensor( + [ + [ + [ + 1.90826353e00, + -1.45999730e00, + -6.20405462e-01, + 1.52503433e00, + -3.64464232e-01, + -8.27359235e-01, + 8.39670803e-01, + 2.44492178e-01, + 4.98332758e-01, + 2.69175139e00, + -7.08081422e-03, + 1.04915401e00, + -1.83476661e00, + 7.67220476e-01, + 2.98580543e-01, + 2.84803992e-02, + ], + [ + -2.66374286e-02, + 4.33497576e-01, + 3.10386309e-01, + 5.46039944e-01, + -2.47292666e-04, + -7.52305019e-01, + 2.39162103e-01, + 7.25216186e-01, + -7.58357372e-01, + 4.20635998e-01, + -4.04739919e-02, + 1.59924145e-01, + 2.05135748e00, + -1.15997978e00, + 5.37166397e-01, + 2.62873606e-01, + ], + [ + 1.85247482e-01, + 7.07046037e-01, + -6.77089715e-01, + -2.24209655e00, + -3.75307980e-02, + -8.59380874e-01, + -2.81027884e00, + 1.01276376e00, + -1.69438001e00, + 4.17574660e-01, + -1.49196962e00, + -1.76483717e00, + -1.94566312e-01, + -1.71183858e00, + 7.72903565e-01, + -1.11557056e00, + ], + [ + 9.46069193e-01, + 1.53417623e-01, + -9.58686996e-01, + 1.18126669e-01, + 1.75967724e00, + 1.62194590e00, + -5.74108159e-01, + 6.79920443e-01, + 5.44028163e-01, + 2.05466114e-01, + -3.63045868e-01, + 2.41865062e-01, + 3.20348382e-01, + -9.05611176e-01, + -1.92690727e-01, + -1.19917547e00, + ], + ] + ], + dtype=torch.float32, + device=torch_device, + ) + + def _get_attn_mask(self): + return torch.tensor([[0, 1, 0, 0]], dtype=torch.long, device=torch_device) + + def _get_input_ids_and_mask(self): + mask = torch.tensor( + [ + [1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0], + ], + dtype=torch.long, + device=torch_device, + ) + + input_ids = torch.tensor( + [ + [ + 89, + 279, + 286, + 84, + 194, + 316, + 182, + 28, + 283, + 37, + 169, + 7, + 253, + 267, + 107, + 250, + 44, + 7, + 102, + 62, + 3, + 243, + 171, + 265, + 302, + 48, + 164, + 264, + 148, + 229, + 280, + 150, + ], + [ + 9, + 192, + 66, + 112, + 163, + 83, + 135, + 70, + 224, + 96, + 31, + 80, + 196, + 80, + 63, + 22, + 85, + 100, + 47, + 283, + 0, + 163, + 126, + 143, + 195, + 82, + 53, + 82, + 18, + 27, + 182, + 52, + ], + ], + dtype=torch.long, + device=torch_device, + ) + + return input_ids, mask + + def test_lsh_layer_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh"] + config["is_decoder"] = False + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) + layer.eval() + reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_lsh_layer_forward_complex(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh"] + config["num_buckets"] = [2, 4] + attn_mask = self._get_attn_mask() + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) + layer.eval() + reformer_output = layer( + prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask, + ) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_layer_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local"] + config["is_decoder"] = False + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) + layer.eval() + reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_layer_forward_complex(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local"] + attn_mask = self._get_attn_mask() + hidden_states = self._get_hidden_states() + torch.manual_seed(0) + layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) + layer.eval() + reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) + output_slice = reformer_output.hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_lsh_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] + config["num_buckets"] = [2, 4] + torch.manual_seed(0) + model = ReformerModel(ReformerConfig(**config)).to(torch_device) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "local", "local", "local"] + torch.manual_seed(0) + model = ReformerModel(ReformerConfig(**config)).to(torch_device) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[0, 0, :5] + expected_output_slice = torch.tensor( + [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_lm_model_forward(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "lsh", "local", "lsh", "local", "lsh"] + config["num_buckets"] = [2, 4] + config["is_decoder"] = False + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) + model.eval() + input_ids, attn_mask = self._get_input_ids_and_mask() + hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] + output_slice = hidden_states[1, -1, :5] + expected_output_slice = torch.tensor( + [0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_local_lm_model_grad(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["local", "local", "local", "local"] + config["hidden_dropout_prob"] = 0.0 + config["local_attention_probs_dropout_prob"] = 0.0 + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) + model.train() + model.zero_grad() + input_ids, _ = self._get_input_ids_and_mask() + loss = model(input_ids=input_ids, labels=input_ids)[0] + + self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) + loss.backward() + + # check last grads to cover all proable errors + grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + expected_grad_slice_word = torch.tensor( + [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device, + ) + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + expected_grad_slice_pos_fac_1 = torch.tensor( + [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device, + ) + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + expected_grad_slice_pos_fac_2 = torch.tensor( + [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3)) + + def test_lsh_lm_model_grad(self): + config = self._get_basic_config_and_input() + config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] + config["hidden_dropout_prob"] = 0.0 + config["lsh_attention_probs_dropout_prob"] = 0.0 + config["num_buckets"] = [2, 4] + config["num_hashes"] = 6 + torch.manual_seed(0) + model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) + model.train() + model.zero_grad() + input_ids, _ = self._get_input_ids_and_mask() + loss = model(input_ids=input_ids, labels=input_ids)[0] + + self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3)) + loss.backward() + # check last grads to cover all proable errors + grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + expected_grad_slice_word = torch.tensor( + [2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], dtype=torch.float, device=torch_device, + ) + grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + expected_grad_slice_pos_fac_1 = torch.tensor( + [-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], dtype=torch.float, device=torch_device, + ) + grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + expected_grad_slice_pos_fac_2 = torch.tensor( + [0.4626, -0.0231, -0.0172, 0.1081, 0.3805], dtype=torch.float, device=torch_device, + ) + self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3)) + self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3)) + + @slow + def test_pretrained_generate_crime_and_punish(self): + model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device) + tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment") + model.eval() + + input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) + output_ids = model.generate( + input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 + ) + output_text = tokenizer.decode(output_ids[0]) + self.assertEqual( + output_text, + "A few months later state expression in his ideas, at the first entrance. He was positively for an inst", + )