From cef4d70e1b302dbe4f62ffacb22c3dcf557a9eae Mon Sep 17 00:00:00 2001 From: Graham Esau Date: Sat, 25 Nov 2023 13:37:23 +0000 Subject: [PATCH] Refactor jsonpickle logic out of visitors --- src/localstack_persist/hooks.py | 2 - src/localstack_persist/prepare_service.py | 14 +--- .../s3/migrate_ephemeral_object_store.py | 10 ++- .../serialization/__init__.py | 0 .../serialization/jsonpickle/__init__.py | 0 .../jsonpickle/handlers.py} | 21 +++--- .../serialization/jsonpickle/serializer.py | 48 +++++++++++++ src/localstack_persist/utils.py | 13 ++++ src/localstack_persist/visitors.py | 70 +++++++------------ 9 files changed, 103 insertions(+), 75 deletions(-) create mode 100644 src/localstack_persist/serialization/__init__.py create mode 100644 src/localstack_persist/serialization/jsonpickle/__init__.py rename src/localstack_persist/{jsonpickle.py => serialization/jsonpickle/handlers.py} (98%) create mode 100644 src/localstack_persist/serialization/jsonpickle/serializer.py create mode 100644 src/localstack_persist/utils.py diff --git a/src/localstack_persist/hooks.py b/src/localstack_persist/hooks.py index 8a2bcdf..cb7193b 100644 --- a/src/localstack_persist/hooks.py +++ b/src/localstack_persist/hooks.py @@ -3,14 +3,12 @@ from localstack.runtime import hooks from .state import STATE_TRACKER -from .jsonpickle import register_handlers LOG = logging.getLogger(__name__) @hooks.on_infra_start(priority=1) def on_infra_start(): - register_handlers() STATE_TRACKER.load_all_services_state() STATE_TRACKER.start() diff --git a/src/localstack_persist/prepare_service.py b/src/localstack_persist/prepare_service.py index 38fa3a5..2653557 100644 --- a/src/localstack_persist/prepare_service.py +++ b/src/localstack_persist/prepare_service.py @@ -3,7 +3,9 @@ import os import sys from localstack.services.plugins import SERVICE_PLUGINS + from .config import BASE_DIR +from .utils import once def prepare_service(service_name: str): @@ -13,18 +15,6 @@ def prepare_service(service_name: str): prepare_s3() -def once(f: Callable): - has_run = False - - def wrapper(*args, **kwargs): - nonlocal has_run - if not has_run: - has_run = True - return f(*args, **kwargs) - - return wrapper - - @once def prepare_lambda(): # Define localstack.services.awslambda as a backward-compatible alias for localstack.services.lambda_ diff --git a/src/localstack_persist/s3/migrate_ephemeral_object_store.py b/src/localstack_persist/s3/migrate_ephemeral_object_store.py index 2ce8eaa..ccf5ed8 100644 --- a/src/localstack_persist/s3/migrate_ephemeral_object_store.py +++ b/src/localstack_persist/s3/migrate_ephemeral_object_store.py @@ -1,5 +1,4 @@ import base64 -import json from typing import cast import jsonpickle @@ -9,7 +8,9 @@ LockedSpooledTemporaryFile, ) from localstack.utils.files import mkdir + from .storage import PersistedS3ObjectStore +from ..serialization.jsonpickle.serializer import JsonPickleSerializer class LockedSpooledTemporaryFileHandler(jsonpickle.handlers.BaseHandler): @@ -35,11 +36,8 @@ def __init__(self, id: str) -> None: def migrate_ephemeral_object_store(file_path: str, store: PersistedS3ObjectStore): jsonpickle.register(LockedSpooledTemporaryFile, LockedSpooledTemporaryFileHandler) - with open(file_path) as file: - envelope: dict = json.load(file) - - unpickler = jsonpickle.Unpickler(keys=True, safe=True) - ephemeral_store = cast(EphemeralS3ObjectStore, unpickler.restore(envelope["data"])) + serializer = JsonPickleSerializer() + ephemeral_store: EphemeralS3ObjectStore = serializer.deserialize(file_path) if not ephemeral_store._filesystem: # create the root directory to avoid trying to re-migrate empty store again in future diff --git a/src/localstack_persist/serialization/__init__.py b/src/localstack_persist/serialization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/localstack_persist/serialization/jsonpickle/__init__.py b/src/localstack_persist/serialization/jsonpickle/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/localstack_persist/jsonpickle.py b/src/localstack_persist/serialization/jsonpickle/handlers.py similarity index 98% rename from src/localstack_persist/jsonpickle.py rename to src/localstack_persist/serialization/jsonpickle/handlers.py index 60d28a0..7a306ce 100644 --- a/src/localstack_persist/jsonpickle.py +++ b/src/localstack_persist/serialization/jsonpickle/handlers.py @@ -7,6 +7,18 @@ from jsonpickle.handlers import DatetimeHandler as DefaultDatetimeHandler from moto.acm.models import CertBundle +from localstack_persist.utils import once + + +@once +def register_handlers(): + CertBundleHandler.handles(CertBundle) + ConditionHandler.handles(Condition) + PriorityQueueHandler.handles(PriorityQueue) + DatetimeHandler.handles(datetime.datetime) + DatetimeHandler.handles(datetime.date) + DatetimeHandler.handles(datetime.time) + class ConditionHandler(jsonpickle.handlers.BaseHandler): def flatten(self, obj, data: dict): @@ -83,12 +95,3 @@ def restore(self, data: dict): raise TypeError("DatetimeHandler: unexpected object type " + cls_name) return cls.fromisoformat(data["isoformat"]) - - -def register_handlers(): - CertBundleHandler.handles(CertBundle) - ConditionHandler.handles(Condition) - PriorityQueueHandler.handles(PriorityQueue) - DatetimeHandler.handles(datetime.datetime) - DatetimeHandler.handles(datetime.date) - DatetimeHandler.handles(datetime.time) diff --git a/src/localstack_persist/serialization/jsonpickle/serializer.py b/src/localstack_persist/serialization/jsonpickle/serializer.py new file mode 100644 index 0000000..62a1269 --- /dev/null +++ b/src/localstack_persist/serialization/jsonpickle/serializer.py @@ -0,0 +1,48 @@ +import json +import logging +from typing import Any +import jsonpickle.handlers +import jsonpickle.tags + +from .handlers import register_handlers + +# Track version for future handling of backward (or forward) incompatible changes. +# This is the "serialisation format" version, which is different to the localstack-persist version. +SER_VERSION_KEY = "v" +SER_VERSION = 1 + +DATA_KEY = "data" + +LOG = logging.getLogger(__name__) + + +class JsonPickleSerializer: + _json_encoder = json.JSONEncoder(check_circular=False, separators=(",", ":")) + + def serialize(self, file_path: str, data: Any): + register_handlers() + + pickler = jsonpickle.Pickler(keys=True, warn=True) + + envelope = {SER_VERSION_KEY: SER_VERSION, DATA_KEY: pickler.flatten(data)} + + with open(file_path, "w") as file: + for chunk in self._json_encoder.iterencode(envelope): + file.write(chunk) + + def deserialize(self, file_path: str) -> Any: + register_handlers() + + with open(file_path) as file: + envelope: dict = json.load(file) + + version = envelope.get(SER_VERSION_KEY, None) + if version != SER_VERSION: + LOG.warning( + "Persisted state at %s has unsupported version %s - trying to load it anyway...", + file_path, + version, + ) + + unpickler = jsonpickle.Unpickler(keys=True, safe=True, on_missing="error") + return unpickler.restore(envelope[DATA_KEY]) diff --git a/src/localstack_persist/utils.py b/src/localstack_persist/utils.py new file mode 100644 index 0000000..993a7f4 --- /dev/null +++ b/src/localstack_persist/utils.py @@ -0,0 +1,13 @@ +from collections.abc import Callable + + +def once(f: Callable[[], None]) -> Callable[[], None]: + has_run = False + + def wrapper(): + nonlocal has_run + if not has_run: + has_run = True + return f() + + return wrapper diff --git a/src/localstack_persist/visitors.py b/src/localstack_persist/visitors.py index 5aab504..5cf7911 100644 --- a/src/localstack_persist/visitors.py +++ b/src/localstack_persist/visitors.py @@ -3,7 +3,6 @@ import shutil from typing import Dict, Optional, Any, TypeAlias -import jsonpickle import logging import localstack.config @@ -20,22 +19,18 @@ from moto.s3.models import s3_backends from .config import BASE_DIR +from .serialization.jsonpickle.serializer import JsonPickleSerializer -JsonSerializableState: TypeAlias = BackendDict | AccountRegionBundle +SerializableState: TypeAlias = BackendDict | AccountRegionBundle logging.getLogger("watchdog").setLevel(logging.INFO) LOG = logging.getLogger(__name__) -# Track version for future handling of backward (or forward) incompatible changes. -# This is the "serialisation format" version, which is different to the localstack-persist version. -SER_VERSION_KEY = "v" -SER_VERSION = 1 - -DATA_KEY = "data" +serializer = JsonPickleSerializer() def get_json_file_path( - state_container: JsonSerializableState, + state_container: SerializableState, ): file_name = "backend" if isinstance(state_container, BackendDict) else "store" @@ -75,7 +70,7 @@ def __init__(self, service_name: str) -> None: def visit(self, state_container: StateContainer): if isinstance(state_container, BackendDict | AccountRegionBundle): - self._load_json(state_container) + self._load_state(state_container) elif isinstance(state_container, AssetDirectory): if state_container.path.startswith(BASE_DIR): # nothing to do - assets are read directly from the volume @@ -94,60 +89,49 @@ def visit(self, state_container: StateContainer): else: LOG.warning("Unexpected state_container type: %s", type(state_container)) - def _load_json(self, state_container: JsonSerializableState): + def _load_state(self, state_container: SerializableState): + state_container_type = state_type(state_container) + file_path = get_json_file_path(state_container) if not os.path.isfile(file_path): return - with open(file_path) as file: - envelope: dict = json.load(file) - - version = envelope.get(SER_VERSION_KEY, None) - if version != SER_VERSION: - LOG.warning( - "Persisted state at %s has unsupported version %s - trying to load it anyway...", - file_path, - version, - ) + deserialized = serializer.deserialize(file_path) - unpickler = jsonpickle.Unpickler(keys=True, safe=True, on_missing="error") - deserialised = unpickler.restore(envelope[DATA_KEY]) - - state_container_type = state_type(state_container) - deserialised_type = state_type(deserialised) + deserialized_type = state_type(deserialized) if ( state_container_type == AccountRegionBundle[V3S3Store] - and deserialised_type == AccountRegionBundle[LegacyS3Store] + and deserialized_type == AccountRegionBundle[LegacyS3Store] ): try: from .s3.migrate_to_v3 import migrate_to_v3 LOG.info("Migrating S3 state to V3 provider...") - self._load_json(s3_backends) - deserialised = migrate_to_v3(s3_backends) + self._load_state(s3_backends) + deserialized = migrate_to_v3(s3_backends) except: LOG.exception("Error migrating S3 state to V3 provider") return - elif not are_same_type(state_container_type, deserialised_type): + elif not are_same_type(state_container_type, deserialized_type): LOG.warning( "Unexpected deserialised state_container type: %s, expected %s", - deserialised_type, + deserialized_type, state_container_type, ) return # Set Processing because after loading state, it will take some time for opensearch/elasticsearch to start. - if deserialised_type == AccountRegionBundle[OpenSearchStore]: - for region_bundle in deserialised.values(): # type: ignore + if deserialized_type == AccountRegionBundle[OpenSearchStore]: + for region_bundle in deserialized.values(): # type: ignore store: OpenSearchStore for store in region_bundle.values(): for domain in store.opensearch_domains.values(): domain["Processing"] = True - if isinstance(state_container, dict) and isinstance(deserialised, dict): - state_container.update(deserialised) - state_container.__dict__.update(deserialised.__dict__) + if isinstance(state_container, dict) and isinstance(deserialized, dict): + state_container.update(deserialized) + state_container.__dict__.update(deserialized.__dict__) class SaveStateVisitor(StateVisitor): @@ -159,7 +143,7 @@ def __init__(self, service_name: str) -> None: def visit(self, state_container: StateContainer): if isinstance(state_container, BackendDict | AccountRegionBundle): - self._save_json(state_container) + self._save_state(state_container) elif isinstance(state_container, AssetDirectory): if state_container.path.startswith(BASE_DIR): # nothing to do - assets are written directly to the volume @@ -173,17 +157,11 @@ def visit(self, state_container: StateContainer): else: LOG.warning("Unexpected state_container type: %s", type(state_container)) - def _save_json(self, state_container: JsonSerializableState): + def _save_state(self, state_container: SerializableState): file_path = get_json_file_path(state_container) - pickler = jsonpickle.Pickler(keys=True, warn=True) - flattened = pickler.flatten(state_container) - - envelope = {SER_VERSION_KEY: SER_VERSION, DATA_KEY: flattened} - os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "w") as file: - for chunk in self.json_encoder.iterencode(envelope): - file.write(chunk) + + serializer.serialize(file_path, state_container) @staticmethod def _sync_directories(src: str, dst: str):