Skip to content

Commit

Permalink
Merge pull request #281 from arshad-ml/fix/handle-freeze-defaultdict
Browse files Browse the repository at this point in the history
Fix: Enable freezing collections.defaultdict objects
  • Loading branch information
tobgu authored Oct 16, 2023
2 parents b091106 + 1b783d1 commit ae193f2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pyrsistent/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
from functools import wraps
from pyrsistent._pmap import PMap, pmap
from pyrsistent._pset import PSet, pset
Expand All @@ -10,6 +11,7 @@ def freeze(o, strict=True):
- list is converted to pvector, recursively
- dict is converted to pmap, recursively on values (but not keys)
- defaultdict is converted to pmap, recursively on values (but not keys)
- set is converted to pset, but not recursively
- tuple is converted to tuple, recursively.
Expand All @@ -33,6 +35,8 @@ def freeze(o, strict=True):
typ = type(o)
if typ is dict or (strict and isinstance(o, PMap)):
return pmap({k: freeze(v, strict) for k, v in o.items()})
if typ is collections.defaultdict or (strict and isinstance(o, PMap)):
return pmap({k: freeze(v, strict) for k, v in o.items()})
if typ is list or (strict and isinstance(o, PVector)):
curried_freeze = lambda x: freeze(x, strict)
return pvector(map(curried_freeze, o))
Expand Down
16 changes: 15 additions & 1 deletion tests/freeze_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Tests for freeze and thaw."""

import collections
from pyrsistent import v, m, s, freeze, thaw, PRecord, field, mutant


Expand All @@ -17,6 +17,13 @@ def test_freeze_dict():
assert result == m(a='b')
assert type(freeze({'a': 'b'})) is type(m())

def test_freeze_defaultdict():
test_dict = collections.defaultdict(dict)
test_dict['a'] = 'b'
result = freeze(test_dict)
assert result == m(a='b')
assert type(freeze({'a': 'b'})) is type(m())

def test_freeze_set():
result = freeze(set([1, 2, 3]))
assert result == s(1, 2, 3)
Expand All @@ -27,6 +34,13 @@ def test_freeze_recurse_in_dictionary_values():
assert result == m(a=v(1))
assert type(result['a']) is type(v())

def test_freeze_recurse_in_defaultdict_values():
test_dict = collections.defaultdict(dict)
test_dict['a'] = [1]
result = freeze(test_dict)
assert result == m(a=v(1))
assert type(result['a']) is type(v())

def test_freeze_recurse_in_pmap_values():
input = {'a': m(b={'c': 1})}
result = freeze(input)
Expand Down

0 comments on commit ae193f2

Please sign in to comment.