Skip to content

add new model prophetnet#6187

Closed
qiweizhen wants to merge 4 commits intohuggingface:masterfrom
qiweizhen:master
Closed

add new model prophetnet#6187
qiweizhen wants to merge 4 commits intohuggingface:masterfrom
qiweizhen:master

Conversation

@qiweizhen
Copy link
Copy Markdown
Contributor

@qiweizhen qiweizhen commented Aug 1, 2020

Add new model structure ProphetNet.

Description:

ProphetNet is a new pre-trained language model for sequence-to-sequence learning with a novel self-supervised objective called future n-gram prediction. ProphetNet is able to predict more future tokens with a n-stream decoder. The original implementation is Fairseq version at github repo.
xProphetNet has the same model structure, but is pretrained with wikipedia 100 languages dataset as described in xGLUE. xGLUE is a benchmark for corss-lingual NLU and NLG tasks. xProphetNet is also served as a baseline model for cross-lingual generation tasks in xGLUE NTG and QG.

Usage:

Take xGLUE NTG task as an example:
Cross-lingual pretrained model is finetuned with English news title generation data, but inference with both English and other zero-shot language data.
A quick usage is like:

from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, ProphetNetConfig

model = ProphetNetForConditionalGeneration.from_pretrained('microsoft/xprophetnet-large-wiki100-cased-xglue-ntg')
tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased-xglue-ntg')

EN_SENTENCE_TO_QUESTION = "Microsoft Corporation intends to officially end free support for the Windows 7 operating system after January 14, 2020, according to the official portal of the organization. From that day, users of this system will not be able to receive security updates, which could make their computers vulnerable to cyber attacks."
RU_SENTENCE_TO_QUESTION = "орпорация Microsoft намерена официально прекратить бесплатную поддержку операционной системы Windows 7 после 14 января 2020 года, сообщается на официальном портале организации . С указанного дня пользователи этой системы не смогут получать обновления безопасности, из-за чего их компьютеры могут стать уязвимыми к кибератакам."
ZH_SENTENCE_TO_QUESTION = "根据该组织的官方门户网站,微软公司打算在2020年1月14日之后正式终止对Windows 7操作系统的免费支持。从那时起,该系统的用户将无法接收安全更新,这可能会使他们的计算机容易受到网络攻击。"
inputs = tokenizer([EN_SENTENCE_TO_QUESTION, RU_SENTENCE_TO_QUESTION, ZH_SENTENCE_TO_QUESTION], padding=True, max_length=256, return_tensors='pt')

summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
print([tokenizer.decode(g) for g in summary_ids])  

Model will generate news titles like:

['[SEP] Microsoft to end Windows 7 free support after January 14, 2020[SEP][PAD][PAD][PAD][PAD]',
 '[SEP] Microsoft намерена прекратить бесплатную поддержку Windows 7 после 14 января 2020 года[SEP]',
 '[SEP]微软打算终止对Windows 7操作系统的免费支持[SEP][PAD][PAD][PAD][PAD][PAD][PAD]']

Released checkpoints:

pretrained:

microsoft/prophetnet-large-uncased
microsoft/xprophetnet-large-wiki100-cased

fine-tuned:

microsoft/prophetnet-large-uncased-cnndm
microsoft/xprophetnet-large-wiki100-cased-xglue-ntg
microsoft/xprophetnet-large-wiki100-cased-xglue-qg

@JetRunner JetRunner self-assigned this Aug 2, 2020
Copy link
Copy Markdown
Contributor

@JetRunner JetRunner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @qiweizhen ! It's really exciting to see ProphetNet since it's been a while not having a big language integration like this!

You can add the link and TL;DR of your paper in the README (and make sure to do so). Also, please write some unit tests to make sure the model works as expected (of course you can do that after we have done several rounds of reviews re. the modeling and tokenizer themselves). Please consider adding TensorFlow model once the pytorch one is shipped. Also, please complete the model cards of your uploaded weights!

Comment on lines +58 to +69
def __new__(cls, **kwargs):
xprophetnet_tokenizer = False if 'xprophetnet_tokenizer' not in kwargs.keys() else kwargs['xprophetnet_tokenizer']
if xprophetnet_tokenizer:
super_class = XLMRobertaTokenizer
else:
super_class = BertTokenizer
cls = type(cls.__name__, (cls, super_class), {})
if xprophetnet_tokenizer:
cls.vocab_files_names = VOCAB_FILES_NAMES_CROSS_LINGUAL
else:
cls.vocab_files_names = VOCAB_FILES_NAMES_EN
return super(ProphetNetTokenizer, cls).__new__(cls)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think it's too hacky here. However, wrt the vocab_files_names issue, you can simply choose your own name, e.g., prophet.bpe or prophet.tokenizer. Be creative lol!

Then I think you can put vocabulary_type as a configuration option in the config! There should be no problem then.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done as you suggested. Vocab files are now named as prophetnet.tokenization

**kwargs
):
if not xprophetnet_tokenizer:
# inherit from BERT tokenizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use the word copied from instead of inherit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I reviewed the codes in tokenization_bart and other files, it seems inheriting is also acceptable. So, instead of copying every function to modify them as self._tokenizer.func, I think this old code is simpler ant avoid causing bug when copying the codes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think both ways are okay but it's also good to hear from @LysandreJik

)
self.unique_no_split_tokens.append("[X_SEP]")
else:
# inherit from XLM-R tokenizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment thread src/transformers/modeling_prophetnet.py Outdated
inputs = tokenizer([EN_SENTENCE_TO_QUESTION, RU_SENTENCE_TO_QUESTION, ZH_SENTENCE_TO_QUESTION], padding=True, max_length=256, return_tensors='pt')

# Generate Summary(bos is removed)
summary_ids = model.generate(inputs['input_ids'][:, 1:], num_beams=4, max_length=100, early_stopping=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My solution is to simply merge this into the model like input = input[:, 1:]. Users don't need to care about details like this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done as you suggested. Remove bos will be done in the model.

from .modeling_prophetnet import (
ProphetNetModel,
ProphetNetForConditionalGeneration
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally understand that ProphetNet is a specialized model for generation but you may also want to add something like ProphetNetForSequenceClassification etc.? BART has them.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this suggestion. I will add ProphetNetForSequenceClassification in near future when I find a suitable way for NLU tasks. For example, I will compare different encoder-decoder NLU functions of T5, BART...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that it would be nice to also have ProphetNetForSequenceClassification, but I think we could handle this in a new PR. The most important is probably ProphetNetForConditionalGeneration for now

Comment thread src/transformers/modeling_auto.py Outdated
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
(ReformerConfig, ReformerModelWithLMHead),
(ProphetNetConfig, ProphetNetModel),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? No need for WithLMHeadModel?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted this line

@julien-c julien-c added the model card Related to pretrained model cards label Aug 5, 2020
Copy link
Copy Markdown
Contributor

@JetRunner JetRunner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sshleifer Is there anything else needed to do in order to make ProphetNet work with your seq2seq example?

Also @mfuntowicz for pipelines

@qiweizhen
Copy link
Copy Markdown
Contributor Author

@sshleifer Is there anything else needed to do in order to make ProphetNet work with your seq2seq example?

Also @mfuntowicz for pipelines

I tried examples/seq2seq/finetune.py and it works with python finetune.py --do_train and --do_predict.

@qiweizhen
Copy link
Copy Markdown
Contributor Author

I will try to complete document and unit test by this week

eps=0.0,
**common_kwargs
):
if "hidden_size" in common_kwargs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to call the parameter d_model? It would be great if we could only use hidden_size to avoid confusion. While a lot of previous models have d_model we are trying to be more consistent now with the naming.

Comment thread src/transformers/modeling_prophetnet.py Outdated
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False

def forward(self, input, use_cache=False, positions=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually call input, input_ids in all other models. Would be great if you can align the naming

real_positions = positions
else:
real_positions = positions
return super().forward(positions), real_positions
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing to me that two tensors are returned -> could you add a comment explaining why this is the case?
I think I would also prefer to split up the class into a function and a class:

  1. Function that calculates the "real_positions"
  2. A very simple nn.Embedding

I not really sure that we need such a big class. IMO, one can calculate the real_positions via a function, then we would only need nn.Embedding which is much more readable and don't need _forward at all and max_positions is such a small function that we can just copy paste it into the code.
But maybe I overlooked something - What do you think?

It's not that big of a deal though...we can also handle this in a refactor later.

Comment thread src/transformers/modeling_prophetnet.py Outdated
def _forward(self, positions):
return super().forward(positions)

def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very different to our usual design...having upper case functions is very unexpected for a function (even though the function instantiates a class...). I think a better way here is to do the following at the very top of the code:

try:
    from apex.normalization import FusedLayerNorm
    ProphetNetLayerNorm  = FusedLayerNorm
except ImportError:
    ProphetNetLayerNorm = torch.nn.LayerNorm

then the layer norm can be instantiated normally with:

ProphetNetLayerNorm(normalized_shape, eps, elementwise_affine)

At the moment export is always set to False when calling LayerNorm and the other two parameters are also not used...should we put eps and elementwise_affine in the config ? Do we need export? Or could we deleted it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe import LayerNorm from another model is a good idea?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like Bart, which does exactly this :)

Comment thread src/transformers/modeling_prophetnet.py Outdated
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

def invert_mask(attention_mask):
assert attention_mask.dim() == 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert message would be nice :-)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can be as easy as assert attention_mask.dim() == 2, "some error message"

Comment thread src/transformers/modeling_prophetnet.py Outdated
embed_dim,
num_heads,
dropout=0.0,
bias=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change to is_bias so everyone knows it's a flag

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse bart? This looks identical.

Comment thread src/transformers/modeling_prophetnet.py Outdated
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like complicated logic in the following lines :D Maybe add a comment?

Comment thread src/transformers/modeling_prophetnet.py Outdated
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"

def _shape(self, tensor, dim_0, bsz):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the function name _shape ? -> looks more like _reshape_for_...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and can we make it static for better readability

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static functions are always less scary IMO :D

Comment thread src/transformers/modeling_prophetnet.py Outdated
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert message would be great

Comment thread src/transformers/modeling_prophetnet.py Outdated
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert message would be great

Comment thread src/transformers/modeling_prophetnet.py Outdated

def forward(
self,
query,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think query are the hidden_states here if I'm not mistaken -> we usually call the input embedding representations hidden_states...by query I would think of the query projection which could be a bit misleading.

Comment thread src/transformers/modeling_prophetnet.py Outdated
v = self._shape(v, -1, bsz)

if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) can we try to avoid single letter variables, makes refactoring always very difficult afterwards -> maybe just stick to key, value and query and change the first key to hidden_states_key

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think k, v, q are acceptable but no loss to use a longer name

Comment thread src/transformers/modeling_prophetnet.py Outdated
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}

assert k is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert statements would be great

Comment thread src/transformers/modeling_prophetnet.py Outdated
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert statements would be great

Comment thread src/transformers/modeling_prophetnet.py Outdated
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert statements would be great

Comment thread src/transformers/modeling_prophetnet.py Outdated

if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) key_padding_mask[:, None, None] is nicer IMO

Comment thread src/transformers/modeling_prophetnet.py Outdated
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> attn_probs = F.softmax(attn_weights, dim=-1) after softmax we have probs

Comment thread src/transformers/modeling_prophetnet.py Outdated
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)

assert v is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert statement missing

Comment thread src/transformers/modeling_prophetnet.py Outdated
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)

def _relative_positions_bucket(self, relative_positions, bidirectional=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_bidirectional is maybe better

# input attn_weights [T*head,T,S]
# input real_positions [B,T] or [1,1]

T, B, _ = query.size()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of single letters here...but I guess OK for now

else:
saved_state = None
layer_state = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read the function correctly the function arguments key and value are not used -> can we delete them from the function signature?

Comment thread src/transformers/modeling_prophetnet.py Outdated
output_attentions = False
):

tgt_len, bsz, embed_dim = query.size()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename query to hidden_states here as well


if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a sidenot, it's always good to try whether expand works before using repeat -> repeat always reallocated new memory which can become heavy depending on how big the tensors are

Comment thread src/transformers/modeling_prophetnet.py Outdated
self_attn_mask = self_attn_mask.unsqueeze(0)
attn_weights_main = attn_weights_main + self_attn_mask

attn_weights_main = softmax(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attn_probs_main would be nicer

Comment thread src/transformers/modeling_prophetnet.py Outdated
attn_result.append(attn_ngram)

# [1+ngram*T, B, C]
attn = torch.cat(attn_result, 0).view(-1, bsz, embed_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be simpler and more readable to just do attn = torch.cat([attn_main, attn_ngram], 0).view(-1, bsz, embed_dim) and delete 3 lines above IMO.

Comment thread src/transformers/modeling_prophetnet.py Outdated
def _get_input_buffer(self, incremental_state):
return {}

def _set_input_buffer(self, incremental_state, buffer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function does not seem to do anything

Comment thread src/transformers/modeling_prophetnet.py Outdated

def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_buffer = {} instead or is _get_input_buffer not fully implemented yet?

Comment thread src/transformers/modeling_prophetnet.py Outdated

def forward(
self,
x,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_states instead of x

Comment thread src/transformers/modeling_prophetnet.py Outdated
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings + 2 + self.padding_idx,
embed_dim, self.padding_idx)
self.ngram_input_embed = nn.Embedding(self.ngram, embed_dim, None)
self.layers = nn.ModuleList([])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do everything in one line IMO -> no need for self.layers = nn.ModuleList([])

Comment thread src/transformers/modeling_prophetnet.py Outdated
i_buckets_main_stream, i_bucket_relative_stream = \
self.cal_finetune_relative_positions(real_positions)
predicting_stream_pos_embed = self.embed_positions._forward(real_positions + 1)
x = self.embed_tokens(input_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x -> hidden_states

Comment thread src/transformers/modeling_prophetnet.py Outdated
next_cache = None
return x_list, next_cache, all_hidden_states, list(all_self_attns)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete emtpy lines

Comment thread src/transformers/modeling_prophetnet.py Outdated

padding_idx, vocab_size, dim_size = config.pad_token_id, config.vocab_size, config.d_model
self.embed_tokens = nn.Embedding(vocab_size, dim_size, padding_idx=padding_idx)
nn.init.normal_(self.embed_tokens.weight, mean=0, std=dim_size ** -0.5)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing ? -> normally we handle the init in the ProphetNetPreTrainedModel, such as here:

def _init_weights(self, module):

Comment thread src/transformers/modeling_prophetnet.py Outdated
base_model = ProphetNetModel(config)
self.model = base_model
#self.padding_idx = config.pad_token_id
self.padding_idx = -100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why -100 and not config.pad_token_id? -100 is the token we usually use to disable the CE loss, not sure if this is set intentionally to this number here..

Comment thread src/transformers/modeling_prophetnet.py Outdated
output_attentions=None,
**unused,
):
if "lm_labels" in unused:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete the whole if "lm_labels" ... -> since lm_labels was never there in the first place we should not allow the user to use it

Comment thread src/transformers/modeling_prophetnet.py Outdated
lprobs,
expend_targets.view(-1),
reduction='sum',
ignore_index=self.padding_idx,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, I see where the -100 comes from. I think -100 is the default value for, so you might not have to put it here, also can we call it ignore_loss_token_id instead?

if labels is not None:
# fine-tune
expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(self.padding_idx)
for i in range(self.config.ngram):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we try not to use the for look here? It should be quite easy to get rid of it here I think

return new_key_padding_mask


class EncoderLayer(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have to split the class up into more classes. If you take a look at BERT:

class BertLayer(nn.Module):
you can see that the BertLayer is much more modularized.

Note that once we have the layer names defined, such as self.fc1 we cannot really change them any more afterwards because the trained weights are tied to it. Have a more modularized layer has the advantage that it's much easier and cleaner to add new features change the logic at a later stage.

Say we want to add another version which has the layer norm before the feed forward layers. As it is now we would have to add if statements which are not very clean. If we instead modularize the feed forward part directly, such as self.feed_forward = FeedForwardLayer(...)we could later simply add anotherFeedForwardLayerNew(...)and do if config....self.feed_forward = FeedForwardLayer(..) else self.feed_forward = FeedForwardLayerNew(...)`. This allowed us to simply device many models from BERT and makes the model in general much more flexible. I strongly recommend to do this here as well...if this means that it does not fit with your current weights, we can simply write a conversion script to convert the weights (takes 15min, I can help you :-))
This has a big advantage if we decide to add new features to the model at a later stage...say for example the feed forward layers would change. As it is implemented now we could not

Comment thread src/transformers/modeling_prophetnet.py Outdated
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)

def _relative_positions_bucket(self, relative_positions, bidirectional=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make this static as well and instert the self... params -> this is a classic function we should test

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @qiweizhen,

Thanks so much for this PR! I read your paper today - congrats for the amazing results! I'm very excited to integrate this into the library.

Sorry that I left so many comments - feel free to only take those into account that you think are necessary - a lot of these comments are nits.

IMO we should focus on these 4 things before merging:

  1. Let's make the layers more flexible to future enhancement/changes by modularizing them more (see my comment on line 358). I think we should use Bert as the role model. Having ProphetNetSelfAttention ProphetNetAttentionOutput, ProphetNetIntermediate, ProphetNetOutput would make the code much more flexible.

  2. We should try to not use Python and Numpy functions at all (Python's for loop, e.g.). Having for loops in the forward functions can render the code very slow on CUDA. Also @mfuntowicz here.
    This is not too big of a problem in the beginning as it does not require breaking change later and can be improved in a second PR though.

  3. more important. It would be great if we can try to change all functions that do not use self to static functions. It has two big advantages: 1) The reader knows the function does not require any model inherit parameters and is probably quite easy. 2) The functions can very easily be tested -> the model does not have to be instantiated to test these functions, which makes it much easier to find bugs and have good tests.

  4. VERY IMPORTANT - Tests! This model is very complex and I think in order to be able to maintain it, we need integration tests, as is done a lot in Bart, e.g.:

    class BartModelIntegrationTests(unittest.TestCase):
    . The more tests, the better...Without tests, it will be very hard to refactor the model at a later stage and to know whether it works correctly. Also testing model specific static functions, like _relative_positions_bucket and single layers is extremely useful! First, people have a much easier time understanding these layers and function by seeing an input and output and 2nd, bugs can be pinned in much more detail. I started doing this a lot for Longformer now:
    def test_diagonalize(self):
    and think it's paying of a lot.
    If it requires too much work to do all these integration tests, no worry - we can do it later...we should have at least 4,5 integration tests for the pretrained models for ProphetNetForConditionalGeneration.

The other comments are mostly nits. Sorry, if all these comments are a bit overwhelming and if it sounds like too much work for you, don't worry - we can take care of it as well, especially like the performance improvements of 2).

Thanks a lot for adding the model - very excited to have the community fine-tune the model on a bunch of tasks :-)

@JetRunner
Copy link
Copy Markdown
Contributor

JetRunner commented Aug 6, 2020

@patrickvonplaten Thanks for your review! I learned a lot, too.

@qiweizhen Please be free to contact me for discussion via WeChat if you have trouble understanding Patrick's comments or you want to have another person to double-check! Thanks for your great work!

Copy link
Copy Markdown
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new model!

  • I would add tests and an integration test like in test_modeling_bart.py. (maybe I missed these).
  • and in general try to reuse components from Bart, which is also copied from fairseq and I think besides your NGramAttention is very similar. (Not a priority, only if it makes your life easier.)

(I meant to comment, not Approve).

Copy link
Copy Markdown
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(deleted)

Comment on lines +16 to +17
EN_SENTENCE = "Google left China in 2010"
ZH_SENTENCE = "Google在2010年离开中国"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long live Microsoft lol!


class EncoderLayer(nn.Module):
"""
Same to Transformer Encoder Layer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Same to Transformer Encoder Layer
Same as Transformer Encoder Layer


# [1+ngram*T, B, C]
attn = torch.cat(attn_result, 0).view(-1, bsz, embed_dim)
attn = torch.cat([attn_main, attn_ngram], 0).view(-1, bsz, embed_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn = torch.cat([attn_main, attn_ngram], 0).view(-1, bsz, embed_dim)
attn = torch.cat((attn_main, attn_ngram), 0).view(-1, bsz, embed_dim)

residual = x
x, _ = self.encoder_attn(
query=x,
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can make self.dropout = Dropout(p=self.dropout) and then directly call it here?

Comment on lines +699 to +701
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the dropouts

return result


def cal_relative_positions_buckets(num_buckets, max_distance, real_positions):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused with the names of _relative_positions_bucket and cal_relative_positions_buckets.

return i_buckets_main_stream, i_bucket_relative_stream

def cal_finetune_relative_positions(self, real_positions):
def cal_and_buffer_finetune_relative_positions(self, real_positions):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cal_and_buffer_finetune_relative_positions(self, real_positions):
def cal_and_cache_finetune_relative_positions(self, real_positions):

ngram_input_embed = self.ngram_input_embed.weight
if use_cache:
B = x.size(1)
B = hidden_states.size(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use capitalized letter for variable name

ngram_mask_matrix = self.buffered_future_mask_ngram(hidden_states)
# TODO in train [(1+ngram)*T, B, C], in inference [T+ngram, B, C]
x = torch.cat([x] + ngram_masks, 0)
hidden_states = torch.cat([hidden_states] + ngram_masks, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hidden_states = torch.cat([hidden_states] + ngram_masks, 0)
hidden_states = torch.cat((hidden_states) + ngram_masks, 0)

@patrickvonplaten patrickvonplaten mentioned this pull request Sep 16, 2020
@qiweizhen qiweizhen closed this Sep 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model card Related to pretrained model cards

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants