Skip to content

Commit

Permalink
Dim match_priority
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 19, 2021
1 parent 1fb6279 commit 22bdf2e
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
vocab=None,
dyn_size=None, dyn_size_ext=None,
undefined=False, generic=False, special=False,
match_priority=0,
derived_from_tag=None, derived_from_op=None,
batch=None, control_flow_ctx=None,
src_data=None, src_axis=None):
Expand All @@ -79,6 +80,8 @@ def __init__(self, kind=Types.Unspecified, description=None,
the behavior is to consider them as equal,
and assume that the chain of operations (e.g. padding + valid conv) results in the same dim.
:param Dim.Op|None derived_from_op:
:param int match_priority: when there is ambiguity between multiple dim tags, this can be used to resolve it.
E.g. for a square matrix used for a linear transformation, the reduce dim tag should have a higher priority.
:param BatchInfo|None batch: for batch-dim, or dynamic dims per batch
:param ControlFlowContext|None control_flow_ctx:
:param Data|None src_data:
Expand All @@ -95,6 +98,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
self.derived_from_op = derived_from_op
if derived_from_op and not derived_from_op.output:
derived_from_op.output = self
self.match_priority = match_priority
if src_data:
assert isinstance(src_data, Data) and isinstance(src_axis, int)
if not batch and dyn_size_ext:
Expand Down Expand Up @@ -170,18 +174,20 @@ def __deepcopy__(self, memo=None):
"""
return self

def copy(self, same_as_self, description=None, kind=None):
def copy(self, same_as_self=True, description=None, kind=None, match_priority=None):
"""
:param bool same_as_self:
:param str|None description: new description
:param Entity|None kind: if set, overwrites self.kind
:param int|None match_priority:
:return: copy, maybe as new kind. setting same_as to self
:rtype: Dim
"""
if not same_as_self:
assert description is not None, "%s copy with not same_as_self should have a new description" % self
tag = Dim(
kind=kind or self.kind, description=description or self.description,
match_priority=match_priority if match_priority is not None else self.match_priority,
dimension=self.dimension, dyn_size_ext=self.dyn_size_ext,
batch=self.batch,
src_data=self.src_data, src_axis=self.src_axis)
Expand Down Expand Up @@ -4366,6 +4372,9 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
# Once we have not guaranteed unique dim tags, multiple axes could match.
# https://github.com/rwth-i6/returnn/issues/632
dims = [i for (i, tag) in enumerate(self.dim_tags) if tag == axes]
if len(dims) > 1:
max_match_priority = max(self.dim_tags[i].match_priority for i in dims)
dims = [i for i in dims if self.dim_tags[i].match_priority == max_match_priority]
assert len(dims) == 1, "%s: matching dim %s must be unique" % (self, axes)
return dims
if isinstance(axes, int):
Expand Down

0 comments on commit 22bdf2e

Please sign in to comment.