Skip to content

Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility#3383

Merged
thomwolf merged 22 commits intohuggingface:masterfrom
patrickvonplaten:clean_encoder_decocer_modeling
Apr 28, 2020
Merged

Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility#3383
thomwolf merged 22 commits intohuggingface:masterfrom
patrickvonplaten:clean_encoder_decocer_modeling

Conversation

@patrickvonplaten
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten commented Mar 22, 2020

Bert-Bert Encoder-Decoder models can now be used as is shown in the test cases:
tests/test_modeling_encoder_decoder.py.
Tests include:

  • forward input_ids and decoder_input_ids for Bert-Bert
  • backprop using masked language model loss for Bert-Bert
  • backprop using "conventional" language model loss for Bert-Bert
  • using the generate() fn with Bart-Bart
  • saving and loading of Encoder-Decoder models.

Before merging a couple of things have to be agreed on as mentioned in the comments below.

UPDATE:

This branch is IMO now fully functional for Bert-2-Bert models.
I will finish the PR (clean the code, make a pretty docstring, etc...) once we agreed on the issues I mentioned further down. Would be very happy if you can review @thomwolf @LysandreJik @sshleifer @julien-c @yjernite

Copy link
Copy Markdown
Contributor

@mgoldey mgoldey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this will supersede my own attempted changes in #3402 - do you have an example in the works for this modified class?

Comment thread src/transformers/modeling_encoder_decoder.py Outdated
Comment thread src/transformers/modeling_encoder_decoder.py Outdated
@patrickvonplaten patrickvonplaten force-pushed the clean_encoder_decocer_modeling branch from a223edd to 8b5732a Compare March 30, 2020 12:59
@codecov-io
Copy link
Copy Markdown

codecov-io commented Mar 30, 2020

Codecov Report

Merging #3383 into master will increase coverage by 0.29%.
The diff coverage is 88.46%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3383      +/-   ##
==========================================
+ Coverage   78.61%   78.90%   +0.29%     
==========================================
  Files         106      105       -1     
  Lines       17953    17973      +20     
==========================================
+ Hits        14114    14182      +68     
+ Misses       3839     3791      -48     
Impacted Files Coverage Δ
src/transformers/modeling_bert.py 89.01% <75.00%> (+0.60%) ⬆️
src/transformers/modeling_encoder_decoder.py 84.41% <92.30%> (+63.36%) ⬆️
src/transformers/__init__.py 98.98% <100.00%> (ø)
src/transformers/modeling_utils.py 91.06% <0.00%> (+0.12%) ⬆️
src/transformers/modeling_tf_utils.py 92.76% <0.00%> (+0.16%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 857ccdb...83f3d10. Read the comment docs.

@patrickvonplaten patrickvonplaten force-pushed the clean_encoder_decocer_modeling branch from d2b75d1 to f92e17a Compare April 9, 2020 16:15
@patrickvonplaten patrickvonplaten force-pushed the clean_encoder_decocer_modeling branch from a3a0539 to 4089d8f Compare April 20, 2020 16:44
Comment thread src/transformers/modeling_encoder_decoder.py Outdated
Comment thread src/transformers/modeling_encoder_decoder.py
Comment thread src/transformers/modeling_encoder_decoder.py
Comment thread src/transformers/utils_encoder_decoder.py

# this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was quite annoying. As far as I can see it is not possible to import a unittest testcase, such as BertModelTest from .test_modeling_bert without running all of its test. At the same time I think we need to import the TestCase here since we need to instantiate BertEncoder and BertDecoder models without copy pasting all of its code. I tried a couple of hacks, but didn't find a good solution. Maybe it's possible to only run pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from cicrle ci in which case only the encoder-decoder model tests are run? Any good ideas for this case? @julien-c @LysandreJik @thomwolf @sshleifer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Move the ModelTester to the top level of that file.

  2. in that file, self.model_tester = BertModelTester(self)

  3. only import BertModelTester

verified that no extra tests are run.

You can merge #4027 if you want to keep your delta small.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I will try this out :-)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be ok to change Bert's testing structure as shown in PR #4027 for you guys @julien-c @thomwolf @LysandreJik ?

Comment thread src/transformers/modeling_bert.py Outdated
@patrickvonplaten patrickvonplaten changed the title [WIP] Update Encoder-Decoder models with Bart/T5-like API and add generate possibility [Require Feedback to continue] Update Encoder-Decoder models with Bart/T5-like API and add generate possibility Apr 20, 2020
Copy link
Copy Markdown
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice draft.
See my comments on serialization with proposals below.
It's a bit dense so happy to discuss them if needed.

Comment thread src/transformers/modeling_encoder_decoder.py
Comment thread src/transformers/modeling_encoder_decoder.py
Comment thread src/transformers/modeling_encoder_decoder.py Outdated
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
"""

def __init__(self, encoder, decoder):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as I detailed below under the from_pretrained method, I think here we should allow both an initialization from (pretained) sub-parts and from a config so we can also just reload a saved pretrained model.

Please read the long comment below from_pretrained before this one.

Here is a proposal for the __init__ that should allow both loading:

    def __init__(self, config=None, encoder=None, decoder=None):
        assert config is not None or (encoder is not None and decoder is not None), "Either a configuration or an Encoder and a decoder has to be provided"

        if config is not None:
            self.config = config
            self.encoder = AutoModel.from_config(config.encoder)
            self.decoder = AutoModelWithLMHead.from_config(config.decoder)
        else:
            self.config = EncoderDecoderConfig.from_encoder_decoder_config(encoder.config, decoder.config)
            self.encoder = encoder
            assert (
                self.encoder.get_output_embeddings() is None
            ), "The encoder {} should not have a LM Head. Please use a model without LM Head"
            self.decoder = decoder

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mean we need an EncoderDecoderConfig which will basically be the following:

class EncoderDecoderConfig(PretrainedConfig):
    model_type = "encoder_decoder"
    def __init__(self, **kwargs):
        encoder_config = kwargs.pop('encoder')
        encoder_model_type = encoder_config.pop('model_type')
        decoder_config = kwargs.pop('decoder')
        decoder_model_type = decoder_config.pop('model_type')
        self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
        self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)

    @classmethod
    def from_encoder_decoder_config(cls, encoder_config, decoder_config):
        return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())

     def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.

        Returns:
            :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output['encoder'] = self.encoder.to_dict()
        output['decoder'] = self.decoder.to_dict()
        output["model_type"] = self.__class__.model_type
        return output

Copy link
Copy Markdown
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe test coverage for BartEncoder + BartDecoder?
Also we should add the caching logic to this at some point.

Comment thread src/transformers/modeling_bert.py Outdated
Comment thread src/transformers/modeling_encoder_decoder.py Outdated
@patrickvonplaten patrickvonplaten force-pushed the clean_encoder_decocer_modeling branch from 36ff2fd to 16b95a8 Compare April 27, 2020 19:56
Comment thread src/transformers/configuration_encoder_decoder.py Outdated
Comment thread src/transformers/configuration_encoder_decoder.py Outdated
Comment thread src/transformers/modeling_encoder_decoder.py
@patrickvonplaten patrickvonplaten changed the title [Require Feedback to continue] Update Encoder-Decoder models with Bart/T5-like API and add generate possibility Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility Apr 27, 2020
@patrickvonplaten
Copy link
Copy Markdown
Contributor Author

patrickvonplaten commented Apr 27, 2020

UPDATE:

I really liked the idea of adding a encoder_decoder config (@thomwolf ) and having for both the encoder_decoder config and model 2 from_pretrained fn:

  1. The standard one which is used thanks to inheritence to PretrainedConfig and PretrainedModel
  2. a from_encoder_decoder_pretrained fn (@sshleifer)

To understand how to use the encoder decoder class please confer to the added tests.

Copy link
Copy Markdown
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really awesome!
And great tests as well!

LGTM :-)

Comment thread src/transformers/modeling_encoder_decoder.py
Comment thread src/transformers/configuration_encoder_decoder.py Outdated
Comment thread src/transformers/configuration_encoder_decoder.py Outdated
Comment thread src/transformers/configuration_encoder_decoder.py Outdated
Comment thread src/transformers/modeling_bert.py Outdated
Comment thread src/transformers/modeling_encoder_decoder.py Outdated
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so neutral kwargs like use_cache only go to encoder? Is that desired/irrelevant?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user will always have to prepend each argument with decoder_ in order to use it for the encoder. So if one would use a Bert + GPT2 (which is not supported at the moment), one would have to write decoder_use_cache=True. IMO this is ok.


def _reorder_cache(self, past, beam_idx):
# as a default encoder-decoder models do not re-order the past.
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuple(layer_past.index_select(1, beam_idx) for layer_past in past)

This is a better IMO or raise NotImplementedError isn't this implem incorrect?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not raise anything because the function is called in generate and should not reorder anything. I think for the moment this is ok.

Comment thread tests/test_modeling_encoder_decoder.py Outdated

# this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Move the ModelTester to the top level of that file.

  2. in that file, self.model_tester = BertModelTester(self)

  3. only import BertModelTester

verified that no extra tests are run.

You can merge #4027 if you want to keep your delta small.

input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load(**input_ids_dict)

def test_save_and_load_from_encoder_decoder_pretrained(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the point of all these helper functions that you use once.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the create_and_check... functions? It's done everywhere in the tests and I think it's actually a nice way to have short and easy to read/see test functions at the bottom and all the logic further up

@patrickvonplaten
Copy link
Copy Markdown
Contributor Author

Code is cleaned: added type hints, cleaned the docstring and added a encoder-decoder model page.
Just need to resolve the issue with importing Bert's model tester. @sshleifer found a solution. If everybody is fine with it - I'll go for it :-)

@sshleifer
Copy link
Copy Markdown
Contributor

LGTM!

@thomwolf thomwolf merged commit fa49b9a into huggingface:master Apr 28, 2020
@patrickvonplaten
Copy link
Copy Markdown
Contributor Author

Ok Good to merge for me! If @sshleifer it's ok for you I will use your PR #4027 for the new test proposition.

@sshleifer
Copy link
Copy Markdown
Contributor

Yes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants