-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSCTGNet.py
More file actions
340 lines (274 loc) · 16.8 KB
/
SCTGNet.py
File metadata and controls
340 lines (274 loc) · 16.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv,global_mean_pool
from LTLEmbed.model import OSUG_SAGE
class TGOSUGNN(OSUG_SAGE):
def __init__(self, device, in_channels, hidden_channels=256, num_layers=10, out_channels=2, embedding_size=128, mlp_hidden_channels=128, max_trace_len=5):
'''
输入:
- in_channels: 节点类型数目;
- hidden_channels: 隐藏层嵌入维度;
- num_layers: 层数;
- out_channels: 输出维度;
- embedding_size: 节点初始embedding的维度;
'''
super().__init__(in_channels, hidden_channels, num_layers, out_channels, embedding_size)
self.device = device
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.out_channels = out_channels
self.embedding_size = embedding_size
self.mlp_hidden_channels = mlp_hidden_channels
self.max_trace_len = max_trace_len
# self.proj = nn.Parameter(torch.randn(self.hidden_channels, self.embedding_size))
# 解码迹
self.decoder_conv = SAGEConv(self.hidden_channels, self.hidden_channels)
self.decoder_bn = nn.BatchNorm1d(self.hidden_channels)
# 预测迹中每个状态的原子命题真值
self.predictor=nn.Sequential(
nn.Linear(self.hidden_channels, self.mlp_hidden_channels),
nn.BatchNorm1d(self.mlp_hidden_channels),
nn.ReLU(),
# nn.Dropout(p=0.1),
nn.Linear(self.mlp_hidden_channels, self.out_channels)
)
# 预测迹中环开始位置
self.loop_checker=nn.Sequential(
nn.Linear(self.hidden_channels * 2, self.mlp_hidden_channels),
nn.BatchNorm1d(self.mlp_hidden_channels),
nn.ReLU(),
# nn.Dropout(p=0.1),
nn.Linear(self.mlp_hidden_channels, 1)
)
def embed_init(self, x):
# x shape [N, in_channels]
return self.embedding(x)
def embed_ver_1(self, h, edge_index):
# h shape [N, embedding_size]
h = self.conv_layers_1(h, edge_index)
h = self.bn_layers_1(h)
h = F.relu(h)
return h
def embed_ver_2(self, h, edge_index):
# h shape [N, hidden_channels]
for conv, bn in zip(self.conv_layers_2, self.bn_layers_2):
h = conv(h, edge_index)
h = bn(h)
h = F.relu(h)
# print(f'h:{h}')
return h
def encode(self, x, edge_index):
h = self.embed_init(x)
h = self.embed_ver_1(h, edge_index)
h = self.embed_ver_2(h, edge_index)
return h
def decode(self, h, edge_index):
h = self.decoder_conv(h, edge_index)
h = self.decoder_bn(h)
h = F.relu(h)
return h
def read_out(self, h, example_mark):
return global_mean_pool(h, example_mark)
def generate_trace(self, x_batch, edge_index_batch, G2T_index_batch, example_mark_batch, atom_mask_batch, batch_size, max_num_var):
'''
模型逻辑:
1. h = GNN.encoder(x) //通过多层OSUG(不同参数)获取节点向量,h in R^{num_vertex * dimension_hidden};
2. For t from 1 to max_trace_len:
a. h = GNN.decode(h) //通过一层OSUG(为了支持可变迹长度共用参数)获取节点向量,h in R^{num_vertex * dimension_hidden};
b. s_t = GNN.readout(h) //计算第t个状态的状态向量,s_t in R^{batch_size * dimension_hidden};
c. h_{atom} = get_atom(h) //从所有的节点中得到原子命题对应的节点向量,h_{atom} in R^{batch_size * num_atom * dimension_hidden};
d. y_t = classfier(h_{atom}) //预测第t个状态的各个原子命题的真值情况,y_t in (0,1)^{batch_size * num_atom};
3. y = [y_1, ..., y_T] //y in (0,1)^{batch_size * trace_len * num_atom};
4. loop = softmax(MLP_2([s_i,s_T])) //利用状态向量计算环开始位置,loop in (0,1)^{T};
5. return Y, loop
'''
h = self.encode(x_batch, edge_index_batch)
s = torch.zeros(size=(batch_size,self.max_trace_len,self.hidden_channels),dtype=torch.float,device=self.device)
state_sequence_batch = torch.zeros(size=(batch_size,self.max_trace_len,max_num_var),dtype=torch.float,device=self.device)
for i in range(self.max_trace_len):
h = self.decode(h, edge_index_batch)
# print(f'h:{h}')
s[:,i] = self.read_out(h, example_mark_batch)
# print(f's[:,{i}]:{s[:,i]}')
h_atom = torch.cat((h, torch.zeros(size=(1, h.shape[1]),dtype=torch.float,device=self.device)), dim=0)[G2T_index_batch]
# print(f'h_atom:{h_atom}')
state_sequence_batch[:,i] = torch.softmax(self.predictor(h_atom.view(-1,self.hidden_channels)), dim=1)[:,1].view(batch_size,max_num_var)
# print(f'state_sequence_batch[:,{i}]:{state_sequence_batch[:,i]}')
state_sequence_batch[:,i] = torch.mul(state_sequence_batch[:,i], atom_mask_batch) # 去除padding的影响
# print(f'state_sequence_batch[:,{i}]:{state_sequence_batch[:,i]}')
# x = torch.mm(h, self.proj)
# print(f'x:{x}')
# print(f's:{s}')
# print(f'state_sequence_batch:{state_sequence_batch}')
loop_batch = torch.zeros(size=(batch_size,self.max_trace_len,1),dtype=torch.float,device=self.device) # loop_batch in R^{batch_size * trace_len * 1}
for i in range(self.max_trace_len):
loop_batch[:,i]=self.loop_checker(torch.cat((s[:,i],s[:,-1]),dim=1))
# print(f'cat((s[:,{i}],s[:,-1])):{torch.cat((s[:,i],s[:,-1]),dim=1)}')
# print(f'loop_batch[:,{i}]:{loop_batch[:,i]}')
loop_batch = torch.softmax(loop_batch,dim=1)
return state_sequence_batch, loop_batch
class TGOSUGNN2RNN(OSUG_SAGE):
def __init__(self, device, in_channels, hidden_channels=256, num_layers=10, out_channels=2, embedding_size=128, mlp_hidden_channels=128, max_trace_len=5):
'''
输入:
- in_channels: 节点类型数目;
- hidden_channels: 隐藏层嵌入维度;
- num_layers: 层数;
- out_channels: 输出维度;
- embedding_size: 节点初始embedding的维度;
'''
super().__init__(in_channels, hidden_channels, num_layers, out_channels, embedding_size)
self.device = device
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.out_channels = out_channels
self.embedding_size = embedding_size
self.mlp_hidden_channels = mlp_hidden_channels
self.max_trace_len = max_trace_len
# 解码迹
self.decoder_state_sequence = nn.GRU(input_size=self.hidden_channels,hidden_size=self.hidden_channels,num_layers=4,batch_first=True)
self.decoder_loop = nn.GRU(input_size=self.hidden_channels,hidden_size=self.hidden_channels,num_layers=4,batch_first=True)
# 预测迹中每个状态的原子命题真值
self.predictor=nn.Sequential(
nn.Linear(self.hidden_channels, self.mlp_hidden_channels),
nn.BatchNorm1d(self.mlp_hidden_channels),
nn.ReLU(),
# nn.Dropout(p=0.1),
nn.Linear(self.mlp_hidden_channels, self.out_channels)
)
# 预测迹中环开始位置
self.loop_checker=nn.Sequential(
nn.Linear(self.hidden_channels * 2, self.mlp_hidden_channels),
nn.BatchNorm1d(self.mlp_hidden_channels),
nn.ReLU(),
# nn.Dropout(p=0.1),
nn.Linear(self.mlp_hidden_channels, 1)
)
def embed_init(self, x):
# x shape [N, in_channels]
return self.embedding(x)
def embed_ver_1(self, h, edge_index):
# h shape [N, embedding_size]
h = self.conv_layers_1(h, edge_index)
h = self.bn_layers_1(h)
h = F.relu(h)
return h
def embed_ver_2(self, h, edge_index):
# h shape [N, hidden_channels]
for conv, bn in zip(self.conv_layers_2, self.bn_layers_2):
h = conv(h, edge_index)
h = bn(h)
h = F.relu(h)
# print(f'h:{h}')
return h
def encode(self, x, edge_index):
h = self.embed_init(x)
h = self.embed_ver_1(h, edge_index)
h = self.embed_ver_2(h, edge_index)
return h
def decode(self, h, s):
h,s = self.decoder(h.view(-1,self.hidden_channels), s.view(-1,self.hidden_channels))
h = self.decoder_bn(h)
h = F.relu(h)
return h
def read_out(self, h, example_mark):
return global_mean_pool(h, example_mark)
def generate_trace(self, x_batch, edge_index_batch, G2T_index_batch, example_mark_batch, atom_mask_batch, batch_size, max_num_var):
'''
模型逻辑:
1. h = GNN.encoder(x) //通过多层OSUG(不同参数)获取节点向量,h in R^{num_vertex * dimension_hidden};
2. h_{atom} = get_atom(h) //从所有的节点中得到原子命题对应的节点向量,h_{atom} in R^{batch_size * num_atom * dimension_hidden};
3. For t from 1 to max_trace_len:
a. s_t,history = LSTM.decode(h_{atom},history) //计算第t个状态的状态向量,s_t in R^{batch_size * num_atom * dimension_hidden};
b. y_t = classfier(s_t) //预测第t个状态的各个原子命题的真值情况,y_t in (0,1)^{batch_size * num_atom};
4. y = [y_1, ..., y_T] //y in (0,1)^{batch_size * trace_len * num_atom};
5. loop = softmax(MLP_2([mean(s_i,dim=1),mean(s_T,dim=1)])) //利用状态向量计算环开始位置,loop in (0,1)^{T};
6. return Y, loop
'''
h = self.encode(x_batch, edge_index_batch)
# print(f'h:{h}')
h_atom = torch.cat((h, torch.zeros(size=(1, h.shape[1]),dtype=torch.float,device=self.device)), dim=0)[G2T_index_batch] # h_{atom} shape [batch_size * num_atom * dimension_hidden]
# print(f'h_atom:{h_atom}')
# print(f'input of decoder_state_sequence:{h_atom.view(-1,self.hidden_channels).unsqueeze(1).repeat(1, self.max_trace_len, 1)}')
s, _ = self.decoder_state_sequence(h_atom.view(-1,self.hidden_channels).unsqueeze(1).repeat(1, self.max_trace_len, 1)) # s shape [batch_size * num_atom, max_trace_len, dimension_hidden]
# print(f's:{s}')
# s = s.view(batch_size,max_num_var,self.max_trace_len,self.hidden_channels).permute(0, 2, 1, 3) # s shape [batch_size, num_atom, max_trace_len, dimension_hidden] -> [batch_size, max_trace_len, num_atom, dimension_hidden]
state_sequence_batch = torch.softmax(self.predictor(s.reshape(-1,self.hidden_channels)), dim=1)[:,1].view(batch_size,max_num_var,self.max_trace_len).permute(0, 2, 1) # state_sequence_batch shape [batch_size, max_trace_len, num_atom]
state_sequence_batch = torch.mul(state_sequence_batch, atom_mask_batch.unsqueeze(1)) # 去除padding的影响
# print(f'state_sequence_batch:{state_sequence_batch}')
h_G = self.read_out(h, example_mark_batch) # h_{atom} shape [batch_size * dimension_hidden]
# print(f'h_G:{h_G}')
# print(f'input of decoder_loop:{h_G.unsqueeze(1).repeat(1, self.max_trace_len, 1)}')
s, _ = self.decoder_loop(h_G.unsqueeze(1).repeat(1, self.max_trace_len, 1)) # s shape [batch_size, max_trace_len, dimension_hidden],每一时刻所有点表示的平均用于表示状态
loop_batch = torch.zeros(size=(batch_size,self.max_trace_len,1),dtype=torch.float,device=self.device) # loop_batch in R^{batch_size * trace_len * 1}
for i in range(self.max_trace_len):
loop_batch[:,i]=self.loop_checker(torch.cat((s[:,i],s[:,-1]),dim=1))
# print(f'cat((s[:,{i}],s[:,-1])):{torch.cat((s[:,i],s[:,-1]),dim=1)}')
# print(f'loop_batch[:,{i}]:{loop_batch[:,i]}')
loop_batch = torch.softmax(loop_batch,dim=1)
return state_sequence_batch, loop_batch
class SCTGOSUGNN(TGOSUGNN):
def __init__(self, device, in_channels, hidden_channels=256, num_layers=10, out_channels=2, embedding_size=128, mlp_hidden_channels=128, max_trace_len=5):
'''
输入:
- in_channels: 节点类型数目;
- hidden_channels: 隐藏层嵌入维度;
- num_layers: 层数;
- out_channels: 输出维度;
- embedding_size: 节点初始embedding的维度;
'''
super().__init__(device, in_channels, hidden_channels, num_layers, out_channels, embedding_size, mlp_hidden_channels, max_trace_len)
# 预测可满足性
self.sc_classifier=nn.Sequential(
nn.Linear(self.hidden_channels, self.mlp_hidden_channels),
nn.BatchNorm1d(self.mlp_hidden_channels),
nn.ReLU(),
# nn.Dropout(p=0.1),
nn.Linear(self.mlp_hidden_channels, self.out_channels)
)
def generate_trace(self, x_batch, edge_index_batch, G2T_index_batch, example_mark_batch, atom_mask_batch, batch_size, max_num_var):
'''
模型逻辑:
1. h = GNN.encoder(x) //通过多层OSUG(不同参数)获取节点向量,h in R^{num_vertex * dimension_hidden};
2. f = GNN.readout(h) //得到公式嵌入向量
2. For t from 1 to max_trace_len:
a. h = GNN.decode(h) //通过一层OSUG(为了支持可变迹长度共用参数)获取节点向量,h in R^{num_vertex * dimension_hidden};
b. s_t = GNN.readout(h) //计算第t个状态的状态向量,s_t in R^{batch_size * dimension_hidden};
c. h_{atom} = get_atom(h) //从所有的节点中得到原子命题对应的节点向量,h_{atom} in R^{batch_size * num_atom * dimension_hidden};
d. y_t = classfier(h_{atom}) //预测第t个状态的各个原子命题的真值情况,y_t in (0,1)^{batch_size * num_atom};
3. y = [y_1, ..., y_T] //y in (0,1)^{batch_size * trace_len * num_atom};
4. loop = softmax(MLP_2([s_i,s_T])) //利用状态向量计算环开始位置,loop in (0,1)^{T};
5. return Y, loop, f
'''
h = self.encode(x_batch, edge_index_batch)
f_batch = self.read_out(h, example_mark_batch)
s = torch.zeros(size=(batch_size,self.max_trace_len,self.hidden_channels),dtype=torch.float,device=self.device)
state_sequence_batch = torch.zeros(size=(batch_size,self.max_trace_len,max_num_var),dtype=torch.float,device=self.device)
for i in range(self.max_trace_len):
h = self.decode(h, edge_index_batch)
# print(f'h:{h}')
s[:,i] = self.read_out(h, example_mark_batch)
# print(f's[:,{i}]:{s[:,i]}')
h_atom = torch.cat((h, torch.zeros(size=(1, h.shape[1]),dtype=torch.float,device=self.device)), dim=0)[G2T_index_batch]
# print(f'h_atom:{h_atom}')
state_sequence_batch[:,i] = torch.softmax(self.predictor(h_atom.view(-1,self.hidden_channels)), dim=1)[:,1].view(batch_size,max_num_var)
# print(f'state_sequence_batch[:,{i}]:{state_sequence_batch[:,i]}')
state_sequence_batch[:,i] = torch.mul(state_sequence_batch[:,i], atom_mask_batch) # 去除padding的影响
# print(f'state_sequence_batch[:,{i}]:{state_sequence_batch[:,i]}')
# x = torch.mm(h, self.proj)
# print(f'x:{x}')
# print(f's:{s}')
# print(f'state_sequence_batch:{state_sequence_batch}')
loop_batch = torch.zeros(size=(batch_size,self.max_trace_len,1),dtype=torch.float,device=self.device) # loop_batch in R^{batch_size * trace_len * 1}
for i in range(self.max_trace_len):
loop_batch[:,i]=self.loop_checker(torch.cat((s[:,i],s[:,-1]),dim=1))
# print(f'cat((s[:,{i}],s[:,-1])):{torch.cat((s[:,i],s[:,-1]),dim=1)}')
# print(f'loop_batch[:,{i}]:{loop_batch[:,i]}')
loop_batch = torch.softmax(loop_batch,dim=1)
return state_sequence_batch, loop_batch, f_batch
def check_satisfiability(self, f_batch):
return self.sc_classifier(f_batch)