From 39d7950ff19e7e02170259c8328e130553b9fd56 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 5 Nov 2020 16:56:30 +0100 Subject: [PATCH 01/10] First addition of Flax/Jax documentation Signed-off-by: Morgan Funtowicz --- docs/source/model_doc/bert.rst | 7 +++ docs/source/model_doc/roberta.rst | 7 +++ src/transformers/modeling_flax_bert.py | 55 +++++++++------------- src/transformers/modeling_flax_roberta.py | 57 +++++++++-------------- 4 files changed, 59 insertions(+), 67 deletions(-) diff --git a/docs/source/model_doc/bert.rst b/docs/source/model_doc/bert.rst index 277798042c1c..8d1322bb095d 100644 --- a/docs/source/model_doc/bert.rst +++ b/docs/source/model_doc/bert.rst @@ -188,3 +188,10 @@ TFBertForQuestionAnswering .. autoclass:: transformers.TFBertForQuestionAnswering :members: call + + +FlaxBertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertModel + :members: __call__ diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index 36c297df3d2f..9ae5062fcd18 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -146,3 +146,10 @@ TFRobertaForQuestionAnswering .. autoclass:: transformers.TFRobertaForQuestionAnswering :members: call + + +FlaxRobertaModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxRobertaModel + :members: __call__ diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index 92ab7dcf9a2f..da89889ed7cc 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from .configuration_bert import BertConfig -from .file_utils import add_start_docstrings +from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from .modeling_flax_utils import FlaxPreTrainedModel, gelu from .utils import logging @@ -35,13 +35,19 @@ BERT_START_DOCSTRING = r""" - This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. + Check the superclass documentation for the generic methods the library implements + for all its model (such as downloading, saving and converting weights from PyTorch models) - This model is also a PyTorch `torch.nn.Module `__ - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + This model is also a Flax Linen `flax.nn.Module `__ + subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as + `Just-In-Time (JIT) compilation `__, + `Automatic Differentiation `__, + `Vectorization `__, + and `Parallelization `__. Parameters: config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. @@ -52,50 +58,32 @@ BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + input_ids (:obj:`Numpy array` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.BertTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: - 0 corresponds to a `sentence A` token, - 1 corresponds to a `sentence B` token. - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + `What are token type IDs? <../glossary.html#token-type-ids>`__ + position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ @@ -291,7 +279,7 @@ class FlaxBertModule(nn.Module): intermediate_size: int @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): # Embedding embeddings = FlaxBertEmbeddings( @@ -410,7 +398,8 @@ def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): def module(self) -> nn.Module: return self._module - def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None): if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -423,7 +412,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_ return self.model.apply( {"params": self.params}, jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), ) diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index 551ff8d52561..d907c4f3a59f 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -21,7 +21,7 @@ import jax.numpy as jnp from .configuration_roberta import RobertaConfig -from .file_utils import add_start_docstrings +from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from .modeling_flax_utils import FlaxPreTrainedModel, gelu from .utils import logging @@ -34,13 +34,19 @@ ROBERTA_START_DOCSTRING = r""" - This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. + Check the superclass documentation for the generic methods the library implements + for all its model (such as downloading, saving and converting weights from PyTorch models) - This model is also a PyTorch `torch.nn.Module `__ - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + This model is also a Flax Linen `flax.nn.Module `__ + subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as + `Just-In-Time (JIT) compilation `__, + `Automatic Differentiation `__, + `Vectorization `__, + and `Parallelization `__. Parameters: config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the @@ -51,50 +57,32 @@ ROBERTA_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + input_ids (:obj:`Numpy array` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. - Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: - 0 corresponds to a `sentence A` token, - 1 corresponds to a `sentence B` token. - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + `What are token type IDs? <../glossary.html#token-type-ids>`__ + position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ @@ -302,7 +290,7 @@ class FlaxRobertaModule(nn.Module): intermediate_size: int @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + def __call__(self, attention_mask, input_ids, token_type_ids, position_ids): # Embedding embeddings = FlaxRobertaEmbeddings( @@ -421,7 +409,8 @@ def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs): def module(self) -> nn.Module: return self._module - def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None): if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -436,7 +425,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_ return self.model.apply( {"params": self.params}, jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), ) From bc31764b9538e802b1796c7550200b7f10c740d0 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 10 Nov 2020 10:41:54 +0100 Subject: [PATCH 02/10] make style --- src/transformers/modeling_flax_bert.py | 26 +++++++++++------------ src/transformers/modeling_flax_roberta.py | 26 +++++++++++------------ 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index da89889ed7cc..bd7f2b9a2c16 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -35,19 +35,19 @@ BERT_START_DOCSTRING = r""" - This model inherits from :class:`~transformers.FlaxPreTrainedModel`. - Check the superclass documentation for the generic methods the library implements - for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a Flax Linen `flax.nn.Module `__ - subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as - `Just-In-Time (JIT) compilation `__, - `Automatic Differentiation `__, - `Vectorization `__, - and `Parallelization `__. + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading, saving and converting weights from + PyTorch models) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation + `__, `Automatic Differentiation + `__, `Vectorization + `__, and `Parallelization + `__. Parameters: config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index d907c4f3a59f..e5d6ea46d83d 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -34,19 +34,19 @@ ROBERTA_START_DOCSTRING = r""" - This model inherits from :class:`~transformers.FlaxPreTrainedModel`. - Check the superclass documentation for the generic methods the library implements - for all its model (such as downloading, saving and converting weights from PyTorch models) - - This model is also a Flax Linen `flax.nn.Module `__ - subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as - `Just-In-Time (JIT) compilation `__, - `Automatic Differentiation `__, - `Vectorization `__, - and `Parallelization `__. + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading, saving and converting weights from + PyTorch models) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation + `__, `Automatic Differentiation + `__, `Vectorization + `__, and `Parallelization + `__. Parameters: config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the From c720c092ac0ddb2a75ca372c38a34e3fc8e20c6b Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 16:21:25 +0100 Subject: [PATCH 03/10] Ensure input order match between Bert & Roberta Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_flax_roberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index e5d6ea46d83d..7f767e9ae375 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -290,7 +290,7 @@ class FlaxRobertaModule(nn.Module): intermediate_size: int @nn.compact - def __call__(self, attention_mask, input_ids, token_type_ids, position_ids): + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): # Embedding embeddings = FlaxRobertaEmbeddings( From e9e09354086b3057223f42ef43d17d6ea284a45c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 17:27:02 +0100 Subject: [PATCH 04/10] Install dependencies "all" when building doc Signed-off-by: Morgan Funtowicz --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index aa48f0756055..9f62c4edfa70 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -281,7 +281,7 @@ jobs: - v0.4-build_doc-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[tf,torch,sentencepiece,docs] + - run: pip install .[all, docs] - save_cache: key: v0.4-build_doc-{{ checksum "setup.py" }} paths: From 05464cb3717f57e29fb391f6e5175655a70da188 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 17:45:16 +0100 Subject: [PATCH 05/10] wraps build_doc deps with "" Signed-off-by: Morgan Funtowicz --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9f62c4edfa70..d565e7477681 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -281,7 +281,7 @@ jobs: - v0.4-build_doc-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[all, docs] + - run: pip install ."[all, docs]" - save_cache: key: v0.4-build_doc-{{ checksum "setup.py" }} paths: From b44cb92b323f665fba97018f03ca19f3c1b35fe3 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 19:49:21 +0100 Subject: [PATCH 06/10] Addressing @sgugger comments. Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_flax_bert.py | 19 ++++++++++++++----- src/transformers/modeling_flax_roberta.py | 10 +++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index bd7f2b9a2c16..133f2354d925 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -58,22 +58,22 @@ BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` of shape :obj:`({0})`): + input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.BertTokenizer`. See - :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for + :meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: @@ -81,7 +81,7 @@ - 1 corresponds to a `sentence B` token. `What are token type IDs? <../glossary.html#token-type-ids>`__ - position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. return_dict (:obj:`bool`, `optional`): @@ -416,3 +416,12 @@ def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), ) + + +class FlaxBertForLMHead(FlaxBertModel): + def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): + super().__init__(config, state, seed, **kwargs) + + def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, return_dict=None, labels=None,): + encoder, pooled = super().__call__(input_ids, attention_mask, token_type_ids, position_ids) + diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index 7f767e9ae375..a7cae82f8ea8 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -57,22 +57,22 @@ ROBERTA_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` of shape :obj:`({0})`): + input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.BertTokenizer`. See - :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: @@ -80,7 +80,7 @@ - 1 corresponds to a `sentence B` token. `What are token type IDs? <../glossary.html#token-type-ids>`__ - position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): + position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. return_dict (:obj:`bool`, `optional`): From 9e55220503655ab45cefcd2d093e6f1731d01701 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 19:56:10 +0100 Subject: [PATCH 07/10] Use list to highlight JAX features. Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_flax_bert.py | 11 ++++++----- src/transformers/modeling_flax_roberta.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index 133f2354d925..4d60f36a6bc6 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -43,11 +43,12 @@ `__ subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation - `__, `Automatic Differentiation - `__, `Vectorization - `__, and `Parallelization - `__. + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ Parameters: config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index a7cae82f8ea8..0010e004ab8d 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -42,11 +42,12 @@ `__ subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation - `__, `Automatic Differentiation - `__, `Vectorization - `__, and `Parallelization - `__. + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ Parameters: config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the From 51dedc7f0bd1eaeeda57d5b4351a721e4117cd5c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 20:16:39 +0100 Subject: [PATCH 08/10] Make style. Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_flax_bert.py | 13 ++++++++++--- src/transformers/modeling_flax_roberta.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index 4d60f36a6bc6..8ee0c7ded3c6 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -44,7 +44,7 @@ Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - + - `Just-In-Time (JIT) compilation `__ - `Automatic Differentiation `__ - `Vectorization `__ @@ -423,6 +423,13 @@ class FlaxBertForLMHead(FlaxBertModel): def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): super().__init__(config, state, seed, **kwargs) - def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, return_dict=None, labels=None,): + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + return_dict=None, + labels=None, + ): encoder, pooled = super().__call__(input_ids, attention_mask, token_type_ids, position_ids) - diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index 0010e004ab8d..8ae15bdc5169 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -43,7 +43,7 @@ Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - + - `Just-In-Time (JIT) compilation `__ - `Automatic Differentiation `__ - `Vectorization `__ From 417f60b5f08691a2b9bb20775feda6f6d20aacb0 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 11 Nov 2020 20:33:24 +0100 Subject: [PATCH 09/10] Let's not look to much into the future for now. Signed-off-by: Morgan Funtowicz --- src/transformers/modeling_flax_bert.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index 8ee0c7ded3c6..da926aca1154 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -419,17 +419,3 @@ def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position ) -class FlaxBertForLMHead(FlaxBertModel): - def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): - super().__init__(config, state, seed, **kwargs) - - def __call__( - self, - input_ids, - attention_mask=None, - token_type_ids=None, - position_ids=None, - return_dict=None, - labels=None, - ): - encoder, pooled = super().__call__(input_ids, attention_mask, token_type_ids, position_ids) From 98a4f5c2d7a52514c788f55beb3b8fee7e7db2dd Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 11 Nov 2020 14:47:56 -0500 Subject: [PATCH 10/10] Style --- src/transformers/modeling_flax_bert.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index da926aca1154..50b499e0837c 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -417,5 +417,3 @@ def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), ) - -