-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
Hi,
I really enjoyed your work and I am the first one who cited your work!
(https://arxiv.org/abs/2109.07628)
As I made some simple variants of your method for language models, could you please check this out?
I think it is not much to be a PR, so I made an issue instead for the request of simple checks from original authors.
Could you please check?
Thank you in advance.
Best,
Adam
# LSTM layer
class SubspaceLSTM(nn.LSTM):
def forward(self, x):
# call get_weight, which samples from the subspace, then use the corresponding weight.
weight_dict = self.get_weight()
mixed_lstm = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first
)
for l in range(self.num_layers):
setattr(mixed_lstm, f'weight_hh_l{l}', nn.Parameter(weight_dict[f'weight_hh_l{l}_mixed']))
setattr(mixed_lstm, f'weight_ih_l{l}', nn.Parameter(weight_dict[f'weight_ih_l{l}_mixed']))
if self.bias:
setattr(mixed_lstm, f'bias_hh_l{l}', nn.Parameter(weight_dict[f'bias_hh_l{l}_mixed']))
setattr(mixed_lstm, f'bias_ih_l{l}', nn.Parameter(weight_dict[f'bias_ih_l{l}_mixed']))
return mixed_lstm(x)
class TwoParamLSTM(SubspaceLSTM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for l in range(self.num_layers):
setattr(self, f'weight_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_hh_l{l}'))))
setattr(self, f'weight_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_ih_l{l}'))))
if self.bias:
setattr(self, f'bias_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_hh_l{l}'))))
setattr(self, f'bias_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_ih_l{l}'))))
class LinesLSTM(TwoParamLSTM):
def get_weight(self):
weight_dict = dict()
for l in range(self.num_layers):
weight_dict[f'weight_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_hh_l{l}') + self.alpha * getattr(self, f'weight_hh_l{l}_1')
weight_dict[f'weight_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_ih_l{l}') + self.alpha * getattr(self, f'weight_ih_l{l}_1')
if self.bias:
weight_dict[f'bias_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_hh_l{l}') + self.alpha * getattr(self, f'bias_hh_l{l}_1')
weight_dict[f'bias_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_ih_l{l}') + self.alpha * getattr(self, f'bias_ih_l{l}_1')
return weight_dict
# Embedding layer
class SubspaceEmbedding(nn.Embedding):
def forward(self, x):
w = self.get_weight()
x = F.embedding(
x,
w,
)
return x
class TwoParamEmbedding(SubspaceEmbedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
class LinesEmbedding(TwoParamEmbedding):
def get_weight(self):
w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
return w
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels