Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility#3383
Conversation
a223edd to
8b5732a
Compare
Codecov Report
@@ Coverage Diff @@
## master #3383 +/- ##
==========================================
+ Coverage 78.61% 78.90% +0.29%
==========================================
Files 106 105 -1
Lines 17953 17973 +20
==========================================
+ Hits 14114 14182 +68
+ Misses 3839 3791 -48
Continue to review full report at Codecov.
|
d2b75d1 to
f92e17a
Compare
a3a0539 to
4089d8f
Compare
|
|
||
| # this line reruns all the tests in BertModelTest; not sure whether this can be prevented | ||
| # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest | ||
| from .test_modeling_bert import BertModelTest |
There was a problem hiding this comment.
This was quite annoying. As far as I can see it is not possible to import a unittest testcase, such as BertModelTest from .test_modeling_bert without running all of its test. At the same time I think we need to import the TestCase here since we need to instantiate BertEncoder and BertDecoder models without copy pasting all of its code. I tried a couple of hacks, but didn't find a good solution. Maybe it's possible to only run pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from cicrle ci in which case only the encoder-decoder model tests are run? Any good ideas for this case? @julien-c @LysandreJik @thomwolf @sshleifer
There was a problem hiding this comment.
-
Move the ModelTester to the top level of that file.
-
in that file,
self.model_tester = BertModelTester(self) -
only import
BertModelTester
verified that no extra tests are run.
You can merge #4027 if you want to keep your delta small.
There was a problem hiding this comment.
Awesome! I will try this out :-)
There was a problem hiding this comment.
Would it be ok to change Bert's testing structure as shown in PR #4027 for you guys @julien-c @thomwolf @LysandreJik ?
thomwolf
left a comment
There was a problem hiding this comment.
Nice draft.
See my comments on serialization with proposals below.
It's a bit dense so happy to discuss them if needed.
| class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. | ||
| """ | ||
|
|
||
| def __init__(self, encoder, decoder): |
There was a problem hiding this comment.
So as I detailed below under the from_pretrained method, I think here we should allow both an initialization from (pretained) sub-parts and from a config so we can also just reload a saved pretrained model.
Please read the long comment below from_pretrained before this one.
Here is a proposal for the __init__ that should allow both loading:
def __init__(self, config=None, encoder=None, decoder=None):
assert config is not None or (encoder is not None and decoder is not None), "Either a configuration or an Encoder and a decoder has to be provided"
if config is not None:
self.config = config
self.encoder = AutoModel.from_config(config.encoder)
self.decoder = AutoModelWithLMHead.from_config(config.decoder)
else:
self.config = EncoderDecoderConfig.from_encoder_decoder_config(encoder.config, decoder.config)
self.encoder = encoder
assert (
self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
self.decoder = decoderThere was a problem hiding this comment.
This mean we need an EncoderDecoderConfig which will basically be the following:
class EncoderDecoderConfig(PretrainedConfig):
model_type = "encoder_decoder"
def __init__(self, **kwargs):
encoder_config = kwargs.pop('encoder')
encoder_model_type = encoder_config.pop('model_type')
decoder_config = kwargs.pop('decoder')
decoder_model_type = decoder_config.pop('model_type')
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
@classmethod
def from_encoder_decoder_config(cls, encoder_config, decoder_config):
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output['encoder'] = self.encoder.to_dict()
output['decoder'] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
sshleifer
left a comment
There was a problem hiding this comment.
maybe test coverage for BartEncoder + BartDecoder?
Also we should add the caching logic to this at some point.
36ff2fd to
16b95a8
Compare
|
UPDATE: I really liked the idea of adding a encoder_decoder config (@thomwolf ) and having for both the encoder_decoder config and model 2
To understand how to use the encoder decoder class please confer to the added tests. |
thomwolf
left a comment
There was a problem hiding this comment.
Really awesome!
And great tests as well!
LGTM :-)
| encoder_hidden_states = encoder_outputs[0] | ||
| else: | ||
| encoder_outputs = () | ||
| kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
There was a problem hiding this comment.
so neutral kwargs like use_cache only go to encoder? Is that desired/irrelevant?
There was a problem hiding this comment.
I think the user will always have to prepend each argument with decoder_ in order to use it for the encoder. So if one would use a Bert + GPT2 (which is not supported at the moment), one would have to write decoder_use_cache=True. IMO this is ok.
|
|
||
| def _reorder_cache(self, past, beam_idx): | ||
| # as a default encoder-decoder models do not re-order the past. | ||
| # TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder |
There was a problem hiding this comment.
tuple(layer_past.index_select(1, beam_idx) for layer_past in past)This is a better IMO or raise NotImplementedError isn't this implem incorrect?
There was a problem hiding this comment.
It should not raise anything because the function is called in generate and should not reorder anything. I think for the moment this is ok.
|
|
||
| # this line reruns all the tests in BertModelTest; not sure whether this can be prevented | ||
| # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest | ||
| from .test_modeling_bert import BertModelTest |
There was a problem hiding this comment.
-
Move the ModelTester to the top level of that file.
-
in that file,
self.model_tester = BertModelTester(self) -
only import
BertModelTester
verified that no extra tests are run.
You can merge #4027 if you want to keep your delta small.
| input_ids_dict = self.prepare_config_and_inputs_bert() | ||
| self.create_and_check_save_and_load(**input_ids_dict) | ||
|
|
||
| def test_save_and_load_from_encoder_decoder_pretrained(self): |
There was a problem hiding this comment.
I don't understand the point of all these helper functions that you use once.
There was a problem hiding this comment.
Do you mean the create_and_check... functions? It's done everywhere in the tests and I think it's actually a nice way to have short and easy to read/see test functions at the bottom and all the logic further up
|
Code is cleaned: added type hints, cleaned the docstring and added a encoder-decoder model page. |
|
LGTM! |
|
Ok Good to merge for me! If @sshleifer it's ok for you I will use your PR #4027 for the new test proposition. |
|
Yes! |
Bert-Bert Encoder-Decoder models can now be used as is shown in the test cases:
tests/test_modeling_encoder_decoder.py.Tests include:
input_idsanddecoder_input_idsfor Bert-Bertgenerate()fn with Bart-BartBefore merging a couple of things have to be agreed on as mentioned in the comments below.
UPDATE:
This branch is IMO now fully functional for Bert-2-Bert models.
I will finish the PR (clean the code, make a pretty docstring, etc...) once we agreed on the issues I mentioned further down. Would be very happy if you can review @thomwolf @LysandreJik @sshleifer @julien-c @yjernite