Skip to content

分组注意力机制 #5

@prannt99

Description

@prannt99

学长您好,如果我想在您的基础上对Value也分组,那我这么写对吗

class CSA_Layer(nn.Module):
def init(self, channels, activate_function, gp):
super(CSA_Layer, self).init()
self.w = nn.Parameter(torch.ones(2))
self.gp = gp
"可以适当利用矩阵乘法降低通道数"
self.q_conv = nn.Conv1d(channels, channels, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels, 1, bias=False)

    self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
    self.trans_conv = nn.Conv1d(channels, channels, 1, bias=False)
    self.after_norm = nn.BatchNorm1d(channels)
    self.act = activate_function
    self.softmax = nn.Softmax(dim=-1)

def forward(self, feature, position):
    w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
    w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
    "用 qk 来强化v  如果想要增强特征 那么qk应该是位置  如果想要增强位置,那么qk应该是特征"
    bs, ch, nums = position.size()
    x_q = self.q_conv(position).reshape(bs, self.gp, ch//self.gp, nums).permute(0, 1, 3, 2)  
    x_q_list = torch.chunk(x_q, dim=1, chunks=self.gp)                                       
    x_k = self.k_conv(position).reshape(bs, self.gp, ch//self.gp, nums)                      
    x_k_list = torch.chunk(x_k, dim=1, chunks=self.gp)                                       
    x_v = self.v_conv(feature).reshape(bs, self.gp, ch//self.gp, nums)                       
    x_v_list = torch.chunk(x_v, dim=1, chunks=self.gp)                                       

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    output = torch.tensor([]).to(device)
    for i in range(self.gp):
        energy = torch.matmul(x_q_list[i].permute(0, 3, 2, 1), x_k_list[i].permute(0, 2, 1, 3))  
        attn = self.softmax(energy)
        x_r = torch.matmul(x_v_list[i].permute(0, 2, 1, 3), attn).squeeze(2)                   
        output = torch.cat((output, x_r), 1).to(device)
    x = w1 * output + w2 * feature

    return x

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions