Skip to content

Commit

Permalink
delete redundant parameters for astar (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma authored Aug 8, 2023
1 parent b0bf957 commit b6e4dd1
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 27 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Now the pygmtools is available with the ``numpy`` backend.
The following packages are required, and shall be automatically installed by ``pip``:

```
Python >= 3.7
Python >= 3.8
requests >= 2.25.1
scipy >= 1.4.1
Pillow >= 7.2.0
Expand All @@ -64,6 +64,7 @@ easydict >= 1.7
appdirs >= 1.4.4
tqdm >= 4.64.1
wget >= 3.2
networkx >= 2.8.8
```

## Available Graph Matching Solvers
Expand Down
8 changes: 2 additions & 6 deletions pygmtools/classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,7 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
return result


def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, dropout=0, beam_width=0,
trust_fact=1, no_pred_size=0, backend=None):
def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, beam_width=0, backend=None):
r"""
A\* (A-star) solver for graph matching (Lawler's QAP).
The **A\*** solver was originally proposed to solve the graph edit distance (GED) problem. It finds the optimal
Expand All @@ -1081,10 +1080,7 @@ def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, dropout=0, beam_
:param channel: (default: None) Channel size of the input layer. If given, it must match the feature dimension (d) of feat1, feat2.
If not given, it will be defined by the feature dimension (d) of feat1, feat2.
Ignored if the network object isgiven (ignored if network!=None)
:param dropout: (default: 0) Dropout probability
:param beam_width: (default: 0) Size of beam-search witdh (0 = no beam).
:param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN).
:param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics.
:param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
:return: :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix
Expand Down Expand Up @@ -1188,7 +1184,7 @@ def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, dropout=0, beam_
if n1 is not None: _check_data_type(n1, 'n1', backend)
if n2 is not None: _check_data_type(n2, 'n2', backend)

args = (feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size)
args = (feat1, feat2, A1, A2, n1, n2, channel, beam_width)
try:
mod = importlib.import_module(f'pygmtools.{backend}_backend')
fn = mod.astar
Expand Down
5 changes: 2 additions & 3 deletions pygmtools/neural_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ def ngm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,


def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=64, filters_2=32, filters_3=16,
tensor_neurons=16, dropout=0, beam_width=0, trust_fact=1, no_pred_size=0,
tensor_neurons=16, beam_width=0, trust_fact=1, no_pred_size=0,
network=None, return_network=False, pretrain='AIDS700nef', backend=None):
r"""
The **GENN-A\*** (Graph Edit Neural Network A\*) solver for graph matching (and graph edit distance)
Expand Down Expand Up @@ -1306,7 +1306,6 @@ def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=6
:param filters_2: (default: 32) Filters (neurons) in 2nd convolution.
:param filters_3: (default: 16) Filters (neurons) in 2nd convolution.
:param tensor_neurons: (default: 16) Neurons in tensor network layer.
:param dropout: (default: 0) Dropout probability
:param beam_width: (default: 0) Size of beam-search witdh (0 = no beam).
:param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN).
:param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics.
Expand Down Expand Up @@ -1472,7 +1471,7 @@ def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=6
if n2 is not None: _check_data_type(n2, 'n2', backend)

args = (feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain)
tensor_neurons, beam_width, trust_fact, no_pred_size, network, pretrain)
try:
mod = importlib.import_module(f'pygmtools.{backend}_backend')
fn = mod.genn_astar
Expand Down
1 change: 0 additions & 1 deletion pygmtools/pytorch_astar_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def default_parameter():
params['filters_2'] = 32
params['filters_3'] = 16
params['tensor_neurons'] = 16
params['dropout'] = 0
params['astar_beam_width'] = 0
params['astar_trust_fact'] = 1
params['astar_no_pred'] = 0
Expand Down
18 changes: 8 additions & 10 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def __init__(self, args):
:param number_of_labels: Number of node labels.
"""
super(GENN, self).__init__()
self.training = False
self.args = args
if self.args['use_net']:
self.number_labels = self.args['channel']
Expand Down Expand Up @@ -909,10 +910,8 @@ def convolutional_pass(self, edge_index, x, edge_weight=None):

features = self.convolution_1(edge_index, x, edge_weight)
features = F.relu(features)
features = F.dropout(features, p=self.args['dropout'], training=self.training)
features = self.convolution_2(edge_index, features, edge_weight)
features = F.relu(features)
features = F.dropout(features, p=self.args['dropout'], training=self.training)
features = self.convolution_3(edge_index, features, edge_weight)
return features

Expand Down Expand Up @@ -1098,17 +1097,17 @@ def hungarian_ged(node_cost_mat: torch.Tensor, n1, n2):
return pred_x, ged_lower_bound


def astar(feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size):
def astar(feat1, feat2, A1, A2, n1, n2, channel, beam_width):
"""
Pytorch implementation of ASTAR
"""
return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, dropout=dropout, beam_width=beam_width,
filters_1=64, filters_2=32, filters_3=16, tensor_neurons=16, trust_fact=trust_fact,
no_pred_size=no_pred_size, pretrain=False, network=None, use_net=False)
return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, beam_width=beam_width,
filters_1=64, filters_2=32, filters_3=16, tensor_neurons=16, trust_fact=1.0,
no_pred_size=0, pretrain=False, network=None, use_net=False)


def astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net):
tensor_neurons, beam_width, trust_fact, no_pred_size, network, pretrain, use_net):
"""
The true implementation of astar and genn_astar functions
"""
Expand Down Expand Up @@ -1139,7 +1138,6 @@ def astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, fi
args['filters_2'] = filters_2
args['filters_3'] = filters_3
args['tensor_neurons'] = tensor_neurons
args['dropout'] = dropout
args['astar_beam_width'] = beam_width
args['astar_trust_fact'] = trust_fact
args['astar_no_pred'] = no_pred_size
Expand Down Expand Up @@ -1534,12 +1532,12 @@ def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau,


def genn_astar(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain):
tensor_neurons, beam_width, trust_fact, no_pred_size, network, pretrain):
"""
Pytorch implementation of GENN-ASTAR
"""
return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net=True)
tensor_neurons, beam_width, trust_fact, no_pred_size, network, pretrain, use_net=True)


#############################################
Expand Down
8 changes: 2 additions & 6 deletions tests/test_classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,12 @@ def test_astar(get_backend):
backends = get_backends(get_backend)
# heuristic_prediction
args1 = (list(range(10, 16, 2)), 10, pygm.astar,{
"beam_width": [0, 1, 2],
"trust_fact": [0.9, 0.95, 1.0],
"no_pred_size": [0, 1],
"beam_width": [0, 1, 2]
}, backends)

# non-batched input
args2 = ([10], 10, pygm.astar,{
"beam_width": [0, 1, 2],
"trust_fact": [0.9, 0.95, 1.0],
"no_pred_size": [0, 1],
"beam_width": [0, 1, 2]
}, backends)

_test_astar(*args1)
Expand Down

0 comments on commit b6e4dd1

Please sign in to comment.