From 1596a74c58f4b4f2decc52561822fa31b8b7abd5 Mon Sep 17 00:00:00 2001 From: RylieWeaver <123048075+RylieWeaver@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:16:51 -0400 Subject: [PATCH] Forces test (#283) * force tests, which required model arg, and some typo fixing in Lennard Jones * Add PNAPlus since it uses positions as well * formatting --- examples/LennardJones/LJ.json | 3 +++ examples/LennardJones/LJ_data.py | 4 ++-- examples/LennardJones/LennardJones.py | 8 +++++++- tests/test_forces_equivariant.py | 27 +++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 tests/test_forces_equivariant.py diff --git a/examples/LennardJones/LJ.json b/examples/LennardJones/LJ.json index a79c5f41..a6b18f12 100644 --- a/examples/LennardJones/LJ.json +++ b/examples/LennardJones/LJ.json @@ -25,6 +25,8 @@ "int_emb_size": 32, "out_emb_size": 16, "basis_emb_size": 8, + "num_gaussians": 10, + "num_filters": 8, "num_before_skip": 1, "num_after_skip": 1, "envelope_exponent": 5, @@ -55,6 +57,7 @@ "Training": { "num_epoch": 15, "batch_size": 64, + "perc_train": 0.7, "patience": 20, "early_stopping": true, "Optimizer": { diff --git a/examples/LennardJones/LJ_data.py b/examples/LennardJones/LJ_data.py index 6226ff6f..594d6d15 100644 --- a/examples/LennardJones/LJ_data.py +++ b/examples/LennardJones/LJ_data.py @@ -103,9 +103,9 @@ def __init__(self, dirpath, config, dist=False, sampling=None): for file in rx: filepath = os.path.join(dirpath, file) - self.dataset.append(self.transform_inumpyut_to_data_object_base(filepath)) + self.dataset.append(self.transform_input_to_data_object_base(filepath)) - def transform_inumpyut_to_data_object_base(self, filepath): + def transform_input_to_data_object_base(self, filepath): # Using readline() file = open(filepath, "r") diff --git a/examples/LennardJones/LennardJones.py b/examples/LennardJones/LennardJones.py index 045b1d25..2f4e5777 100644 --- a/examples/LennardJones/LennardJones.py +++ b/examples/LennardJones/LennardJones.py @@ -61,6 +61,7 @@ help="preprocess only (no training)", ) parser.add_argument("--inputfile", help="input file", type=str, default="LJ.json") + parser.add_argument("--model_type", help="model type", type=str, default=None) parser.add_argument("--mae", action="store_true", help="do mae calculation") parser.add_argument("--ddstore", action="store_true", help="ddstore dataset") parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None) @@ -98,6 +99,11 @@ # Configurable run choices (JSON file that accompanies this example script). with open(input_filename, "r") as f: config = json.load(f) + config["NeuralNetwork"]["Architecture"]["model_type"] = ( + args.model_type + if args.model_type + else config["NeuralNetwork"]["Architecture"]["model_type"] + ) verbosity = config["Verbosity"]["level"] config["NeuralNetwork"]["Variables_of_interest"][ "graph_feature_names" @@ -159,7 +165,7 @@ ## This is a local split trainset, valset, testset = split_dataset( dataset=total, - perc_train=0.9, + perc_train=config["NeuralNetwork"]["Training"]["perc_train"], stratify_splitting=False, ) print("Local splitting: ", len(total), len(trainset), len(valset), len(testset)) diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py new file mode 100644 index 00000000..4609844c --- /dev/null +++ b/tests/test_forces_equivariant.py @@ -0,0 +1,27 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +import os +import pytest + +import subprocess + + +@pytest.mark.parametrize("example", ["LennardJones"]) +@pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus"]) +@pytest.mark.mpi_skip() +def pytest_examples(example, model_type): + path = os.path.join(os.path.dirname(__file__), "..", "examples", example) + file_path = os.path.join(path, example + ".py") # Assuming different model scripts + return_code = subprocess.call(["python", file_path, "--model_type", model_type]) + + # Check the file ran without error. + assert return_code == 0