From 680cb6199feded6f2883d5c3127970a65f30ae3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Sun, 10 Dec 2023 11:46:54 -0800 Subject: [PATCH] Fix unpickling of Configuration --- CHANGELOG.md | 5 +++++ src/sdsstools/configuration.py | 15 +++++++++++---- test/test_configuration.py | 12 ++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45786d0..61ed6f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## [1.5.4](https://github.com/sdss/sdsstools/compare/1.5.3...1.5.4) - 2023-12-10 + +- Fix unpickling of `Configuration` instances. This only seems relevant when trying to pass a `Configuration` object to a `multiprocessing` callback. + + ## [1.5.3](https://github.com/sdss/sdsstools/compare/1.5.2...1.5.3) - 2023-12-08 - Vendorise `pydl`'s `yanny` module which can be accessed as `sdsstools.yanny`. diff --git a/src/sdsstools/configuration.py b/src/sdsstools/configuration.py index 885838f..195bc2e 100644 --- a/src/sdsstools/configuration.py +++ b/src/sdsstools/configuration.py @@ -223,15 +223,22 @@ def __init__( def __getitem__(self, __key: str) -> Any: if self.strict_mode: - return super().__getitem__(__key) + return dict.__getitem__(self, __key) return self.get(__key) def __setitem__(self, __key: str, __value: Any) -> None: - if isinstance(__value, dict) and self.propagate_type: - __value = self.__class__(__value, strict_mode=self.strict_mode) + # We use getattr to give default values to the non-default dict attributes. + # This only seems to matter when pickling/unpickling, and at the end of + # unpickling the values are set correctly anyway. - return super().__setitem__(__key, __value) + if isinstance(__value, dict) and getattr(self, "propagate_type", True): + __value = self.__class__( + __value, + strict_mode=getattr(self, "strict_mode", False), + ) + + return dict.__setitem__(self, __key, __value) def get(self, __key: str, default: Any = None, strict: bool | None = None) -> Any: if (strict is None and self.strict_mode is True) or strict is True: diff --git a/test/test_configuration.py b/test/test_configuration.py index dbb0f01..fba699b 100644 --- a/test/test_configuration.py +++ b/test/test_configuration.py @@ -9,6 +9,7 @@ import copy import inspect import io +import multiprocessing import os import unittest.mock @@ -379,3 +380,14 @@ def test_configuration_assignment_dot(): conf["cat1.key1"] = {"subkey2": 2} assert len(conf) == 2 assert conf["cat1.key1"] == {"subkey2": 2} + + +def _process(config): + assert config["cat1"]["key1"] == "another_value" + + +def test_configuration_with_multiprocessing(config_file): + config = Configuration(config_file) + config.propagate_type = False + with multiprocessing.Pool(2) as pool: + pool.apply(_process, (config,))