From e980a3c567d81be43ebaa0b640db92570ce2c3da Mon Sep 17 00:00:00 2001 From: Bradley Bauer Date: Wed, 13 Dec 2017 16:34:33 -0500 Subject: [PATCH] Fix hashable_uniq_dict key issue on windows. Unhashable items can not be stored as keys in hashable_uniq_dict. So to fix this, wrap all keys in hashable_uniq_dict with a hashable type. --- torchfile.py | 51 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/torchfile.py b/torchfile.py index 950725b..25f2efc 100644 --- a/torchfile.py +++ b/torchfile.py @@ -48,6 +48,15 @@ LuaFunction = namedtuple('LuaFunction', ['size', 'dumped', 'upvalues']) +class mycontainer(): + def __init__(self, val): + self.val = val + def __hash__(self): + return id(self.val) + def __eq__(self, other): + return id(self.val) == id(other.val) + def __ne__(self, other): + return id(self.val) != id(other.val) class hashable_uniq_dict(dict): """ @@ -60,27 +69,37 @@ class hashable_uniq_dict(dict): This way, dicts can be keys of other dicts. """ + def __iter__(self): + return iter(self.keys()) + + def __getitem__(self, k): + for _k,v in self.items(): + if str(_k) == str(k): + return v + + def __setitem__(self, k, v): + dict.__setitem__(self, mycontainer(k), v) + + def items(self): + return [(k.val, v) for k,v in dict.items(self)] + + def keys(self): + return [k.val for k in dict.keys(self)] + + def values(self): + return [v for v in dict.values(self)] + def __hash__(self): return id(self) - def __getattr__(self, key): - if key in self: - return self[key] - if isinstance(key, (str, bytes)): - return self.get(key.encode('utf8')) - def __eq__(self, other): return id(self) == id(other) - def __ne__(self, other): - return id(self) != id(other) - def _disabled_binop(self, other): raise TypeError( 'hashable_uniq_dict does not support these comparisons') __cmp__ = __ne__ = __le__ = __gt__ = __lt__ = _disabled_binop - class TorchObject(object): """ Simple torch object, used by `add_trivial_class_reader`. @@ -97,16 +116,16 @@ def __init__(self, typename, obj=None, version_number=0): self._version_number = version_number def __getattr__(self, k): - if k in self._obj: + if k in self._obj.keys(): return self._obj[k] if isinstance(k, (str, bytes)): - return self._obj.get(k.encode('utf8')) - + return self._obj[k.encode('utf8')] + def __getitem__(self, k): - if k in self._obj: + if k in self._obj.keys(): return self._obj[k] if isinstance(k, (str, bytes)): - return self._obj.get(k.encode('utf8')) + return self._obj[k.encode('utf8')] def torch_typename(self): return self._typename @@ -118,7 +137,7 @@ def __str__(self): return repr(self) def __dir__(self): - keys = list(self._obj.keys()) + keys = self._obj.keys() keys.append('torch_typename') return keys