Skip to content

Commit

Permalink
remove debugging lines
Browse files Browse the repository at this point in the history
  • Loading branch information
wzever committed Aug 13, 2023
1 parent 468b9d2 commit 92d9fd5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 14 deletions.
13 changes: 1 addition & 12 deletions pygmtools/mindspore_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,14 @@ def rrwm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
"""
mindspore implementation of RRWM algorithm.
"""
import time
lp1 = sk = 0
t1 = time.time()
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
# rescale the values in K
d = K.sum(axis=2, keepdims=True)
dmax = d.max(axis=1, keepdims=True)
K = K / (dmax + d.min() * 1e-5)
v = v0
t2 = time.time()
for i in range(max_iter):
# random walk
t3 = time.time()
v = mindspore.ops.BatchMatMul()(K, v)
last_v = v
n = mindspore.ops.norm(v, axis=1, p=1, keep_dims=True)
Expand All @@ -244,20 +239,14 @@ def rrwm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
# reweighted jump
s = v.view(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
s = beta * s / s.max(axis=1, keepdims=True).max(axis=2, keepdims=True)
t4 = time.time()
lp1 += (t4 - t3)
# print(n1, n2)
v = alpha * sinkhorn(s, n1, n2, max_iter=sk_iter, batched_operation=True).swapaxes(1, 2).reshape(batch_num, n1n2, 1) + \
(1 - alpha) * v
t5 = time.time()
# print(s.shape)
sk += (t5 - t4)
n = mindspore.ops.norm(v, axis=1, p=1, keep_dims=True)
v = mindspore.ops.matmul(v, 1 / n)

if (v - last_v).sum().sqrt() < 1e-5:
break
# print(f'pre:{t2-t1:.4f}, lp1:{lp1:.4f}, sk:{sk:.4f}')

return v.view(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)


Expand Down
2 changes: 1 addition & 1 deletion pygmtools/paddle_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _check_and_init_gm(K, n1, n2, n1max, n2max, x0):
n1max = paddle.max(n1)
if n2max is None:
n2max = paddle.max(n2)

assert n1max * n2max == n1n2, 'the input size of K does not match with n1max * n2max!'

# initialize x0 (also v0)
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .pytorch_astar_modules import GCNConv, AttentionModule, TensorNetworkModule, GraphPair, \
VERY_LARGE_INT, to_dense_adj, to_dense_batch, default_parameter, check_layer_parameter, node_metric
from torch import Tensor
# from pygmtools.a_star import a_star
from pygmtools.a_star import a_star

#############################################
# Linear Assignment Problem Solvers #
Expand Down

0 comments on commit 92d9fd5

Please sign in to comment.