Skip to content

Commit

Permalink
Refactor jsonpickle logic out of visitors
Browse files Browse the repository at this point in the history
  • Loading branch information
GREsau committed Nov 25, 2023
1 parent e638d71 commit cef4d70
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 75 deletions.
2 changes: 0 additions & 2 deletions src/localstack_persist/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from localstack.runtime import hooks

from .state import STATE_TRACKER
from .jsonpickle import register_handlers

LOG = logging.getLogger(__name__)


@hooks.on_infra_start(priority=1)
def on_infra_start():
register_handlers()
STATE_TRACKER.load_all_services_state()
STATE_TRACKER.start()

Expand Down
14 changes: 2 additions & 12 deletions src/localstack_persist/prepare_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
import sys
from localstack.services.plugins import SERVICE_PLUGINS

from .config import BASE_DIR
from .utils import once


def prepare_service(service_name: str):
Expand All @@ -13,18 +15,6 @@ def prepare_service(service_name: str):
prepare_s3()


def once(f: Callable):
has_run = False

def wrapper(*args, **kwargs):
nonlocal has_run
if not has_run:
has_run = True
return f(*args, **kwargs)

return wrapper


@once
def prepare_lambda():
# Define localstack.services.awslambda as a backward-compatible alias for localstack.services.lambda_
Expand Down
10 changes: 4 additions & 6 deletions src/localstack_persist/s3/migrate_ephemeral_object_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
from typing import cast
import jsonpickle

Expand All @@ -9,7 +8,9 @@
LockedSpooledTemporaryFile,
)
from localstack.utils.files import mkdir

from .storage import PersistedS3ObjectStore
from ..serialization.jsonpickle.serializer import JsonPickleSerializer


class LockedSpooledTemporaryFileHandler(jsonpickle.handlers.BaseHandler):
Expand All @@ -35,11 +36,8 @@ def __init__(self, id: str) -> None:

def migrate_ephemeral_object_store(file_path: str, store: PersistedS3ObjectStore):
jsonpickle.register(LockedSpooledTemporaryFile, LockedSpooledTemporaryFileHandler)
with open(file_path) as file:
envelope: dict = json.load(file)

unpickler = jsonpickle.Unpickler(keys=True, safe=True)
ephemeral_store = cast(EphemeralS3ObjectStore, unpickler.restore(envelope["data"]))
serializer = JsonPickleSerializer()
ephemeral_store: EphemeralS3ObjectStore = serializer.deserialize(file_path)

if not ephemeral_store._filesystem:
# create the root directory to avoid trying to re-migrate empty store again in future
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from jsonpickle.handlers import DatetimeHandler as DefaultDatetimeHandler
from moto.acm.models import CertBundle

from localstack_persist.utils import once


@once
def register_handlers():
CertBundleHandler.handles(CertBundle)
ConditionHandler.handles(Condition)
PriorityQueueHandler.handles(PriorityQueue)
DatetimeHandler.handles(datetime.datetime)
DatetimeHandler.handles(datetime.date)
DatetimeHandler.handles(datetime.time)


class ConditionHandler(jsonpickle.handlers.BaseHandler):
def flatten(self, obj, data: dict):
Expand Down Expand Up @@ -83,12 +95,3 @@ def restore(self, data: dict):
raise TypeError("DatetimeHandler: unexpected object type " + cls_name)

return cls.fromisoformat(data["isoformat"])


def register_handlers():
CertBundleHandler.handles(CertBundle)
ConditionHandler.handles(Condition)
PriorityQueueHandler.handles(PriorityQueue)
DatetimeHandler.handles(datetime.datetime)
DatetimeHandler.handles(datetime.date)
DatetimeHandler.handles(datetime.time)
48 changes: 48 additions & 0 deletions src/localstack_persist/serialization/jsonpickle/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
import logging
from typing import Any
import jsonpickle.handlers
import jsonpickle.tags

from .handlers import register_handlers

# Track version for future handling of backward (or forward) incompatible changes.
# This is the "serialisation format" version, which is different to the localstack-persist version.
SER_VERSION_KEY = "v"
SER_VERSION = 1

DATA_KEY = "data"

LOG = logging.getLogger(__name__)


class JsonPickleSerializer:
_json_encoder = json.JSONEncoder(check_circular=False, separators=(",", ":"))

def serialize(self, file_path: str, data: Any):
register_handlers()

pickler = jsonpickle.Pickler(keys=True, warn=True)

envelope = {SER_VERSION_KEY: SER_VERSION, DATA_KEY: pickler.flatten(data)}

with open(file_path, "w") as file:
for chunk in self._json_encoder.iterencode(envelope):
file.write(chunk)

def deserialize(self, file_path: str) -> Any:
register_handlers()

with open(file_path) as file:
envelope: dict = json.load(file)

version = envelope.get(SER_VERSION_KEY, None)
if version != SER_VERSION:
LOG.warning(
"Persisted state at %s has unsupported version %s - trying to load it anyway...",
file_path,
version,
)

unpickler = jsonpickle.Unpickler(keys=True, safe=True, on_missing="error")
return unpickler.restore(envelope[DATA_KEY])
13 changes: 13 additions & 0 deletions src/localstack_persist/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from collections.abc import Callable


def once(f: Callable[[], None]) -> Callable[[], None]:
has_run = False

def wrapper():
nonlocal has_run
if not has_run:
has_run = True
return f()

return wrapper
70 changes: 24 additions & 46 deletions src/localstack_persist/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import shutil
from typing import Dict, Optional, Any, TypeAlias

import jsonpickle
import logging

import localstack.config
Expand All @@ -20,22 +19,18 @@
from moto.s3.models import s3_backends

from .config import BASE_DIR
from .serialization.jsonpickle.serializer import JsonPickleSerializer

JsonSerializableState: TypeAlias = BackendDict | AccountRegionBundle
SerializableState: TypeAlias = BackendDict | AccountRegionBundle

logging.getLogger("watchdog").setLevel(logging.INFO)
LOG = logging.getLogger(__name__)

# Track version for future handling of backward (or forward) incompatible changes.
# This is the "serialisation format" version, which is different to the localstack-persist version.
SER_VERSION_KEY = "v"
SER_VERSION = 1

DATA_KEY = "data"
serializer = JsonPickleSerializer()


def get_json_file_path(
state_container: JsonSerializableState,
state_container: SerializableState,
):
file_name = "backend" if isinstance(state_container, BackendDict) else "store"

Expand Down Expand Up @@ -75,7 +70,7 @@ def __init__(self, service_name: str) -> None:

def visit(self, state_container: StateContainer):
if isinstance(state_container, BackendDict | AccountRegionBundle):
self._load_json(state_container)
self._load_state(state_container)
elif isinstance(state_container, AssetDirectory):
if state_container.path.startswith(BASE_DIR):
# nothing to do - assets are read directly from the volume
Expand All @@ -94,60 +89,49 @@ def visit(self, state_container: StateContainer):
else:
LOG.warning("Unexpected state_container type: %s", type(state_container))

def _load_json(self, state_container: JsonSerializableState):
def _load_state(self, state_container: SerializableState):
state_container_type = state_type(state_container)

file_path = get_json_file_path(state_container)
if not os.path.isfile(file_path):
return

with open(file_path) as file:
envelope: dict = json.load(file)

version = envelope.get(SER_VERSION_KEY, None)
if version != SER_VERSION:
LOG.warning(
"Persisted state at %s has unsupported version %s - trying to load it anyway...",
file_path,
version,
)
deserialized = serializer.deserialize(file_path)

unpickler = jsonpickle.Unpickler(keys=True, safe=True, on_missing="error")
deserialised = unpickler.restore(envelope[DATA_KEY])

state_container_type = state_type(state_container)
deserialised_type = state_type(deserialised)
deserialized_type = state_type(deserialized)

if (
state_container_type == AccountRegionBundle[V3S3Store]
and deserialised_type == AccountRegionBundle[LegacyS3Store]
and deserialized_type == AccountRegionBundle[LegacyS3Store]
):
try:
from .s3.migrate_to_v3 import migrate_to_v3

LOG.info("Migrating S3 state to V3 provider...")
self._load_json(s3_backends)
deserialised = migrate_to_v3(s3_backends)
self._load_state(s3_backends)
deserialized = migrate_to_v3(s3_backends)
except:
LOG.exception("Error migrating S3 state to V3 provider")
return
elif not are_same_type(state_container_type, deserialised_type):
elif not are_same_type(state_container_type, deserialized_type):
LOG.warning(
"Unexpected deserialised state_container type: %s, expected %s",
deserialised_type,
deserialized_type,
state_container_type,
)
return

# Set Processing because after loading state, it will take some time for opensearch/elasticsearch to start.
if deserialised_type == AccountRegionBundle[OpenSearchStore]:
for region_bundle in deserialised.values(): # type: ignore
if deserialized_type == AccountRegionBundle[OpenSearchStore]:
for region_bundle in deserialized.values(): # type: ignore
store: OpenSearchStore
for store in region_bundle.values():
for domain in store.opensearch_domains.values():
domain["Processing"] = True

if isinstance(state_container, dict) and isinstance(deserialised, dict):
state_container.update(deserialised)
state_container.__dict__.update(deserialised.__dict__)
if isinstance(state_container, dict) and isinstance(deserialized, dict):
state_container.update(deserialized)
state_container.__dict__.update(deserialized.__dict__)


class SaveStateVisitor(StateVisitor):
Expand All @@ -159,7 +143,7 @@ def __init__(self, service_name: str) -> None:

def visit(self, state_container: StateContainer):
if isinstance(state_container, BackendDict | AccountRegionBundle):
self._save_json(state_container)
self._save_state(state_container)
elif isinstance(state_container, AssetDirectory):
if state_container.path.startswith(BASE_DIR):
# nothing to do - assets are written directly to the volume
Expand All @@ -173,17 +157,11 @@ def visit(self, state_container: StateContainer):
else:
LOG.warning("Unexpected state_container type: %s", type(state_container))

def _save_json(self, state_container: JsonSerializableState):
def _save_state(self, state_container: SerializableState):
file_path = get_json_file_path(state_container)
pickler = jsonpickle.Pickler(keys=True, warn=True)
flattened = pickler.flatten(state_container)

envelope = {SER_VERSION_KEY: SER_VERSION, DATA_KEY: flattened}

os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as file:
for chunk in self.json_encoder.iterencode(envelope):
file.write(chunk)

serializer.serialize(file_path, state_container)

@staticmethod
def _sync_directories(src: str, dst: str):
Expand Down

0 comments on commit cef4d70

Please sign in to comment.