diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d3d21fc..7aa49fec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,14 +8,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.5 + rev: v0.4.3 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -24,7 +24,7 @@ repos: exclude: ^tests - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.10.0 hooks: - id: mypy @@ -37,7 +37,7 @@ repos: additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11 - repo: https://github.com/MarcoGorelli/cython-lint - rev: v0.16.0 + rev: v0.16.2 hooks: - id: cython-lint args: [--no-pycodestyle] @@ -49,7 +49,7 @@ repos: - id: blacken-docs - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.39.0 + rev: v0.40.0 hooks: - id: markdownlint # MD013: line too long diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 2c17d926..0a3518d0 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -210,9 +210,9 @@ GEM jekyll-feed (~> 0.9) jekyll-seo-tag (~> 2.1) minitest (5.19.0) - nokogiri (1.16.2-arm64-darwin) + nokogiri (1.16.5-arm64-darwin) racc (~> 1.4) - nokogiri (1.16.2-x86_64-linux) + nokogiri (1.16.5-x86_64-linux) racc (~> 1.4) octokit (4.25.1) faraday (>= 1, < 3) diff --git a/docs/changelog.md b/docs/changelog.md index 243ca7a6..3dd9abd9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,8 @@ # Change log +## 2024.5.15 +- Reimplemented support for pickle in MSONAble. (@matthewcarbone) + ## 2024.4.17 - Revert changes to json.py for now. diff --git a/docs/monty.functools.md b/docs/monty.functools.md index 15bc0fca..93fd41b6 100644 --- a/docs/monty.functools.md +++ b/docs/monty.functools.md @@ -73,13 +73,10 @@ becomes The decorated main accepts two new arguments: > prof_file: Name of the output file with profiling data - > ```none > If not given, a temporary file is created. > ``` - > sortby: Profiling data are sorted according to this value. - > ```none > default is “time”. See sort_stats. > ``` diff --git a/docs/monty.os.md b/docs/monty.os.md index 8d85d1dc..96184b2e 100644 --- a/docs/monty.os.md +++ b/docs/monty.os.md @@ -15,7 +15,6 @@ performing some tasks, and returns to the original working directory afterwards. E.g., > with cd(“/my/path/”): - > ```none > do_something() > ``` diff --git a/docs/monty.re.md b/docs/monty.re.md index 7082093f..b7e419ef 100644 --- a/docs/monty.re.md +++ b/docs/monty.re.md @@ -25,11 +25,9 @@ A powerful regular expression version of grep. * **Returns** > {key1: [[[matches…], lineno], [[matches…], lineno], - > ```none > [[matches…], lineno], …], > ``` - > key2: …} For reverse reads, the lineno is given as a -ve number. Please note diff --git a/monty/__init__.py b/monty/__init__.py index 02eb8a81..a96daff4 100644 --- a/monty/__init__.py +++ b/monty/__init__.py @@ -9,7 +9,7 @@ __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2014, The Materials Virtual Lab" -__version__ = "2024.4.17" +__version__ = "2024.5.15" __maintainer__ = "Shyue Ping Ong" __email__ = "ongsp@ucsd.edu" __date__ = "Oct 12 2020" diff --git a/monty/functools.py b/monty/functools.py index d5f53819..bc9b9953 100644 --- a/monty/functools.py +++ b/monty/functools.py @@ -84,7 +84,7 @@ def __init__(self, func: Callable) -> None: func: Function to decorate. """ self.__func = func - wraps(self.__func)(self) + wraps(self.__func)(self) # type: ignore def __get__(self, inst: Any, inst_cls) -> Any: if inst is None: diff --git a/monty/json.py b/monty/json.py index 7513fc24..759758b1 100644 --- a/monty/json.py +++ b/monty/json.py @@ -8,6 +8,7 @@ import json import os import pathlib +import pickle import traceback import types from collections import OrderedDict, defaultdict @@ -16,7 +17,8 @@ from importlib import import_module from inspect import getfullargspec from pathlib import Path -from uuid import UUID +from typing import Any, Dict +from uuid import UUID, uuid4 try: import numpy as np @@ -94,8 +96,14 @@ def _check_type(obj, type_str) -> bool: Checks whether obj is an instance of the type defined by type_str. This removes the need to explicitly import type_str. Handles subclasses like isinstance does. E.g.:: - class A: pass - class B(A): pass + class A: + pass + + + class B(A): + pass + + a, b = A(), B() assert isinstance(a, A) assert isinstance(b, B) @@ -166,7 +174,10 @@ def as_dict(self) -> dict: """ A JSON serializable dict representation of an object. """ - d = {"@module": self.__class__.__module__, "@class": self.__class__.__name__} + d = { + "@module": self.__class__.__module__, + "@class": self.__class__.__name__, + } try: parent_module = self.__class__.__module__.split(".", maxsplit=1)[0] @@ -357,6 +368,175 @@ def __modify_schema__(cls, field_schema): custom_schema = cls._generic_json_schema() field_schema.update(custom_schema) + def _get_partial_json(self, json_kwargs, pickle_kwargs): + """Used with the save method. Gets the json representation of a class + with the unserializable components sustituted for hash references.""" + + if pickle_kwargs is None: + pickle_kwargs = {} + if json_kwargs is None: + json_kwargs = {} + encoder = MontyEncoder(allow_unserializable_objects=True, **json_kwargs) + encoded = encoder.encode(self) + return encoder, encoded, json_kwargs, pickle_kwargs + + def get_partial_json(self, json_kwargs=None, pickle_kwargs=None): + """ + Parameters + ---------- + json_kwargs : dict + Keyword arguments to pass to the serializer. + pickle_kwargs : dict + Keyword arguments to pass to pickle.dump. + + Returns + ------- + str, dict + The json encoding of the class and the name-object map if one is + required, otherwise None. + """ + + encoder, encoded, json_kwargs, pickle_kwargs = self._get_partial_json( + json_kwargs, pickle_kwargs + ) + name_object_map = encoder._name_object_map + if len(name_object_map) == 0: + name_object_map = None + return encoded, name_object_map, json_kwargs, pickle_kwargs + + def save( + self, + json_path, + mkdir=True, + json_kwargs=None, + pickle_kwargs=None, + strict=True, + ): + """Utility that uses the standard tools of MSONable to convert the + class to json format, but also save it to disk. In addition, this + method intelligently uses pickle to individually pickle class objects + that are not serializable, saving them separately. This maximizes the + readability of the saved class information while allowing _any_ + class to be at least partially serializable to disk. + + For a fully MSONable class, only a class.json file will be saved to + the location {save_dir}/class.json. For a partially MSONable class, + additional information will be saved to the save directory at + {save_dir}. This includes a pickled object for each attribute that + e serialized. + + Parameters + ---------- + file_path : os.PathLike + The file to which to save the json object. A pickled object of + the same name but different extension might also be saved if the + class is not entirely MSONable. + mkdir : bool + If True, makes the provided directory, including all parent + directories. + json_kwargs : dict + Keyword arguments to pass to the serializer. + pickle_kwargs : dict + Keyword arguments to pass to pickle.dump. + strict : bool + If True, will not allow you to overwrite existing files. + """ + + json_path = Path(json_path) + save_dir = json_path.parent + + encoded, name_object_map, json_kwargs, pickle_kwargs = self.get_partial_json( + json_kwargs, pickle_kwargs + ) + + if mkdir: + save_dir.mkdir(exist_ok=True, parents=True) + + # Define the pickle path + pickle_path = save_dir / f"{json_path.stem}.pkl" + + # Check if the files exist and the strict parameter is True + if strict and json_path.exists(): + raise FileExistsError(f"strict is true and file {json_path} exists") + if strict and pickle_path.exists(): + raise FileExistsError(f"strict is true and file {pickle_path} exists") + + # Save the json file + with open(json_path, "w") as outfile: + outfile.write(encoded) + + # Save the pickle file if we have anything to save from the name_object_map + if name_object_map is not None: + with open(pickle_path, "wb") as f: + pickle.dump(name_object_map, f, **pickle_kwargs) + + @classmethod + def load(cls, file_path): + """Loads a class from a provided json file. + + Parameters + ---------- + file_path : os.PathLike + The json file to load from. + + Returns + ------- + MSONable + An instance of the class being reloaded. + """ + + d = _d_from_path(file_path) + return cls.from_dict(d) + + +def load(path): + """Loads a json file that was saved using MSONable.save. + + Parameters + ---------- + path : os.PathLike + Path to the json file to load. + + Returns + ------- + MSONable + """ + + d = _d_from_path(path) + module = d["@module"] + klass = d["@class"] + module = import_module(module) + klass = getattr(module, klass) + return klass.from_dict(d) + + +def _d_from_path(file_path): + json_path = Path(file_path) + save_dir = json_path.parent + pickle_path = save_dir / f"{json_path.stem}.pkl" + + with open(json_path, "r") as infile: + d = json.loads(infile.read()) + + if pickle_path.exists(): + name_object_map = pickle.load(open(pickle_path, "rb")) + d = _recursive_name_object_map_replacement(d, name_object_map) + return d + + +def _recursive_name_object_map_replacement(d, name_object_map): + if isinstance(d, dict): + if "@object_reference" in d: + name = d["@object_reference"] + return name_object_map.pop(name) + return { + k: _recursive_name_object_map_replacement(v, name_object_map) + for k, v in d.items() + } + elif isinstance(d, list): + return [_recursive_name_object_map_replacement(x, name_object_map) for x in d] + return d + class MontyEncoder(json.JSONEncoder): """ @@ -367,6 +547,18 @@ class MontyEncoder(json.JSONEncoder): json.dumps(object, cls=MontyEncoder) """ + def __init__(self, *args, allow_unserializable_objects=False, **kwargs): + super().__init__(*args, **kwargs) + self._allow_unserializable_objects = allow_unserializable_objects + self._name_object_map: Dict[str, Any] = {} + self._index = 0 + + def _update_name_object_map(self, o): + name = f"{self._index:012}-{str(uuid4())}" + self._index += 1 + self._name_object_map[name] = o + return {"@object_reference": name} + def default(self, o) -> dict: # pylint: disable=E0202 """ Overriding default method for JSON encoding. This method does two @@ -380,7 +572,11 @@ def default(self, o) -> dict: # pylint: disable=E0202 Python dict representation. """ if isinstance(o, datetime.datetime): - return {"@module": "datetime", "@class": "datetime", "string": str(o)} + return { + "@module": "datetime", + "@class": "datetime", + "string": str(o), + } if isinstance(o, UUID): return {"@module": "uuid", "@class": "UUID", "string": str(o)} if isinstance(o, Path): @@ -431,10 +627,20 @@ def default(self, o) -> dict: # pylint: disable=E0202 } if bson is not None and isinstance(o, bson.objectid.ObjectId): - return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)} + return { + "@module": "bson.objectid", + "@class": "ObjectId", + "oid": str(o), + } if callable(o) and not isinstance(o, MSONable): - return _serialize_callable(o) + try: + return _serialize_callable(o) + except AttributeError as e: + # Some callables may not have instance __name__ + if self._allow_unserializable_objects: + return self._update_name_object_map(o) + raise AttributeError(e) try: if pydantic is not None and isinstance(o, pydantic.BaseModel): @@ -450,6 +656,11 @@ def default(self, o) -> dict: # pylint: disable=E0202 d = o.as_dict() elif isinstance(o, Enum): d = {"value": o.value} + elif self._allow_unserializable_objects: + # Last resort logic. We keep track of some name of the object + # as a reference, and instead of the object, store that + # name, which of course is json-serializable + d = self._update_name_object_map(o) else: raise TypeError( f"Object of type {o.__class__.__name__} is not JSON serializable" @@ -639,7 +850,11 @@ class MSONError(Exception): def jsanitize( - obj, strict=False, allow_bson=False, enum_values=False, recursive_msonable=False + obj, + strict=False, + allow_bson=False, + enum_values=False, + recursive_msonable=False, ): """ This method cleans an input json-like object, either a list or a dict or @@ -690,16 +905,19 @@ def jsanitize( for i in obj ] if np is not None and isinstance(obj, np.ndarray): - return [ - jsanitize( - i, - strict=strict, - allow_bson=allow_bson, - enum_values=enum_values, - recursive_msonable=recursive_msonable, - ) - for i in obj.tolist() - ] + try: + return [ + jsanitize( + i, + strict=strict, + allow_bson=allow_bson, + enum_values=enum_values, + recursive_msonable=recursive_msonable, + ) + for i in obj.tolist() + ] + except TypeError: + return obj.tolist() if np is not None and isinstance(obj, np.generic): return obj.item() if _check_type( diff --git a/monty/serialization.py b/monty/serialization.py index 128b7ec7..0e646d35 100644 --- a/monty/serialization.py +++ b/monty/serialization.py @@ -12,7 +12,7 @@ try: from ruamel.yaml import YAML except ImportError: - YAML = None + YAML = None # type: ignore from monty.io import zopen from monty.json import MontyDecoder, MontyEncoder diff --git a/pyproject.toml b/pyproject.toml index 95ef2d93..4587a74b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ dependencies = [ ] -version = "2024.4.17" +version = "2024.5.15" [tool.setuptools] packages = ["monty"] diff --git a/requirements-optional.txt b/requirements-optional.txt index 546f4ff2..2997ce97 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -1,9 +1,9 @@ numpy==1.26.4 ruamel.yaml==0.18.6 msgpack==1.0.8 -tqdm==4.66.2 -pymongo==4.6.3 +tqdm==4.66.4 +pymongo==4.7.2 pandas==2.2.2 -orjson==3.10.0 +orjson==3.10.3 types-orjson==3.6.2 types-requests==2.31.0.20240406 diff --git a/tests/test_files/3000_lines.txt.gz b/tests/test_files/3000_lines.txt.gz index ca27a301..9d95b137 100644 Binary files a/tests/test_files/3000_lines.txt.gz and b/tests/test_files/3000_lines.txt.gz differ diff --git a/tests/test_json.py b/tests/test_json.py index d453baee..2653ec4d 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -23,6 +23,11 @@ except ImportError: torch = None +try: + import pydantic +except ImportError: + pydantic = None + try: from bson.objectid import ObjectId except ImportError: @@ -30,7 +35,14 @@ import pytest -from monty.json import MontyDecoder, MontyEncoder, MSONable, _load_redirect, jsanitize +from monty.json import ( + MontyDecoder, + MontyEncoder, + MSONable, + _load_redirect, + jsanitize, + load, +) from . import __version__ as tests_version @@ -390,63 +402,67 @@ def test_enum_serialization_no_msonable(self): f = jsanitize(d, enum_values=True) assert f["123"] == "value_a" - # def test_save_load(self, tmp_path): - # """Tests the save and load serialization methods.""" - # - # test_good_class = GoodMSONClass( - # "Hello", - # "World", - # "Python", - # **{ - # "cant_serialize_me": GoodNOTMSONClass( - # "Hello2", "World2", "Python2", **{"values": []} - # ), - # "cant_serialize_me2": [ - # GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}), - # GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}), - # ], - # "cant_serialize_me3": [ - # { - # "tmp": GoodMSONClass( - # "Hello5", "World5", "Python5", **{"values": []} - # ), - # "tmp2": 2, - # "tmp3": [1, 2, 3], - # }, - # { - # "tmp5": GoodNOTMSONClass( - # "aHello5", "aWorld5", "aPython5", **{"values": []} - # ), - # "tmp2": 5, - # "tmp3": {"test": "test123"}, - # }, - # # Gotta check that if I hide an MSONable class somewhere - # # it still gets correctly serialized. - # {"actually_good": GoodMSONClass("1", "2", "3", **{"values": []})}, - # ], - # "values": [], - # }, - # ) - # - # # This will pass - # test_good_class.as_dict() - # - # # This will fail - # with pytest.raises(TypeError): - # test_good_class.to_json() - # - # # This should also pass though - # target = tmp_path / "test_dir123" - # test_good_class.save(target, json_kwargs={"indent": 4, "sort_keys": True}) - # - # # This will fail - # with pytest.raises(FileExistsError): - # test_good_class.save(target, strict=True) - # - # # Now check that reloading this, the classes are equal! - # test_good_class2 = GoodMSONClass.load(target) - # - # assert test_good_class == test_good_class2 + def test_save_load(self, tmp_path): + """Tests the save and load serialization methods.""" + + test_good_class = GoodMSONClass( + "Hello", + "World", + "Python", + **{ + "cant_serialize_me": GoodNOTMSONClass( + "Hello2", "World2", "Python2", **{"values": []} + ), + "cant_serialize_me2": [ + GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}), + GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}), + ], + "cant_serialize_me3": [ + { + "tmp": GoodMSONClass( + "Hello5", "World5", "Python5", **{"values": []} + ), + "tmp2": 2, + "tmp3": [1, 2, 3], + }, + { + "tmp5": GoodNOTMSONClass( + "aHello5", "aWorld5", "aPython5", **{"values": []} + ), + "tmp2": 5, + "tmp3": {"test": "test123"}, + }, + # Gotta check that if I hide an MSONable class somewhere + # it still gets correctly serialized. + {"actually_good": GoodMSONClass("1", "2", "3", **{"values": []})}, + ], + "values": [], + }, + ) + + # This will pass + test_good_class.as_dict() + + # This will fail + with pytest.raises(TypeError): + test_good_class.to_json() + + # This should also pass though + target = tmp_path / "test.json" + test_good_class.save(target, json_kwargs={"indent": 4, "sort_keys": True}) + + # This will fail + with pytest.raises(FileExistsError): + test_good_class.save(target, strict=True) + + # Now check that reloading this, the classes are equal! + test_good_class2 = GoodMSONClass.load(target) + + # Final check using load + test_good_class3 = load(target) + + assert test_good_class == test_good_class2 + assert test_good_class == test_good_class3 class TestJson: @@ -576,6 +592,10 @@ def test_numpy(self): d = jsanitize(x, strict=True) assert isinstance(d["energies"][0], float) + x = {"energy": np.array(-1.0)} + d = jsanitize(x, strict=True) + assert isinstance(d["energy"], float) + # Test data nested in a class x = np.array([[1 + 1j, 2 + 1j], [3 + 1j, 4 + 1j]], dtype="complex64") cls = ClassContainingNumpyArray(np_a={"a": [{"b": x}]}) @@ -722,6 +742,23 @@ def test_jsanitize(self): clean_recursive_msonable = jsanitize(d, recursive_msonable=True) assert clean_recursive_msonable["hello"]["a"] == 1 assert clean_recursive_msonable["hello"]["b"] == 2 + assert clean_recursive_msonable["hello"]["c"] == 3 + assert clean_recursive_msonable["test"] == "hi" + + d = {"hello": [GoodMSONClass(1, 2, 3), "test"], "test": "hi"} + clean_recursive_msonable = jsanitize(d, recursive_msonable=True) + assert clean_recursive_msonable["hello"][0]["a"] == 1 + assert clean_recursive_msonable["hello"][0]["b"] == 2 + assert clean_recursive_msonable["hello"][0]["c"] == 3 + assert clean_recursive_msonable["hello"][1] == "test" + assert clean_recursive_msonable["test"] == "hi" + + d = {"hello": (GoodMSONClass(1, 2, 3), "test"), "test": "hi"} + clean_recursive_msonable = jsanitize(d, recursive_msonable=True) + assert clean_recursive_msonable["hello"][0]["a"] == 1 + assert clean_recursive_msonable["hello"][0]["b"] == 2 + assert clean_recursive_msonable["hello"][0]["c"] == 3 + assert clean_recursive_msonable["hello"][1] == "test" assert clean_recursive_msonable["test"] == "hi" d = {"dt": datetime.datetime.now()} @@ -851,6 +888,7 @@ def test_redirect_settings_file(self): } } + @pytest.mark.skipif(pydantic is None, reason="pydantic not present") def test_pydantic_integrations(self): from pydantic import BaseModel, ValidationError