-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
学长您好,如果我想在您的基础上对Value也分组,那我这么写对吗
class CSA_Layer(nn.Module):
def init(self, channels, activate_function, gp):
super(CSA_Layer, self).init()
self.w = nn.Parameter(torch.ones(2))
self.gp = gp
"可以适当利用矩阵乘法降低通道数"
self.q_conv = nn.Conv1d(channels, channels, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels, 1, bias=False)
self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
self.trans_conv = nn.Conv1d(channels, channels, 1, bias=False)
self.after_norm = nn.BatchNorm1d(channels)
self.act = activate_function
self.softmax = nn.Softmax(dim=-1)
def forward(self, feature, position):
w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
"用 qk 来强化v 如果想要增强特征 那么qk应该是位置 如果想要增强位置,那么qk应该是特征"
bs, ch, nums = position.size()
x_q = self.q_conv(position).reshape(bs, self.gp, ch//self.gp, nums).permute(0, 1, 3, 2)
x_q_list = torch.chunk(x_q, dim=1, chunks=self.gp)
x_k = self.k_conv(position).reshape(bs, self.gp, ch//self.gp, nums)
x_k_list = torch.chunk(x_k, dim=1, chunks=self.gp)
x_v = self.v_conv(feature).reshape(bs, self.gp, ch//self.gp, nums)
x_v_list = torch.chunk(x_v, dim=1, chunks=self.gp)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
output = torch.tensor([]).to(device)
for i in range(self.gp):
energy = torch.matmul(x_q_list[i].permute(0, 3, 2, 1), x_k_list[i].permute(0, 2, 1, 3))
attn = self.softmax(energy)
x_r = torch.matmul(x_v_list[i].permute(0, 2, 1, 3), attn).squeeze(2)
output = torch.cat((output, x_r), 1).to(device)
x = w1 * output + w2 * feature
return x
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels