-
Notifications
You must be signed in to change notification settings - Fork 0
/
hgcond_main.py
82 lines (76 loc) · 3.39 KB
/
hgcond_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import argparse
import time
import numpy as np
import torch
from graphsyn import evalue_hgcond, hgcond
from utils_data import get_data
from utils import getsize_mb
#%% cmd parameters setting
argparser = argparse.ArgumentParser("HGCond",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--dataset", type=str, default='imdb')#dblp
argparser.add_argument("--cond", type=float, default=0.001)
argparser.add_argument("--basicmodel", type=str, default='HeteroSGC')#HeteroSGC HeteroGCN
argparser.add_argument("--para-init", type=str, default='orth')# 'orth' 'rand'
argparser.add_argument("--feat-init", type=str, default='cluster') # 'cluster' 'sample'
args = argparser.parse_args()
dataset = args.dataset#['dblp', 'imdb', 'acm', 'AMiner','freebase']
cond_rate = args.cond
basicmodel = args.basicmodel
feat_init = args.feat_init
para_init = args.para_init
#%% fix parameters setting
model_architecture = {'hidden_channels':64,
'num_layers':3}
model_train = {'epochs':1000,
'lr':0.01,
'weight_decay':0.0005}
cond_train = {'epochs_deep':3,
'epochs_basic_model':1,
'lr':0.01,
'lr_basic_model':0.1}
cond_train['epochs_initial'] = model_architecture['hidden_channels']*2
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#%% Necessary directory
dir_name = './synthetic_graphs'
if not os.path.exists(dir_name):
os.makedirs(dir_name)
#%% Data
data = get_data(name=dataset)
data = data.to(device)
#%% Train
info = f'Dataset:{dataset}\nCond-rate:{cond_rate}\nFeature initialization:{feat_init}\nParameter initialization:{para_init}\nBasicmodel:{basicmodel}\n'
print(info)
print('Original dataset infomations:')
print(data.info)
print('Start condensing ...')
time_start=time.time()
# graphsyner, losses_log = hgcond(data, cond_rate, feat_init, para_init,
# basicmodel, model_architecture, cond_train)
####################################################
graphsyner, losses_log = hgcond(data, cond_rate, feat_init, para_init, basicmodel, model_architecture, cond_train)
#################################################
time_end=time.time()
y_syn = graphsyner.y_syn
mask_syn = graphsyner.mask_syn
x_syn_dict = graphsyner.best_x_syn_dict
adj_t_syn_dict = graphsyner.best_adj_t_syn_dict
save_path = f'{dir_name}/{dataset}-{cond_rate}.cond'
torch.save((x_syn_dict, adj_t_syn_dict, y_syn, mask_syn), save_path)
print(f'Condensation finished, taking time:{time_end - time_start:.2f}s')
print(f'The condensed graph is saved as {save_path}')
print(graphsyner)
origin_storage = getsize_mb([data.x_dict, data.adj_t_dict,
data[data.target_node_type].y,
data[data.target_node_type].train_mask])
condensed_storage = getsize_mb([x_syn_dict, adj_t_syn_dict, y_syn, mask_syn])
print(f'Origin graph:{origin_storage:.2f}Mb Condensed graph:{condensed_storage:.2f}Mb')
#%%
print('Train on the synthetic graph and test on the real graph')
# x_syn_dict, adj_t_syn_dict, y_syn, mask_syn = torch.load('./synthetic_graphs/dblp-0.001-cluster-orth-1.cond')
accs,f1_micros,f1_macros = evalue_hgcond(1, data, x_syn_dict, adj_t_syn_dict, y_syn, mask_syn,
basicmodel, model_architecture, model_train)
mean = np.mean(accs)*100
std = np.std(accs)*100
print(f'\nAccuracy:{mean:.2f}+{std:.2f}')