Skip to content

Commit

Permalink
Add test for mixed voigt/3x3 stress in error table
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Dec 3, 2024
1 parent ae6718c commit 1dac49e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
12 changes: 12 additions & 0 deletions tests/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from ase.atoms import Atoms
from ase.calculators.lj import LennardJones
from ase.stress import voigt_6_to_full_3x3_stress
from pytest import approx

from pprint import pprint
Expand Down Expand Up @@ -103,6 +104,17 @@ def test_err_from_calc(ref_atoms):
assert ref_err_dict['virial/atom/comp']['_ALL_']["count"] == 10 * 6


def test_err_stress_shape(ref_atoms):
ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_')
ref_err_dict, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_')

for at in ref_atoms_calc:
at.info["REF_stress"] = voigt_6_to_full_3x3_stress(at.info["REF_stress"])
ref_err_dict_shape, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_')

assert ref_err_dict == ref_err_dict_shape


def test_error_properties(ref_atoms):
ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_')
# both energy and per atom
Expand Down
8 changes: 3 additions & 5 deletions wfl/fit/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def _reshape_normalize(quant, prop, atoms, per_atom):
quant: 2-d array containing reshaped quantity, with leading dimension 1 for per-config
or len(atoms) for per-atom
"""
# convert scalars or lists into arrays
quant = np.asarray(quant)

# fix shape of stress/virial
if prop.startswith("stress") or prop.startswith("virial"):
Expand Down Expand Up @@ -156,9 +154,9 @@ def _reshape_normalize(quant, prop, atoms, per_atom):
raise ValueError("/atom only possible in config_properties")
data = at.arrays

# grab data
ref_quant = data.get(ref_property_prefix + prop_use)
calc_quant = data.get(calc_property_prefix + prop_use)
# grab data, make a copy so normalization doesn't affect original
ref_quant = np.asarray(data.get(ref_property_prefix + prop_use)).copy()
calc_quant = np.asarray(data.get(calc_property_prefix + prop_use)).copy()
if ref_quant is None or calc_quant is None:
# warn if data is missing by reporting summary at the very end
if prop not in missed_prop_counter:
Expand Down

0 comments on commit 1dac49e

Please sign in to comment.