Skip to content

Commit

Permalink
Sh/ep device (#360)
Browse files Browse the repository at this point in the history
* remove hardcoded device for environment provider

* added option for environment provider device in scripts

* black update

Co-authored-by: Zeyuan Tang <zeyuan.tang@phys.au.dk>
  • Loading branch information
stefaanhessmann and zyt0y authored Nov 10, 2021
1 parent 6722679 commit d42ee91
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def read(fname):
"tqdm",
"pyyaml",
],
extras_require={
"test": ["pytest", "pytest-console-scripts", "pytest-datadir"]
},
extras_require={"test": ["pytest", "pytest-console-scripts", "pytest-datadir"]},
license="MIT",
description="SchNetPack - Deep Neural Networks for Atomistic Systems",
long_description="""
Expand Down
12 changes: 12 additions & 0 deletions src/schnetpack/utils/script_utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def get_mode_parsers():
help="Number of checkpoints that will be stored (default: %(default)s)",
default=3,
)
train_parser.add_argument(
"--environment_provider_device",
help="Choose device for environment providers. It is recommended to keep CPU. (default: %(default)s)",
choices=["cpu", "cuda"],
default="cpu",
)

# evaluation parser
eval_parser = ArgumentParser(add_help=False)
Expand Down Expand Up @@ -173,6 +179,12 @@ def get_mode_parsers():
eval_parser.add_argument(
"--overwrite", help="Remove previous evaluation files", action="store_true"
)
eval_parser.add_argument(
"--environment_provider_device",
help="Choose device for environment providers. It is recommended to keep CPU. (default: %(default)s)",
choices=["cpu", "cuda"],
default="cpu",
)

return json_parser, train_parser, eval_parser

Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/utils/script_utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_environment_provider(args, device):
return spk.environment.AseEnvironmentProvider(cutoff=args.cutoff)
elif args.environment_provider == "torch":
return spk.environment.TorchEnvironmentProvider(
cutoff=args.cutoff, device="cpu"
cutoff=args.cutoff, device=device
)
else:
raise NotImplementedError
3 changes: 2 additions & 1 deletion src/scripts/spk_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def main(args):
train_args = setup_run(args)

device = torch.device("cuda" if args.cuda else "cpu")
ep_device = torch.device(args.environment_provider_device)

# get dataset
environment_provider = get_environment_provider(train_args, device=device)
environment_provider = get_environment_provider(train_args, device=ep_device)
dataset = get_dataset(train_args, environment_provider=environment_provider)

# get dataloaders
Expand Down

0 comments on commit d42ee91

Please sign in to comment.