From af83306e0425851a750a8f74a5110fe197fdecfe Mon Sep 17 00:00:00 2001 From: Ben Kurtovic Date: Tue, 5 Sep 2023 00:44:01 -0400 Subject: [PATCH] Fix pickling SmartLists (fixes #289) --- CHANGELOG | 1 + docs/changelog.rst | 2 ++ src/mwparserfromhell/smart_list/list_proxy.py | 14 +++++++++- src/mwparserfromhell/smart_list/smart_list.py | 18 +++++++----- src/mwparserfromhell/smart_list/utils.py | 4 ++- src/mwparserfromhell/utils.py | 20 ++++++++++--- tests/test_smart_list.py | 28 ++++++++++++++++++- tests/test_template.py | 2 +- tests/test_wikicode.py | 11 +++++++- 9 files changed, 84 insertions(+), 16 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 19641b88..f280773c 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -4,6 +4,7 @@ v0.6.5 (unreleased): - Added support for Python 3.11. - Fixed parsing of leading zeros in named HTML entities. (#288) - Fixed memory leak parsing tags. (#303) +- Fixed pickling SmartList objects. (#289) v0.6.4 (released February 14, 2022): diff --git a/docs/changelog.rst b/docs/changelog.rst index e2af82de..116623c6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,8 @@ Unreleased (`#288 `_) - Fixed memory leak parsing tags. (`#303 `_) +- Fixed pickling SmartList objects. + (`#289 `_) v0.6.4 ------ diff --git a/src/mwparserfromhell/smart_list/list_proxy.py b/src/mwparserfromhell/smart_list/list_proxy.py index 5132def3..f726754f 100644 --- a/src/mwparserfromhell/smart_list/list_proxy.py +++ b/src/mwparserfromhell/smart_list/list_proxy.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2020 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # Copyright (C) 2019-2020 Yuri Astrakhan # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -19,6 +19,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import weakref + from .utils import _SliceNormalizerMixIn, inheritdoc @@ -30,11 +32,21 @@ class ListProxy(_SliceNormalizerMixIn, list): it builds it dynamically using the :meth:`_render` method. """ + __slots__ = ("__weakref__", "_parent", "_sliceinfo") + def __init__(self, parent, sliceinfo): super().__init__() self._parent = parent self._sliceinfo = sliceinfo + def __reduce_ex__(self, protocol: int) -> tuple: + return (ListProxy, (self._parent, self._sliceinfo), ()) + + def __setstate__(self, state: tuple) -> None: + # Reregister with the parent + child_ref = weakref.ref(self, self._parent._delete_child) + self._parent._children[id(child_ref)] = (child_ref, self._sliceinfo) + def __repr__(self): return repr(self._render()) diff --git a/src/mwparserfromhell/smart_list/smart_list.py b/src/mwparserfromhell/smart_list/smart_list.py index e2fd87f8..dbd09ca0 100644 --- a/src/mwparserfromhell/smart_list/smart_list.py +++ b/src/mwparserfromhell/smart_list/smart_list.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2020 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # Copyright (C) 2019-2020 Yuri Astrakhan # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -49,12 +49,16 @@ class SmartList(_SliceNormalizerMixIn, list): [0, 1, 2, 3, 4] """ - def __init__(self, iterable=None): - if iterable: - super().__init__(iterable) - else: - super().__init__() - self._children = {} + __slots__ = ("_children",) + + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + obj._children = {} + return obj + + def __reduce_ex__(self, protocol: int) -> tuple: + # Detach children when pickling + return (SmartList, (), None, iter(self)) def __getitem__(self, key): if not isinstance(key, slice): diff --git a/src/mwparserfromhell/smart_list/utils.py b/src/mwparserfromhell/smart_list/utils.py index 1a36d0b0..ec2d3886 100644 --- a/src/mwparserfromhell/smart_list/utils.py +++ b/src/mwparserfromhell/smart_list/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2016 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # Copyright (C) 2019-2020 Yuri Astrakhan # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -37,6 +37,8 @@ def inheritdoc(method): class _SliceNormalizerMixIn: """MixIn that provides a private method to normalize slices.""" + __slots__ = () + def _normalize_slice(self, key, clamp=False): """Return a slice equivalent to the input *key*, standardized.""" if key.start is None: diff --git a/src/mwparserfromhell/utils.py b/src/mwparserfromhell/utils.py index 0ed1d560..d049d713 100644 --- a/src/mwparserfromhell/utils.py +++ b/src/mwparserfromhell/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2020 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,10 +23,20 @@ users generally won't need stuff from here. """ +from __future__ import annotations + __all__ = ["parse_anything"] +import typing +from typing import Any + +if typing.TYPE_CHECKING: + from .wikicode import Wikicode + -def parse_anything(value, context=0, skip_style_tags=False): +def parse_anything( + value: Any, context: int = 0, *, skip_style_tags: bool = False +) -> Wikicode: """Return a :class:`.Wikicode` for *value*, allowing multiple types. This differs from :meth:`.Parser.parse` in that we accept more than just a @@ -58,11 +68,13 @@ def parse_anything(value, context=0, skip_style_tags=False): if value is None: return Wikicode(SmartList()) if hasattr(value, "read"): - return parse_anything(value.read(), context, skip_style_tags) + return parse_anything(value.read(), context, skip_style_tags=skip_style_tags) try: nodelist = SmartList() for item in value: - nodelist += parse_anything(item, context, skip_style_tags).nodes + nodelist += parse_anything( + item, context, skip_style_tags=skip_style_tags + ).nodes return Wikicode(nodelist) except TypeError as exc: error = ( diff --git a/tests/test_smart_list.py b/tests/test_smart_list.py index 54ac00cd..18c0d10b 100644 --- a/tests/test_smart_list.py +++ b/tests/test_smart_list.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2020 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,8 @@ Test cases for the SmartList class and its child, ListProxy. """ +import pickle + import pytest from mwparserfromhell.smart_list import SmartList @@ -432,3 +434,27 @@ def test_influence(): assert [6, 5, 2, 3, 4, 1] == parent assert [4, 3, 2] == child2 assert 0 == len(parent._children) + + +@pytest.mark.parametrize("protocol", range(pickle.HIGHEST_PROTOCOL + 1)) +def test_pickling(protocol: int): + """test SmartList objects behave properly when pickling""" + parent = SmartList([0, 1, 2, 3, 4, 5]) + enc = pickle.dumps(parent, protocol=protocol) + assert pickle.loads(enc) == parent + + child = parent[1:3] + assert len(parent._children) == 1 + assert list(parent._children.values())[0][0]() is child + enc = pickle.dumps(parent, protocol=protocol) + parent2 = pickle.loads(enc) + assert parent2 == parent + assert parent2._children == {} + + enc = pickle.dumps(child, protocol=protocol) + child2 = pickle.loads(enc) + assert child2 == child + assert child2._parent == parent + assert child2._parent is not parent + assert len(child2._parent._children) == 1 + assert list(child2._parent._children.values())[0][0]() is child2 diff --git a/tests/test_template.py b/tests/test_template.py index f0154a45..abb4e1ea 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -725,7 +725,7 @@ def test_formatting(): ), ] - for (original, expected) in tests: + for original, expected in tests: code = parse(original) template = code.filter_templates()[0] template.add("pop", "12345example ref") diff --git a/tests/test_wikicode.py b/tests/test_wikicode.py index ce624d7c..16c7ebce 100644 --- a/tests/test_wikicode.py +++ b/tests/test_wikicode.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2020 Ben Kurtovic +# Copyright (C) 2012-2023 Ben Kurtovic # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ from functools import partial import re +import pickle from types import GeneratorType import pytest @@ -60,6 +61,14 @@ def test_nodes(): code.__setattr__("nodes", object) +@pytest.mark.parametrize("protocol", range(pickle.HIGHEST_PROTOCOL + 1)) +def test_pickling(protocol: int): + """test Wikicode objects can be pickled""" + code = parse("Have a {{template}} and a [[page|link]]") + enc = pickle.dumps(code, protocol=protocol) + assert pickle.loads(enc) == code + + def test_get(): """test Wikicode.get()""" code = parse("Have a {{template}} and a [[page|link]]")