Skip to content

Commit

Permalink
rm redundant A-star code and backup links for datasets (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma authored Dec 4, 2023
1 parent 14894f9 commit 5795e2b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
32 changes: 20 additions & 12 deletions pygmtools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def __init__(self, sets, obj_resize, **ds_dict):
SPLIT_OFFSET = dataset_cfg.WillowObject.SPLIT_OFFSET
TRAIN_SAME_AS_TEST = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST
RAND_OUTLIER = dataset_cfg.WillowObject.RAND_OUTLIER
URL = 'http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip'
URL = ['http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip',
'https://huggingface.co/heatingma/pygmtools/resolve/main/WILLOW-ObjectClass_dataset.zip']
if len(ds_dict.keys()) > 0:
if 'CLASSES' in ds_dict.keys():
CLASSES = ds_dict['CLASSES']
Expand Down Expand Up @@ -750,6 +751,8 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict):
COMB_CLS = dataset_cfg.SPair.COMB_CLS
SIZE = dataset_cfg.SPair.SIZE
ROOT_DIR = dataset_cfg.SPair.ROOT_DIR
URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/SPair-71k.tar.gz',
'http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz']
if len(ds_dict.keys()) > 0:
if 'TRAIN_DIFF_PARAMS' in ds_dict.keys():
TRAIN_DIFF_PARAMS = ds_dict['TRAIN_DIFF_PARAMS']
Expand All @@ -775,14 +778,12 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict):
)

assert not problem == 'MGM', 'No match found for problem {} in SPair-71k'.format(problem)
self.dataset_dir = 'data/SPair-71k'

if not os.path.exists(SPair71k_image_path):
assert ROOT_DIR == dataset_cfg.SPair.ROOT_DIR, 'you should not change ROOT_DIR unless the data have been manually downloaded'
self.download(url='http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz')

if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)

self.download(url=URL)

self.dataset_dir = 'data/SPair-71k'
self.obj_resize = obj_resize
self.sets = sets_translation_dict[sets]
self.ann_files = open(os.path.join(self.SPair71k_layout_path, self.SPair71k_dataset_size, self.sets + ".txt"), "r").read().split("\n")
Expand Down Expand Up @@ -815,20 +816,24 @@ def download(self, url=None, retries=5):
if not os.path.exists(dirs):
os.makedirs(dirs)
print('Downloading dataset SPair-71k...')
filename = "data/SPair-71k.tgz"
filename = "data/SPair-71k.tar.gz"
download(filename=filename, url=url, to_cache=False)
try:
tar = tarfile.open(filename, "r")
except tarfile.ReadError as err:
print('Warning: Content error. Retrying...\n', err)
os.remove(filename)
return self.download(url, retries - 1)


self.dataset_dir = 'data/SPair-71k'
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)

file_names = tar.getnames()
print('Unzipping files...')
sleep(0.5)
for file_name in tqdm(file_names):
tar.extract(file_name, "data/")
tar.extract(file_name, self.dataset_dir)
tar.close()
try:
os.remove(filename)
Expand Down Expand Up @@ -1018,7 +1023,9 @@ def __init__(self, sets, obj_resize, **ds_dict):
CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES
ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1Po9pRMWXTqKK2ABPpVmkcsOq-6K_2v-B'
URL = ['https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download',
'https://huggingface.co/heatingma/pygmtools/resolve/main/IMC-PT-SparseGM.tar.gz']

if len(ds_dict.keys()) > 0:
if 'MAX_KPT_NUM' in ds_dict.keys():
MAX_KPT_NUM = ds_dict['MAX_KPT_NUM']
Expand Down Expand Up @@ -1190,7 +1197,8 @@ class CUB2011:
def __init__(self, sets, obj_resize, **ds_dict):
CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT
ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
URL = ['https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45',
'https://huggingface.co/heatingma/pygmtools/resolve/main/CUB_200_2011.tar.gz']
if len(ds_dict.keys()) > 0:
if 'ROOT_DIR' in ds_dict.keys():
ROOT_DIR = ds_dict['ROOT_DIR']
Expand Down
3 changes: 1 addition & 2 deletions pygmtools/pytorch_astar_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def check_layer_parameter(params):
return True


def node_metric(node1, node2):

def node_metric(node1, node2):
encoding = torch.sum(torch.abs(node1.unsqueeze(2) - node2.unsqueeze(1)), dim=-1)
non_zero = torch.nonzero(encoding)
for i in range(non_zero.shape[0]):
Expand Down
16 changes: 9 additions & 7 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,6 @@ def _astar(self, data: GraphPair):
ns_2 = torch.bincount(data.g2.batch)

adj_1 = to_dense_adj(edge_index_1, batch=batch_1, edge_attr=edge_attr_1)

dummy_adj_1 = torch.zeros(adj_1.shape[0], adj_1.shape[1] + 1, adj_1.shape[2] + 1, device=device)
dummy_adj_1[:, :-1, :-1] = adj_1
adj_2 = to_dense_adj(edge_index_2, batch=batch_2, edge_attr=edge_attr_2)
Expand Down Expand Up @@ -1055,12 +1054,15 @@ def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_no
return ged

def heuristic_prediction_hun(self, k: torch.Tensor, n1, n2, partial_pmat):
k_prime = k.reshape(-1, n1 + 1, n2 + 1)
node_costs = torch.empty(k_prime.shape[0])
for i in range(k_prime.shape[0]):
_, node_costs[i] = hungarian_ged(k_prime[i], n1, n2)
node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1)
self.heuristic_cache['node_cost'] = node_cost_mat
if 'node_cost' in self.heuristic_cache:
node_cost_mat = self.heuristic_cache['node_cost']
else:
k_prime = k.reshape(-1, n1 + 1, n2 + 1)
node_costs = torch.empty(k_prime.shape[0])
for i in range(k_prime.shape[0]):
_, node_costs[i] = hungarian_ged(k_prime[i], n1, n2)
node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1)
self.heuristic_cache['node_cost'] = node_cost_mat

graph_1_mask = ~partial_pmat.sum(dim=-1).to(dtype=torch.bool)
graph_2_mask = ~partial_pmat.sum(dim=-2).to(dtype=torch.bool)
Expand Down

0 comments on commit 5795e2b

Please sign in to comment.