Skip to content

Commit

Permalink
alt np fix
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Oct 1, 2024
1 parent 6949500 commit 03113cb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
14 changes: 4 additions & 10 deletions qcelemental/models/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,10 @@

def generate_caster(dtype):
def cast_to_np(v):
# for driver=properties
if isinstance(v, dict):
vv = {}
for key, val in v.items():
try:
val = np.asarray(val, dtype=dtype)
except ValueError:
raise ValueError(f"Could not cast {val} to NumPy Array!")
vv[key] = val
return vv
if isinstance(v, (float, dict)):
return v
elif isinstance(v, int):
return float(v)

try:
v = np.asarray(v, dtype=dtype)
Expand Down
35 changes: 35 additions & 0 deletions qcelemental/tests/test_model_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import copy
import warnings

import numpy as np
import pydantic
import pytest
Expand Down Expand Up @@ -618,3 +621,35 @@ def test_result_model_deprecations(result_data_fixture, optimization_data_fixtur

with pytest.warns(DeprecationWarning):
qcel.models.v1.Optimization(**optimization_data_fixture)


@pytest.mark.parametrize(
"retres,atprop,rettyp",
[
(15, "mp2_correlation_energy", float),
(15.0, "mp2_correlation_energy", float),
([1.0, -2.5, 3, 0, 0, 0, 0, 0, 0], "return_gradient", np.ndarray),
(np.array([1.0, -2.5, 3, 0, 0, 0, 0, 0, 0]), "return_gradient", np.ndarray),
({"cat1": "tail", "cat2": "whiskers"}, None, str),
({"float1": 4.4, "float2": -9.9}, None, float),
({"list1": [-1.0, 4.4], "list2": [-9.9, 2]}, None, list),
({"arr1": np.array([-1.0, 4.4]), "arr2": np.array([-9.9, 2])}, None, np.ndarray),
],
)
def test_return_result_types(result_data_fixture, retres, atprop, rettyp, request, schema_versions):
AtomicResult = schema_versions.AtomicResult

working_res = copy.deepcopy(result_data_fixture)
working_res["return_result"] = retres
if atprop:
working_res["properties"]["calcinfo_natom"] = 3
working_res["properties"][atprop] = retres
atres = AtomicResult(**working_res)

if isinstance(retres, dict):
for v in atres.return_result.values():
assert isinstance(v, rettyp)
else:
if atprop:
assert isinstance(getattr(atres.properties, atprop), rettyp)
assert isinstance(atres.return_result, rettyp)

0 comments on commit 03113cb

Please sign in to comment.