Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 30, 2024
1 parent d2054b6 commit db0e577
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
22 changes: 16 additions & 6 deletions dpgen2/exploration/task/lmp/lmp_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,17 @@ def make_lmp_input(
graph_list = ""
for ii in graphs:
graph_list += ii + " "
model_devi_file_name = lmp_pimd_model_devi_name % pimd_bead if pimd_bead is not None else lmp_model_devi_name
model_devi_file_name = (
lmp_pimd_model_devi_name % pimd_bead
if pimd_bead is not None
else lmp_model_devi_name
)
if Version(deepmd_version) < Version("1"):
# 0.x
ret += "pair_style deepmd %s ${THERMO_FREQ} %s\n" % (graph_list, model_devi_file_name)
ret += "pair_style deepmd %s ${THERMO_FREQ} %s\n" % (

Check warning on line 111 in dpgen2/exploration/task/lmp/lmp_input.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/exploration/task/lmp/lmp_input.py#L111

Added line #L111 was not covered by tests
graph_list,
model_devi_file_name,
)
else:
# 1.x
keywords = ""
Expand All @@ -118,9 +125,10 @@ def make_lmp_input(
keywords += "fparam ${ELE_TEMP}"
if ele_temp_a is not None:
keywords += "aparam ${ELE_TEMP}"
ret += (
"pair_style deepmd %s out_freq ${THERMO_FREQ} out_file %s %s\n"
% (graph_list, model_devi_file_name, keywords)
ret += "pair_style deepmd %s out_freq ${THERMO_FREQ} out_file %s %s\n" % (
graph_list,
model_devi_file_name,
keywords,
)
ret += "pair_coeff * *\n"
ret += "\n"
Expand All @@ -129,7 +137,9 @@ def make_lmp_input(
if trj_seperate_files:
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n"
else:
lmp_traj_file_name = lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name
lmp_traj_file_name = (
lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name
)
ret += (
"dump 1 all custom ${DUMP_FREQ} %s id type x y z fx fy fz\n"
% lmp_traj_file_name
Expand Down
34 changes: 22 additions & 12 deletions dpgen2/exploration/task/lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def set_lmp(
self.extra_pair_style_args,
self.pimd_bead,
)
self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq, self.pimd_bead)
self.lmp_template = revise_lmp_input_dump(
self.lmp_template, self.traj_freq, self.pimd_bead
)
if plm_template_fname is not None:
self.plm_template = Path(plm_template_fname).read_text().split("\n")
self.plm_set = True
Expand Down Expand Up @@ -150,28 +152,36 @@ def find_only_one_key(lmp_lines, key):


def revise_lmp_input_model(
lmp_lines, task_model_list, trj_freq, extra_pair_style_args="", pimd_bead=None, deepmd_version="1"
lmp_lines,
task_model_list,
trj_freq,
extra_pair_style_args="",
pimd_bead=None,
deepmd_version="1",
):
idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"])
if extra_pair_style_args:
extra_pair_style_args = " " + extra_pair_style_args
graph_list = " ".join(task_model_list)
model_devi_file_name = lmp_pimd_model_devi_name % pimd_bead if pimd_bead is not None else lmp_model_devi_name
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file %s%s"
% (
graph_list,
trj_freq,
model_devi_file_name,
extra_pair_style_args,
)
model_devi_file_name = (
lmp_pimd_model_devi_name % pimd_bead
if pimd_bead is not None
else lmp_model_devi_name
)
lmp_lines[idx] = "pair_style deepmd %s out_freq %d out_file %s%s" % (
graph_list,
trj_freq,
model_devi_file_name,
extra_pair_style_args,
)
return lmp_lines


def revise_lmp_input_dump(lmp_lines, trj_freq, pimd_bead=None):
idx = find_only_one_key(lmp_lines, ["dump", "dpgen_dump"])
lmp_traj_file_name = lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name
lmp_traj_file_name = (
lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name
)
lmp_lines[idx] = (
f"dump dpgen_dump all custom %d {lmp_traj_file_name} id type x y z"
% trj_freq
Expand Down
21 changes: 16 additions & 5 deletions tests/op/test_run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ class TestMergePIMDFiles(unittest.TestCase):
def test_merge_pimd_files(self):
for i in range(1, 3):
with open("traj.%s.dump" % i, "w") as f:
f.write("""ITEM: TIMESTEP
f.write(
"""ITEM: TIMESTEP
0
ITEM: NUMBER OF ATOMS
3
Expand All @@ -318,13 +319,16 @@ def test_merge_pimd_files(self):
1 8 7.23103 0.814939 4.59892
2 1 7.96453 0.61699 5.19158
3 1 6.43661 0.370311 5.09854
""")
"""
)
for i in range(1, 3):
with open("model_devi.%s.out" % i, "w") as f:
f.write("""# step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f
f.write(
"""# step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f
0 9.023897e-17 3.548771e-17 5.237314e-17 8.196123e-16 1.225653e-16 3.941002e-16
10 1.081667e-16 4.141596e-17 7.534462e-17 9.070597e-16 1.067947e-16 4.153524e-16
""")
"""
)

merge_pimd_files()
self.assertTrue(os.path.exists(lmp_traj_name))
Expand All @@ -335,6 +339,13 @@ def test_merge_pimd_files(self):
assert model_devi.shape[0] == 4

def tearDown(self):
for f in [lmp_traj_name, "traj.1.dump", "traj.2.dump", lmp_model_devi_name, "model_devi.1.out", "model_devi.2.out"]:
for f in [
lmp_traj_name,
"traj.1.dump",
"traj.2.dump",
lmp_model_devi_name,
"model_devi.1.out",
"model_devi.2.out",
]:
if os.path.exists(f):
os.remove(f)

0 comments on commit db0e577

Please sign in to comment.