Skip to content

Simple variants for nn.Embedding and nn.LSTM #4

@vaseline555

Description

@vaseline555

Hi,
I really enjoyed your work and I am the first one who cited your work!
(https://arxiv.org/abs/2109.07628)

As I made some simple variants of your method for language models, could you please check this out?
I think it is not much to be a PR, so I made an issue instead for the request of simple checks from original authors.

Could you please check?
Thank you in advance.

Best,
Adam


# LSTM layer
class SubspaceLSTM(nn.LSTM):
    def forward(self, x):
        # call get_weight, which samples from the subspace, then use the corresponding weight.
        weight_dict = self.get_weight()
        mixed_lstm = nn.LSTM(
            input_size=self.input_size, 
            hidden_size=self.hidden_size, 
            num_layers=self.num_layers, 
            batch_first=self.batch_first
        )
        for l in range(self.num_layers):
            setattr(mixed_lstm, f'weight_hh_l{l}', nn.Parameter(weight_dict[f'weight_hh_l{l}_mixed']))
            setattr(mixed_lstm, f'weight_ih_l{l}', nn.Parameter(weight_dict[f'weight_ih_l{l}_mixed']))
            if self.bias:
                setattr(mixed_lstm, f'bias_hh_l{l}', nn.Parameter(weight_dict[f'bias_hh_l{l}_mixed']))
                setattr(mixed_lstm, f'bias_ih_l{l}', nn.Parameter(weight_dict[f'bias_ih_l{l}_mixed']))
        return mixed_lstm(x)

class TwoParamLSTM(SubspaceLSTM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for l in range(self.num_layers):
            setattr(self, f'weight_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_hh_l{l}'))))
            setattr(self, f'weight_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_ih_l{l}'))))
            if self.bias:
                setattr(self, f'bias_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_hh_l{l}'))))
                setattr(self, f'bias_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_ih_l{l}'))))
                
class LinesLSTM(TwoParamLSTM):
    def get_weight(self):
        weight_dict = dict()
        for l in range(self.num_layers):
            weight_dict[f'weight_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_hh_l{l}') + self.alpha * getattr(self, f'weight_hh_l{l}_1') 
            weight_dict[f'weight_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_ih_l{l}') + self.alpha * getattr(self, f'weight_ih_l{l}_1') 
            if self.bias:
                weight_dict[f'bias_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_hh_l{l}') + self.alpha * getattr(self, f'bias_hh_l{l}_1') 
                weight_dict[f'bias_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_ih_l{l}') + self.alpha * getattr(self, f'bias_ih_l{l}_1')
        return weight_dict

# Embedding layer
class SubspaceEmbedding(nn.Embedding):
    def forward(self, x):
        w = self.get_weight()
        x = F.embedding(
            x,
            w,
        )
        return x

class TwoParamEmbedding(SubspaceEmbedding):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
                
class LinesEmbedding(TwoParamEmbedding):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions