Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed solution to hashable_uniq_dict key issue on windows. #13

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions torchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
LuaFunction = namedtuple('LuaFunction',
['size', 'dumped', 'upvalues'])

class mycontainer():
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return id(self.val) == id(other.val)
def __ne__(self, other):
return id(self.val) != id(other.val)

class hashable_uniq_dict(dict):
"""
Expand All @@ -60,27 +69,37 @@ class hashable_uniq_dict(dict):
This way, dicts can be keys of other dicts.
"""

def __iter__(self):
return iter(self.keys())

def __getitem__(self, k):
for _k,v in self.items():
if str(_k) == str(k):
return v

def __setitem__(self, k, v):
dict.__setitem__(self, mycontainer(k), v)

def items(self):
return [(k.val, v) for k,v in dict.items(self)]

def keys(self):
return [k.val for k in dict.keys(self)]

def values(self):
return [v for v in dict.values(self)]

def __hash__(self):
return id(self)

def __getattr__(self, key):
if key in self:
return self[key]
if isinstance(key, (str, bytes)):
return self.get(key.encode('utf8'))

def __eq__(self, other):
return id(self) == id(other)

def __ne__(self, other):
return id(self) != id(other)

def _disabled_binop(self, other):
raise TypeError(
'hashable_uniq_dict does not support these comparisons')
__cmp__ = __ne__ = __le__ = __gt__ = __lt__ = _disabled_binop


class TorchObject(object):
"""
Simple torch object, used by `add_trivial_class_reader`.
Expand All @@ -97,16 +116,16 @@ def __init__(self, typename, obj=None, version_number=0):
self._version_number = version_number

def __getattr__(self, k):
if k in self._obj:
if k in self._obj.keys():
return self._obj[k]
if isinstance(k, (str, bytes)):
return self._obj.get(k.encode('utf8'))

return self._obj[k.encode('utf8')]
def __getitem__(self, k):
if k in self._obj:
if k in self._obj.keys():
return self._obj[k]
if isinstance(k, (str, bytes)):
return self._obj.get(k.encode('utf8'))
return self._obj[k.encode('utf8')]

def torch_typename(self):
return self._typename
Expand All @@ -118,7 +137,7 @@ def __str__(self):
return repr(self)

def __dir__(self):
keys = list(self._obj.keys())
keys = self._obj.keys()
keys.append('torch_typename')
return keys

Expand Down