From 9d0c38b10adc4164bd28ea397d28c96ecb8477d2 Mon Sep 17 00:00:00 2001 From: mufeili Date: Mon, 14 Oct 2024 22:03:54 -0400 Subject: [PATCH] update --- models.py | 55 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/models.py b/models.py index 2221dc3..4d26a46 100644 --- a/models.py +++ b/models.py @@ -207,27 +207,44 @@ def forward(self, x, edge_index, batch): x = global_mean_pool(x, batch) x = self.lin(x) return x - -class BidirectionalSAGEConv(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels): - super(BidirectionalSAGEConv, self).__init__() - - self.sage_forward = SAGEConv(in_channels, hidden_channels) - self.sage_forward_hidden = SAGEConv(hidden_channels, out_channels) - self.sage_backward = SAGEConv(in_channels, hidden_channels) - self.sage_backward_hidden = SAGEConv(hidden_channels, out_channels) - self.lin = Linear(out_channels, 1) +class BidirectionalSAGEConv(nn.Module): + def __init__(self, + in_channels, + out_channels): + super().__init__() + + self.conv_forward = SAGEConv(in_channels, out_channels) + self.conv_backward = SAGEConv(in_channels, out_channels) + + def forward(self, x, edge_index, reverse_edge_index): + x1 = self.conv_forward(x, edge_index) + x2 = self.conv_backward(x, reverse_edge_index) + x = (x1 + x2)/2 + x = F.relu(x) + + return x + +class BidirectionalSAGE(nn.Module): + def __init__(self, + in_channels, + hidden_channels, + out_channels, + num_layers): + super().__init__() + + self.convs = nn.ModuleList() + assert num_layers >= 1 + self.convs.append(BidirectionalSAGEConv(in_channels, hidden_channels)) + for _ in range(num_layers - 1): + self.convs.append(BidirectionalSAGEConv(hidden_channels, hidden_channels)) + self.pred = nn.Linear(hidden_channels, out_channels) + def forward(self, x, edge_index, batch): - x_forward = self.sage_forward(x, edge_index) - x_forward = self.sage_forward_hidden(x_forward, edge_index) reverse_edge_index = edge_index[[1, 0], :] - x_backward = self.sage_backward(x, reverse_edge_index) - x_backward = self.sage_backward_hidden(x_backward, reverse_edge_index) - x_bidirectional = (x_forward + x_backward) / 2.0 - - x = F.log_softmax(x_bidirectional, dim=1) + for conv in self.convs: + x = conv(x, edge_index, reverse_edge_index) x = global_mean_pool(x, batch) - x = self.lin(x) + x = self.pred(x) + return x -