Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I1906 solution json #1909

Merged
merged 8 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Features

- Added an option to force install compatible versions of jax and jaxlib if already installed using CLI ([#1881](https://github.com/pybamm-team/PyBaMM/pull/1881))
- Allow pybamm.Solution.save_data() to return a string if filename is None, and added json to_format option ([#1909](https://github.com/pybamm-team/PyBaMM/pull/1909)

## Bug fixes

Expand Down
46 changes: 42 additions & 4 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Solution class
#
import casadi
import json
import numbers
import numpy as np
import pickle
Expand All @@ -10,6 +11,19 @@
from scipy.io import savemat


class NumpyEncoder(json.JSONEncoder):
"""
Numpy serialiser helper class that converts numpy arrays to a list
https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
"""

def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
# won't be called since we only need to convert numpy arrays
return json.JSONEncoder.default(self, obj) # pragma: no cover


class Solution(object):
"""
Class containing the solution of, and various attributes associated with, a PyBaMM
Expand Down Expand Up @@ -536,14 +550,17 @@ def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)

def save_data(self, filename, variables=None, to_format="pickle", short_names=None):
def save_data(
self, filename=None, variables=None,
to_format="pickle", short_names=None
):
"""
Save solution data only (raw arrays)

Parameters
----------
filename : str
The name of the file to save data to
filename : str, optional
The name of the file to save data to. If None, then a str is returned
variables : list, optional
List of variables to save. If None, saves all of the variables that have
been created so far
Expand All @@ -553,12 +570,19 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No
- 'pickle' (default): creates a pickle file with the data dictionary
- 'matlab': creates a .mat file, for loading in matlab
- 'csv': creates a csv file (0D variables only)
- 'json': creates a json file
short_names : dict, optional
Dictionary of shortened names to use when saving. This may be necessary when
saving to MATLAB, since no spaces or special characters are allowed in
MATLAB variable names. Note that not all the variables need to be given
a short name.

Returns
-------
data : str, optional
str if 'csv' or 'json' is chosen and filename is None, otherwise None


"""
if variables is None:
# variables not explicitly provided -> save all variables that have been
Expand Down Expand Up @@ -588,9 +612,17 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No
data_short_names[name] = var

if to_format == "pickle":
if filename is None:
raise ValueError(
"pickle format must be written to a file"
)
with open(filename, "wb") as f:
pickle.dump(data_short_names, f, pickle.HIGHEST_PROTOCOL)
elif to_format == "matlab":
if filename is None:
raise ValueError(
"matlab format must be written to a file"
)
# Check all the variable names only contain a-z, A-Z or _ or numbers
for name in data_short_names.keys():
# Check the string only contains the following ASCII:
Expand Down Expand Up @@ -625,7 +657,13 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No
)
)
df = pd.DataFrame(data_short_names)
df.to_csv(filename, index=False)
return df.to_csv(filename, index=False)
elif to_format == "json":
if filename is None:
return json.dumps(data_short_names, cls=NumpyEncoder)
else:
with open(filename, "w") as outfile:
json.dump(data_short_names, outfile, cls=NumpyEncoder)
else:
raise ValueError("format '{}' not recognised".format(to_format))

Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_solvers/test_solution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Tests for the Solution class
#
import json
import pybamm
import unittest
import numpy as np
Expand Down Expand Up @@ -237,9 +238,13 @@ def test_save(self):
# test save data
with self.assertRaises(ValueError):
solution.save_data("test.pickle")

# set variables first then save
solution.update(["c", "d"])
with self.assertRaisesRegex(ValueError, "pickle"):
solution.save_data(to_format="pickle")
solution.save_data("test.pickle")

data_load = pybamm.load("test.pickle")
np.testing.assert_array_equal(solution.data["c"], data_load["c"])
np.testing.assert_array_equal(solution.data["d"], data_load["d"])
Expand All @@ -250,6 +255,9 @@ def test_save(self):
np.testing.assert_array_equal(solution.data["c"], data_load["c"].flatten())
np.testing.assert_array_equal(solution.data["d"], data_load["d"])

with self.assertRaisesRegex(ValueError, "matlab"):
solution.save_data(to_format="matlab")

# to matlab with bad variables name fails
solution.update(["c + d"])
with self.assertRaisesRegex(ValueError, "Invalid character"):
Expand All @@ -268,11 +276,36 @@ def test_save(self):
solution.save_data("test.csv", to_format="csv")
# only save "c" and "2c"
solution.save_data("test.csv", ["c", "2c"], to_format="csv")
csv_str = solution.save_data(variables=["c", "2c"], to_format="csv")

# check string is the same as the file
with open('test.csv') as f:
# need to strip \r chars for windows
self.assertEqual(
csv_str.replace('\r', ''), f.read()
)

# read csv
df = pd.read_csv("test.csv")
np.testing.assert_array_almost_equal(df["c"], solution.data["c"])
np.testing.assert_array_almost_equal(df["2c"], solution.data["2c"])

# to json
solution.save_data("test.json", to_format="json")
json_str = solution.save_data(to_format="json")

# check string is the same as the file
with open('test.json') as f:
# need to strip \r chars for windows
self.assertEqual(
json_str.replace('\r', ''), f.read()
)

# check if string has the right values
json_data = json.loads(json_str)
np.testing.assert_array_almost_equal(json_data["c"], solution.data["c"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it only almost equal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was only copying the other lines of this test. But it would need to be almost equal, as the numbers (e.g. 1/3) will be rounded off when they are written to json

np.testing.assert_array_almost_equal(json_data["d"], solution.data["d"])

# raise error if format is unknown
with self.assertRaisesRegex(ValueError, "format 'wrong_format' not recognised"):
solution.save_data("test.csv", to_format="wrong_format")
Expand Down