Skip to content

Tokenizers v3.0.0#3185

Merged
mfuntowicz merged 67 commits intomasterfrom
tokenizers-v3.0.0
Apr 6, 2020
Merged

Tokenizers v3.0.0#3185
mfuntowicz merged 67 commits intomasterfrom
tokenizers-v3.0.0

Conversation

@mfuntowicz
Copy link
Copy Markdown
Member

No description provided.

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this is quite a different test suite than what we have in test_modeling_common and test_tokenization_common, but in a good way, imo.

The use of subTest will greatly help debugging and splitting between test_xxx methods and assert_xxx methods makes the code cleaner and easier to read.

Comment thread tests/test_tokenization_fast.py Outdated
Comment on lines +32 to +36
TOKENIZERS_CLASSES = frozenset([
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file"),
Tokenizer("DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file"),
Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file"),
])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frozenset + named tuple: nice!

Comment thread tests/test_tokenization_fast.py Outdated
Comment on lines +43 to +45
for (name, rust_cls, python_cls, vocab_key) in self.TOKENIZERS_CLASSES:
for pretrained_name in python_cls.pretrained_vocab_files_map[vocab_key].keys():
with self.subTest("{} ({})".format(name, pretrained_name)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to me like to optimal organisation. Testing on every checkpoint on every tokenizer organized in subTests is really thorough.

Does it take a while? Should we mark this as slow or is it fast enough?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It takes around 1 min for doing everything

Comment on lines +266 to +265
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually try to use lambdas as little as possible, as they're usually a bit hard to read. cc @thomwolf

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm here I think it's fine, no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tests I'm more fine with them

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Copy Markdown
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, just a few comments on things to check

Comment thread src/transformers/tokenization_utils.py Outdated
Comment thread src/transformers/tokenization_utils.py Outdated
Comment thread src/transformers/tokenization_utils.py Outdated
Comment on lines +1964 to +1965
return_token_type_ids: bool = True,
return_attention_mask: bool = True,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should rebase on master if you can because I think this is now Optional[bool] = None since @LysandreJik worked to adapt the output to the models.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to fully undestand why bool should be Optional ? A default value would be more understable imho and remove the need for None checking, wdyt ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check w @LysandreJik :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spirit behind having those values as None by default is the following: if this value is None, then all those returns are set to the default tokenizer-specific values. This is different from tokenizer to tokenizer, e.g. DistilBERT should by default return input_ids and attention_mask, but not token_type_ids, as the model cannot handle it. This in turn allows the user to do the following:

inputs = tokenizer.encode_plus(values, return_tensors="pt")
model(**inputs)

And this now works with every model.

These values can still be explicitly set to True or False by the user. See #3116 for more information/implementation details.

mfuntowicz and others added 26 commits March 26, 2020 15:00
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
…Fast

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
…ded to the output #3091

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This new structure exposes all the mappings retrieved from Rust.
It also keeps the current behavior with model forward.
Copy link
Copy Markdown
Contributor

@n1t0 n1t0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few details here and there, but otherwise looks good to me!

) -> List[Encoding]:
if sequences is None:
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these messages might be more specific for each method. This one should probably just say list/tuple of strings.

def encode(self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> Encoding:
if sequence is None:
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should probably just say string

Comment thread src/transformers/tokenization_utils.py Outdated
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping:
encoding_dict["offset_mapping"].append([e.original_str.offsets(o) for o in e.offsets])
encoding_dict["offset_mapping"] = [o for o in e.offsets]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be like above:

Suggested change
encoding_dict["offset_mapping"] = [o for o in e.offsets]
encoding_dict["offset_mapping"].append(e.offsets)

Comment thread src/transformers/tokenization_utils.py Outdated

if batch_text_or_text_pairs is None:
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I think we can remove the string option

self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be enabled?


# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Rust correctly handles the space before the mask while python doesnt
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Python is right here. There shouldn't be any space before the <mask> token. This means that Roberta on the fast path should probably have an AddedToken('<mask>', lstrip=True)

Comment thread tests/test_tokenization_gpt2.py Outdated
# Testing tokenization
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
rust_tokens = rust_tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence, add_prefix_space=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_prefix_space=True isn't required here I think

mfuntowicz and others added 4 commits March 30, 2020 17:25
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
@mfuntowicz mfuntowicz marked this pull request as ready for review March 31, 2020 13:11
mfuntowicz and others added 13 commits March 31, 2020 16:30
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
…every iteration.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
… for Roberta.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
…utes.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
@codecov-io
Copy link
Copy Markdown

codecov-io commented Apr 1, 2020

Codecov Report

Merging #3185 into master will decrease coverage by 0.24%.
The diff coverage is 80.42%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3185      +/-   ##
==========================================
- Coverage   77.79%   77.55%   -0.25%     
==========================================
  Files         100      100              
  Lines       17025    17105      +80     
==========================================
+ Hits        13245    13265      +20     
- Misses       3780     3840      +60
Impacted Files Coverage Δ
src/transformers/tokenization_bert.py 95.33% <ø> (-1.7%) ⬇️
src/transformers/tokenization_roberta.py 94.36% <100%> (-5.64%) ⬇️
src/transformers/pipelines.py 74.51% <100%> (-0.28%) ⬇️
src/transformers/tokenization_transfo_xl.py 40.67% <100%> (-0.43%) ⬇️
src/transformers/tokenization_utils.py 86.18% <78.61%> (-5.81%) ⬇️
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7420a6a...860cf66. Read the comment docs.

@LysandreJik LysandreJik self-requested a review April 1, 2020 16:15
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't easy to review as the diff mixes substractions from PreTrainedTokenizer and additions from BatchEncoding; other than that, cool! Thanks @mfuntowicz :)

Comment on lines +48 to +52
# Define type aliases
TextInput = str
TextPairInput = Tuple[str, str]
PreTokenizedInput = List[str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this!

Comment thread src/transformers/tokenization_utils.py Outdated
Comment on lines +172 to +175
Find the Offsets of the token containing the character at the specified position
:param sentence: Index of the sentence relative to the batch provided to the tokenizer.
:param char: Char index to get the relative token offsets
:return: (token start, token end)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Applicable to most other docstrings) We use Google style doc in the library - could we try and use them here as well?

@LysandreJik
Copy link
Copy Markdown
Member

Really like the new typings!

@mfuntowicz mfuntowicz merged commit 96ab75b into master Apr 6, 2020
@mfuntowicz mfuntowicz deleted the tokenizers-v3.0.0 branch April 6, 2020 22:29

# Filter out features not available on specific models
inputs = self.inputs_for_model(inputs)
# inputs = self.inputs_for_model(inputs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remember to remove this for good soon @mfuntowicz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants