Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanch authored and stefanch committed Feb 11, 2020
1 parent b1952fb commit ab9bef9
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 92 deletions.
4 changes: 3 additions & 1 deletion scripts/sgdml_dataset_from_extxyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def read_nonstd_ext_xyz(f):
+ ' Dataset \'{}\' already exists.'.format(dataset_file_name)
)

mols = read(dataset, index=':')
mols = read(dataset.name, index=':')

lattice, R, z, E, F = None, None, None, None, None

Expand Down Expand Up @@ -163,6 +163,8 @@ def read_nonstd_ext_xyz(f):

if 'Energy' in mols[0].info:
E = np.array([mol.info['Energy'] for mol in mols])
if 'energy' in mols[0].info:
E = np.array([mol.info['energy'] for mol in mols])
F = np.array([mol.get_forces() for mol in mols])

else: # legacy non-standard XYZ format
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from io import open
from setuptools import setup, find_packages


Expand All @@ -14,7 +15,7 @@ def get_property(property, package):
from os import path

this_dir = path.abspath(path.dirname(__file__))
with open(path.join(this_dir, 'README.md')) as f:
with open(path.join(this_dir, 'README.md'), encoding='utf8') as f:
long_description = f.read()

# Scripts
Expand Down
2 changes: 1 addition & 1 deletion sgdml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

__version__ = '0.4.0.dev2'
__version__ = '0.4.1.dev1'

MAX_PRINT_WIDTH = 100
LOG_LEVELNAME_WIDTH = 7 # do not modify
Expand Down
43 changes: 30 additions & 13 deletions sgdml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'):
n_mols, n_atoms, _ = dataset['R'].shape
print(
' {:<18} {} ({:<d} atoms)'.format(
'Name:', dataset['name'].astype(str), n_atoms
'Name:', ui.unicode_str(dataset['name']), n_atoms
)
)
print(' {:<18} {}'.format('Theory:', dataset['theory']))
print(' {:<18} {}'.format('Theory:', ui.unicode_str(dataset['theory'])))
print(' {:<18} {:,} data points'.format('Size:', n_mols))

ui.print_lattice(dataset['lattice'] if 'lattice' in dataset else None)
Expand All @@ -128,7 +128,7 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'):

e_unit = 'unknown unit'
if 'e_unit' in dataset:
e_unit = dataset['e_unit']
e_unit = ui.unicode_str(dataset['e_unit'])

print(' Energies [{}]:'.format(e_unit))
if 'E_min' in dataset and 'E_max' in dataset:
Expand All @@ -148,7 +148,9 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'):

f_unit = 'unknown unit'
if 'r_unit' in dataset and 'e_unit' in dataset:
f_unit = str(dataset['e_unit']) + '/' + str(dataset['r_unit'])
f_unit = (
ui.unicode_str(dataset['e_unit']) + '/' + ui.unicode_str(dataset['r_unit'])
)

print(' Forces [{}]:'.format(f_unit))

Expand All @@ -165,7 +167,7 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'):
F_var = dataset['F_var'] if 'F_var' in dataset else np.var(dataset['F'].ravel())
print(' {:<16} {:<.3f}'.format('Variance:', F_var))

print(' {:<18} {}'.format('Fingerprint:', dataset['md5'].astype(str)))
print(' {:<18} {}'.format('Fingerprint:', ui.unicode_str(dataset['md5'])))

idx = np.random.choice(n_mols, 1)[0]
r = dataset['R'][idx, :, :]
Expand Down Expand Up @@ -228,7 +230,7 @@ def _print_model_properties(model, title_str='Model properties'):

print(ui.white_bold_str(title_str))

print(' {:<18} {}'.format('Dataset:', model['dataset_name'].astype(str)))
print(' {:<18} {}'.format('Dataset:', ui.unicode_str(model['dataset_name'])))

n_atoms = len(model['z'])
print(' {:<18} {:<d}'.format('Atoms:', n_atoms))
Expand Down Expand Up @@ -286,11 +288,11 @@ def _print_model_properties(model, title_str='Model properties'):
f_unit = 'unknown unit'
if 'r_unit' in model and 'e_unit' in model:
e_unit = model['e_unit']
f_unit = str(model['e_unit']) + '/' + str(model['r_unit'])
f_unit = ui.unicode_str(model['e_unit']) + '/' + ui.unicode_str(model['r_unit'])

if is_valid:
action_str = 'Validation' if not is_valid else 'Expected test'
print(' {:<18}'.format('{} errors:'.format(action_str)))
print(' {:<18}'.format('{} errors (MAE/RMSE):'.format(action_str)))
if model['use_E']:
print(
' {:<16} {:>.4f}/{:>.4f} [{}]'.format(
Expand Down Expand Up @@ -503,6 +505,7 @@ def create( # noqa: C901
use_cprsn=use_cprsn,
use_E=use_E,
use_E_cstr=use_E_cstr,
model0=model0,
)

task_file_names = []
Expand Down Expand Up @@ -580,6 +583,8 @@ def create( # noqa: C901
if model0 is not None:
model0_path, model0 = model0

shutil.copy(model0_path, os.path.join(task_dir, 'm0.npz'))

try:
tmpl_task = gdml_train.create_task(
dataset,
Expand Down Expand Up @@ -652,6 +657,13 @@ def cprsn_callback(n_atoms, n_atoms_kept):
ker_progr_callback = partial(ui.progr_bar, disp_str='Assembling kernel matrix...')
solve_callback = partial(ui.progr_toggle, disp_str='Solving linear system... ')

def save_progr_callback(
unconv_model
): # saves current (unconverged) model during iterative training
unconv_model_file = '_unconv_model.npz'
unconv_model_path = os.path.join(task_dir, unconv_model_file)
np.savez_compressed(unconv_model_path, **unconv_model)

gdml_train = GDMLTrain(max_processes=max_processes, use_torch=use_torch)
for i, task_file_name in enumerate(task_file_names):
if n_tasks > 1:
Expand Down Expand Up @@ -688,6 +700,7 @@ def cprsn_callback(n_atoms, n_atoms_kept):
desc_callback,
ker_progr_callback,
solve_callback,
save_progr_callback,
)
except Exception as err:
print()
Expand Down Expand Up @@ -877,11 +890,15 @@ def test(
gdml = GDMLTrain(max_processes=max_processes)

# exclude training and/or test sets from validation set if necessary
excl_idxs = np.empty((0,), dtype=int)
excl_idxs = np.empty((0,), dtype=np.uint)
if dataset['md5'] == model['md5_train']:
excl_idxs = np.concatenate([excl_idxs, model['idxs_train']])
excl_idxs = np.concatenate([excl_idxs, model['idxs_train']]).astype(
np.uint
)
if dataset['md5'] == model['md5_valid']:
excl_idxs = np.concatenate([excl_idxs, model['idxs_valid']])
excl_idxs = np.concatenate([excl_idxs, model['idxs_valid']]).astype(
np.uint
)
if len(excl_idxs) == 0:
excl_idxs = None

Expand Down Expand Up @@ -1502,14 +1519,14 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
'--cg',
dest='use_cg',
action='store_true',
#help='use iterative solver (conjugate gradient) with Nystroem preconditioner',
# help='use iterative solver (conjugate gradient) with Nystroem preconditioner',
help=argparse.SUPPRESS
)
group.add_argument(
'--fk',
dest='use_fk',
action='store_true',
#help='use iterative solver (conjugate gradient) with Nystroem approximation',
# help='use iterative solver (conjugate gradient) with Nystroem approximation',
help=argparse.SUPPRESS
)

Expand Down
3 changes: 2 additions & 1 deletion sgdml/intf/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
Path to a sGDML model file
E_to_eV : float, optional
Conversion factor from whatever energy unit is used by the model to eV. By default this parameter is set to convert from kcal/mol.
use_torch : boolean, optional
F_to_eV_Ang : boolean, optional
Conversion factor from whatever length unit is used by the model to Angstrom. By default, the length unit is not converted.
"""

Expand All @@ -62,6 +62,7 @@ def __init__(

model = np.load(model_path)
self.gdml_predict = GDMLPredict(model)
# self.gdml_predict.prepare_parallel()

self.log.warning(
'Please remember to specify the proper conversion factors, if your model does not use \'kcal/mol\' and \'Ang\' as units.'
Expand Down
35 changes: 14 additions & 21 deletions sgdml/predict.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,6 @@ def __init__(
else None
)

# from packaging import version
# print(version.parse('2.3.2dev0') > version.parse('2.3.1dev1'))

# print(model['code_version'])
# legacy support
# if version.parse(model['code_version']) <= version.parse('0.3.5.dev5'):
# model['R_desc'] = model['R_desc'].T

self.n_train = model['R_desc'].shape[1]
sig = model['sig']

Expand Down Expand Up @@ -307,13 +299,15 @@ def __init__(
self.torch_device
)

# enable data parallelism
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
self.torch_predict = torch.nn.DataParallel(self.torch_predict)
# self.torch_predict.to(self.torch_device) # needed?

is_cuda = next(self.torch_predict.parameters()).is_cuda
if is_cuda:
self.log.info(
'Numbers of CUDA devices found: {:d}'.format(
torch.cuda.device_count()
)
)
self.log.info('Numbers of CUDA devices found: {:d}'.format(n_gpu))
else:
self.log.warning(
'No CUDA devices found! PyTorch is running on the CPU.'
Expand Down Expand Up @@ -632,9 +626,7 @@ def _dummy_predict():

i_done += 1

gps = (
n_bulk * n_reps / (timeit.timeit(_dummy_predict, number=n_reps))
)
gps = n_bulk * n_reps / timeit.timeit(_dummy_predict, number=n_reps)

# print(
# '{:2d}@{:d} {:d} | {:7.2f} gps'.format(
Expand Down Expand Up @@ -873,6 +865,12 @@ def _predict_bulk_train(self, R_desc, R_d_desc):
# F = res[1:].reshape(1, -1).dot(r_d_desc)
# return res[1:].reshape(1, -1)

if self._num_workers == 1: # HACK
self.log.critical(
'Bulk (train) predictions are not possible with just one process (not implemented).'
)
sys.exit()

if self._bulk_mp is False: # HACK!
self._set_bulk_mp(True)

Expand Down Expand Up @@ -987,11 +985,6 @@ def predict(self, r):

Rs = torch.from_numpy(r.reshape(M, -1, 3)).to(self.torch_device)

# enable data parallelism
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
self.torch_predict = torch.nn.DataParallel(self.torch_predict)

e_pred, f_pred = self.torch_predict.forward(Rs)

E = e_pred.cpu().numpy()
Expand Down
70 changes: 54 additions & 16 deletions sgdml/torchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GDMLTorchPredict(nn.Module):
:class:`torch.nn.Module`. Contains no trainable parameters.
"""

def __init__(self, model, lat_and_inv=None, batch_size=None, max_memory=1.0):
def __init__(self, model, lat_and_inv=None, batch_size=None, max_memory=4.0):
"""
Parameters
----------
Expand Down Expand Up @@ -77,25 +77,44 @@ def __init__(self, model, lat_and_inv=None, batch_size=None, max_memory=1.0):
)

self._batch_size = batch_size
# self._max_memory = int(2 ** 30 * max_memory) if max_memory is not None else torch.cuda.get_device_properties(0).total_memory
self._max_memory = int(2 ** 30 * max_memory)
self._sig = int(model['sig'])
self._c = float(model['c'])
self._std = float(model.get('std', 1))

desc_siz = model['R_desc'].shape[0]
n_perms, self._n_atoms = model['perms'].shape
perm_idxs = torch.tensor(model['tril_perms_lin']).view(-1, n_perms).t()
perm_idxs = (
torch.tensor(model['tril_perms_lin'], device=self._dev)
.view(-1, n_perms)
.t()
)
self._xs_train, self._Jx_alphas = (
nn.Parameter(
xs.repeat(1, n_perms)[:, perm_idxs].reshape(-1, desc_siz),
requires_grad=False,
)
for xs in (
torch.tensor(model['R_desc']).t(),
torch.tensor(np.array(model['R_d_desc_alpha'])),
torch.tensor(model['R_desc'], device=self._dev).t(),
torch.tensor(np.array(model['R_d_desc_alpha']), device=self._dev),
)
)

# DEBUG
# cuda_check = self._xs_train.is_cuda
# if cuda_check:
# get_cuda_device = self._xs_train.get_device()
# print('_xs_train')
# print(get_cuda_device)

# DEBUG
# cuda_check = self._Jx_alphas.is_cuda
# if cuda_check:
# get_cuda_device = self._Jx_alphas.get_device()
# print('self._Jx_alphas')
# print(get_cuda_device)

self.desc_siz = desc_siz
self.perm_idxs = perm_idxs
self.n_perms = n_perms
Expand All @@ -105,7 +124,7 @@ def set_alphas(self, R_d_desc, alphas):
r_dim = R_d_desc.shape[2]
R_d_desc_alpha = np.einsum('kji,ki->kj', R_d_desc, alphas.reshape(-1, r_dim))

xs = torch.tensor(np.array(R_d_desc_alpha)).to(self._dev)
xs = torch.tensor(np.array(R_d_desc_alpha), device=self._dev) # .to(self._dev)
self._Jx_alphas = nn.Parameter(
xs.repeat(1, self.n_perms)[:, self.perm_idxs].reshape(-1, self.desc_siz),
requires_grad=False,
Expand Down Expand Up @@ -174,18 +193,37 @@ def forward(self, Rs, batch_size=None, max_memory=None):
Rs = Rs.double()
batch_size = self._batch_size or self._max_memory // self._memory_per_sample()

# print('batch_size')
# print(batch_size)
# print(Rs.shape)
if torch.cuda.is_available():
batch_size *= torch.cuda.device_count()

try:
Es, Fs = zip(*map(self._forward, DataLoader(Rs, batch_size=batch_size)))

except RuntimeError as e:
if 'out of memory' in str(e):

print('NOTE: ran out of memory, but retrying!')

if batch_size > 2:

import gc
gc.collect()

torch.cuda.empty_cache()

# reverse batch multiplication
if torch.cuda.is_available():
batch_size /= torch.cuda.device_count()

# if torch.cuda.is_available():
# t = torch.cuda.get_device_properties(0).total_memory
# c = torch.cuda.memory_cached(0)
# a = torch.cuda.memory_allocated(0)
# f = c-a # free inside cache
# print('free memory')
# print(f)
batch_size = int(batch_size / 0.1)
self._batch_size = batch_size

Es, Fs = zip(*map(self._forward, DataLoader(Rs, batch_size=batch_size)))
return self.forward(Rs, batch_size=self._batch_size)
else:
print('ERROR: ran out of memory, FAILED!')
sys.exit()
else:
raise e

# Es, Fs = zip(*map(self._forward, DataLoader(Rs, batch_size=batch_size)))
return torch.cat(Es).to(dtype), torch.cat(Fs).to(dtype)
Loading

0 comments on commit ab9bef9

Please sign in to comment.