Skip to content

Commit

Permalink
improve printing protocol for several featurizers using fortran
Browse files Browse the repository at this point in the history
  • Loading branch information
Qi-max committed Oct 19, 2019
1 parent bf15da0 commit f258ac9
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 221 deletions.
33 changes: 15 additions & 18 deletions amlearn/featurize/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class BaseNN(six.with_metaclass(ABCMeta, BaseEstimator, TransformerMixin)):
def __init__(self, cutoff=5, allow_neighbor_limit=300, n_neighbor_limit=80,
type_col='type', coords_cols=None, pbc=None, bds=None,
save=True, backend=None, output_path=None,
output_file_prefix=None):
output_file_prefix=None, print_freq=1000):
self.cutoff = cutoff
self.allow_neighbor_limit = allow_neighbor_limit
self.n_neighbor_limit = n_neighbor_limit
Expand All @@ -48,6 +48,7 @@ def __init__(self, cutoff=5, allow_neighbor_limit=300, n_neighbor_limit=80,
self.backend = backend if backend is not None \
else create_featurizer_backend(output_path=output_path)
self.output_file_prefix = output_file_prefix
self.print_freq = print_freq

def fit_transform(self, X=None, y=None, **fit_params):
return self.transform(X)
Expand All @@ -65,13 +66,13 @@ def __init__(self, small_face_thres=0.05, cutoff=5,
allow_neighbor_limit=300, n_neighbor_limit=80,
type_col='type', coords_cols=None, pbc=None,
bds=None, save=True, backend=None, output_path=None,
output_file_prefix='voro_nn'):
output_file_prefix='voro_nn', print_freq=1000):
super(VoroNN, self).__init__(
cutoff=cutoff, allow_neighbor_limit=allow_neighbor_limit,
n_neighbor_limit=n_neighbor_limit, type_col=type_col,
coords_cols=coords_cols, pbc=pbc, bds=bds, save=save,
backend=backend, output_path=output_path,
output_file_prefix=output_file_prefix)
output_file_prefix=output_file_prefix, print_freq=print_freq)
self.small_face_thres = small_face_thres

def transform(self, X=None):
Expand All @@ -98,21 +99,18 @@ def transform(self, X=None):
neighbor_dist_lists = \
np.zeros((n_atoms, self.n_neighbor_limit), dtype=np.longdouble)

n_edge_max = 0
n_neighbor_max = 0

neighbor_num_list, neighbor_id_lists, neighbor_area_lists, \
neighbor_vol_lists, neighbor_dist_lists, neighbor_edge_lists, \
n_neighbor_max, n_edge_max = \
(neighbor_num_list, neighbor_id_lists, neighbor_area_lists,
neighbor_vol_lists, neighbor_dist_lists, neighbor_edge_lists)= \
voronoi_nn.voronoi(X[self.type_col].values,
X[self.coords_cols].values,
self.cutoff, self.allow_neighbor_limit,
self.small_face_thres, self.pbc, self.bds,
neighbor_num_list, neighbor_id_lists,
neighbor_area_lists, neighbor_vol_lists,
neighbor_dist_lists, neighbor_edge_lists,
n_neighbor_max, n_edge_max, n_atoms=n_atoms,
n_neighbor_limit=self.n_neighbor_limit)
n_atoms=n_atoms,
n_neighbor_limit=self.n_neighbor_limit,
print_freq=self.print_freq)

voro_props = list()
for neighbor_num, neighbor_id_list, \
Expand Down Expand Up @@ -146,13 +144,13 @@ def __init__(self, cutoff=4, allow_neighbor_limit=300,
n_neighbor_limit=80, type_col='type',
coords_cols=None, pbc=None, bds=None,
backend=None, save=True, output_path=None,
output_file_prefix='dist_nn'):
output_file_prefix='dist_nn', print_freq=1000):
super(DistanceNN, self).__init__(
cutoff=cutoff, allow_neighbor_limit=allow_neighbor_limit,
n_neighbor_limit=n_neighbor_limit, type_col=type_col,
coords_cols=coords_cols, pbc=pbc, bds=bds, save=save,
backend=backend, output_path=output_path,
output_file_prefix=output_file_prefix)
output_file_prefix=output_file_prefix, print_freq=print_freq)

def transform(self, X=None):
"""
Expand All @@ -172,15 +170,14 @@ def transform(self, X=None):
neighbor_dist_lists = \
np.zeros((n_atoms, self.n_neighbor_limit), dtype=np.longdouble)

n_neighbor_max = 0
(n_neighbor_max, neighbor_num_list, neighbor_id_lists,
neighbor_dist_lists) = \
(neighbor_num_list, neighbor_id_lists, neighbor_dist_lists) = \
distance_nn.distance_neighbor(
X[self.type_col].values, X[self.coords_cols].values,
self.cutoff, self.allow_neighbor_limit, self.pbc,
self.bds, n_neighbor_max, neighbor_num_list,
self.bds, neighbor_num_list,
neighbor_id_lists, neighbor_dist_lists,
n_atoms=n_atoms, n_neighbor_limit=self.n_neighbor_limit)
n_atoms=n_atoms, n_neighbor_limit=self.n_neighbor_limit,
print_freq=self.print_freq)

dist_props = list()
for neighbor_num, neighbor_id_list, neighbor_dist_list in \
Expand Down
Loading

0 comments on commit f258ac9

Please sign in to comment.