-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathpretrain.py
More file actions
19 lines (18 loc) · 1.12 KB
/
pretrain.py
File metadata and controls
19 lines (18 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import argparse
from pretrain_utils import pretrain_graph_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='graph adapter')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='instagram', choices=['arxiv', 'instagram', 'reddit'])
parser.add_argument('--max_epoch', type=int, default=50)
parser.add_argument('--hiddensize_gnn', type=int, default=128)
parser.add_argument('--hiddensize_fusion', type=int, default=128)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--learning_ratio', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=1e-3)
parser.add_argument('--num_warmup_steps', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lm_head_path', type=str, default=f'./pretrain_models/head/lm_head.pkl')
parser.add_argument('--GNN_type', type=str, default='SAGE', choices = ['SAGE','GAT','MLP'])
args = parser.parse_args()
pretrain_graph_adapter(args)