diff --git a/torchfile.py b/torchfile.py index 950725b..25f2efc 100644 --- a/torchfile.py +++ b/torchfile.py @@ -48,6 +48,15 @@ LuaFunction = namedtuple('LuaFunction', ['size', 'dumped', 'upvalues']) +class mycontainer(): + def __init__(self, val): + self.val = val + def __hash__(self): + return id(self.val) + def __eq__(self, other): + return id(self.val) == id(other.val) + def __ne__(self, other): + return id(self.val) != id(other.val) class hashable_uniq_dict(dict): """ @@ -60,27 +69,37 @@ class hashable_uniq_dict(dict): This way, dicts can be keys of other dicts. """ + def __iter__(self): + return iter(self.keys()) + + def __getitem__(self, k): + for _k,v in self.items(): + if str(_k) == str(k): + return v + + def __setitem__(self, k, v): + dict.__setitem__(self, mycontainer(k), v) + + def items(self): + return [(k.val, v) for k,v in dict.items(self)] + + def keys(self): + return [k.val for k in dict.keys(self)] + + def values(self): + return [v for v in dict.values(self)] + def __hash__(self): return id(self) - def __getattr__(self, key): - if key in self: - return self[key] - if isinstance(key, (str, bytes)): - return self.get(key.encode('utf8')) - def __eq__(self, other): return id(self) == id(other) - def __ne__(self, other): - return id(self) != id(other) - def _disabled_binop(self, other): raise TypeError( 'hashable_uniq_dict does not support these comparisons') __cmp__ = __ne__ = __le__ = __gt__ = __lt__ = _disabled_binop - class TorchObject(object): """ Simple torch object, used by `add_trivial_class_reader`. @@ -97,16 +116,16 @@ def __init__(self, typename, obj=None, version_number=0): self._version_number = version_number def __getattr__(self, k): - if k in self._obj: + if k in self._obj.keys(): return self._obj[k] if isinstance(k, (str, bytes)): - return self._obj.get(k.encode('utf8')) - + return self._obj[k.encode('utf8')] + def __getitem__(self, k): - if k in self._obj: + if k in self._obj.keys(): return self._obj[k] if isinstance(k, (str, bytes)): - return self._obj.get(k.encode('utf8')) + return self._obj[k.encode('utf8')] def torch_typename(self): return self._typename @@ -118,7 +137,7 @@ def __str__(self): return repr(self) def __dir__(self): - keys = list(self._obj.keys()) + keys = self._obj.keys() keys.append('torch_typename') return keys