Skip to content

Commit

Permalink
Add .save_as_python_script method and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ronald-jaepel committed Mar 22, 2024
1 parent 645ccce commit efefceb
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
94 changes: 94 additions & 0 deletions cadet/cadet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ def load_json(self, filename, update=False):
else:
self.root = data

def save_as_python_script(self, filename: str, only_return_pythonic_representation=False):
if not filename.endswith(".py"):
raise Warning(f"The filename given to .save_as_python_script isn't a python file name.")

code_lines_list = [
"import numpy",
"from cadet import Cadet",
"",
"sim = Cadet()",
"root = sim.root",
]

code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root,
current_lines_list=code_lines_list,
prefix="root")

filename_for_reproduced_h5_file = filename.replace(".py", ".h5")
code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'")
code_lines_list.append("sim.save()")

if not only_return_pythonic_representation:
with open(filename, "w") as handle:
handle.writelines([line + "\n" for line in code_lines_list])
else:
return code_lines_list

def append(self, lock=False):
"This can only be used to write new keys to the system, this is faster than having to read the data before writing it"
if self.filename is not None:
Expand Down Expand Up @@ -347,3 +373,71 @@ def recursively_save(h5file, path, dic, func):
raise KeyError(f'Name conflict with upper and lower case entries for key "{path}{key}".')
else:
raise


def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None):
"""
Recursively turn a nested dictionary or addict.Dict into a list of Python code that
can generate the nested dictionary.
:param dictionary:
:param current_lines_list:
:param prefix_list:
:return: list of Python code lines
"""

def merge_to_absolute_key(prefix, key):
"""
Combine key and prefix to "prefix.key" except if there is no prefix, then return key
"""
if prefix is None:
return key
else:
return f"{prefix}.{key}"

def clean_up_key(absolute_key: str):
"""
Remove problematic phrases from key, such as blank "return"
:param absolute_key:
:return:
"""
absolute_key = absolute_key.replace(".return", "['return']")
return absolute_key

def get_pythonic_representation_of_value(value):
"""
Use repr() to get a pythonic representation of the value
and add "np." to "array" and "float64"
"""
value_representation = repr(value)
value_representation = value_representation.replace("array", "numpy.array")
value_representation = value_representation.replace("float64", "numpy.float64")
try:
eval(value_representation)
except NameError as e:
raise ValueError(
f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n"
f"Please report this to the CADET-Python developers.") from e

return value_representation

if current_lines_list is None:
current_lines_list = []

for key in sorted(dictionary.keys()):
value = dictionary[key]

absolute_key = merge_to_absolute_key(prefix, key)

if type(value) in (dict, Dict):
current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key)
else:
value_representation = get_pythonic_representation_of_value(value)

absolute_key = clean_up_key(absolute_key)

current_lines_list.append(f"{absolute_key} = {value_representation}")

return current_lines_list
64 changes: 64 additions & 0 deletions tests/test_save_as_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import tempfile

import numpy as np
import pytest
from addict import Dict

from cadet import Cadet


@pytest.fixture
def temp_cadet_file():
"""
Create a new Cadet object for use in tests.
"""
model = Cadet()

with tempfile.NamedTemporaryFile() as temp:
model.filename = temp
yield model


def test_save_as_python(temp_cadet_file):
"""
Test that the Cadet class raises a KeyError exception when duplicate keys are set on it.
"""
# initialize "sim" variable to be overwritten by the exec lines later
sim = Cadet()

# Populate temp_cadet_file with all tricky cases currently known
temp_cadet_file.root.input.foo = 1
temp_cadet_file.root.input.bar.baryon = np.arange(10)
temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9)
temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64)
temp_cadet_file.root.input["return"].split_foobar = 1

code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True)

# remove code lines that save the file
code_lines = code_lines[:-2]

# populate "sim" variable using the generated code lines
for line in code_lines:
exec(line)

# test that "sim" is equal to "temp_cadet_file"
recursive_equality_check(sim.root, temp_cadet_file.root)


def recursive_equality_check(dict_a: dict, dict_b: dict):
assert dict_a.keys() == dict_b.keys()
for key in dict_a.keys():
value_a = dict_a[key]
value_b = dict_b[key]
if type(value_a) in (dict, Dict):
recursive_equality_check(value_a, value_b)
elif type(value_a) == np.ndarray:
np.testing.assert_array_equal(value_a, value_b)
else:
assert value_a == value_b
return True


if __name__ == "__main__":
pytest.main()

0 comments on commit efefceb

Please sign in to comment.