Skip to content

Commit

Permalink
change a_star to astar in code (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma authored Oct 5, 2023
1 parent 9aa9bf8 commit 0be2ed6
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 16 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
if [ "${{ matrix.python-version }}" != "3.10" ]; then pip install mindspore==1.10.0; fi
- name: generate astar.so
run: |
python pygmtools/astar/get_astar.py
python pygmtools/c_astar/get_c_astar.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
pip install -r tests/requirements_win_mac.txt
- name: generate astar.so
run: |
python pygmtools/astar/get_astar.py
python pygmtools/c_astar/get_c_astar.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down Expand Up @@ -101,7 +101,7 @@ jobs:
python -m pip install -r tests\requirements_win_mac.txt
- name: generate astar.pyd
run: |
python pygmtools/astar/get_astar.py
python pygmtools/c_astar/get_c_astar.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cdef extern from "priority_queue.hpp":

@cython.boundscheck(False)
@cython.wraparound(False)
def a_star(
def c_astar(
data,
k,
vector[long] ns_1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from Cython.Build import cythonize
import numpy as np
from glob import glob

setup(
name='a-star function',
name='c_astar function',
ext_modules=cythonize(
Extension(
'a_star',
'c_astar',
glob('*.pyx'),
include_dirs=[np.get_include(),"."],
extra_compile_args=["-std=c++11"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import shutil

ori_dir = os.getcwd()
os.chdir('pygmtools/astar')
os.chdir('pygmtools/c_astar')

try:
os.system("python a_star_setup.py build_ext --inplace")
os.system("python c_astar_setup.py build_ext --inplace")
except:
os.system("python3 a_star_setup.py build_ext --inplace")
os.system("python3 c_astar_setup.py build_ext --inplace")

current_dir = os.getcwd()

Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .pytorch_astar_modules import GCNConv, AttentionModule, TensorNetworkModule, GraphPair, \
VERY_LARGE_INT, to_dense_adj, to_dense_batch, default_parameter, check_layer_parameter, node_metric
from torch import Tensor
from pygmtools.a_star import a_star
from pygmtools.c_astar import c_astar

#############################################
# Linear Assignment Problem Solvers #
Expand Down Expand Up @@ -930,10 +930,10 @@ def forward(self, data: GraphPair):
data.g1.nodes_num[i], data.g2.nodes_num[i])
num_nodes_1 = data.g1.nodes_num[i] + 1
num_nodes_2 = data.g2.nodes_num[i] + 1
x_pred[i][:num_nodes_1, :num_nodes_2] = self._a_star(cur_data)
x_pred[i][:num_nodes_1, :num_nodes_2] = self._astar(cur_data)
return x_pred[:, :-1, :-1]

def _a_star(self, data: GraphPair):
def _astar(self, data: GraphPair):

if self.args['cuda']:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -995,7 +995,7 @@ def _a_star(self, data: GraphPair):

self.reset_cache()

x_pred, _ = a_star(
x_pred, _ = c_astar(
data, k, ns_1.cpu().numpy(), ns_2.cpu().numpy(),
self.net_prediction_cache,
self.heuristic_prediction_hun,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_property(prop, project):

class BdistWheelCommand(_bdist_wheel):
def run(self):
os.system("python pygmtools/astar/get_astar.py")
os.system("python pygmtools/c_astar/get_c_astar.py")
super().run()

def get_tag(self):
Expand All @@ -74,7 +74,7 @@ def get_tag(self):
class InstallCommand(_install):
def run(self):
try:
os.system("python pygmtools/astar/get_astar.py")
os.system("python pygmtools/c_astar/get_c_astar.py")
except:
pass
_install.run(self)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _test_classic_solver_on_linear_assignment(num_nodes1, num_nodes2, node_feat_
last_X = pygm.utils.to_numpy(_X)


# The testing function for a_star
# The testing function for astar
def _test_astar(graph_num_nodes, node_feat_dim, solver_func, matrix_params, backends):
if backends[0] != 'pytorch':
backends.insert(0, 'pytorch') # force pytorch as the reference backend
Expand Down

0 comments on commit 0be2ed6

Please sign in to comment.