Skip to content

Commit

Permalink
Stop checking for input alias in Function.__call__
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 17, 2024
1 parent 266dee2 commit 2faae23
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 121 deletions.
81 changes: 12 additions & 69 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,41 +393,6 @@ def __init__(
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)

# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
potential_aliased_input_groups = []
for inp in maker.inputs:
# If the input is a shared variable, the memory region is under PyTensor control
# and can't be aliased.
if not (
isinstance(inp, In)
and inp.borrow
and not inp.shared
and hasattr(inp.variable.type, "may_share_memory")
):
continue

for group in potential_aliased_input_groups:
# If one is super of the other, that means one could be replaced by the other
if any(
inp.variable.type.is_super(other_inp.variable.type)
or other_inp.variable.type.is_super(inp.variable.type)
for other_inp in group
):
group.append(inp)
break
else: # no break
# Input makes a new group
potential_aliased_input_groups.append([inp])

# Potential aliased inputs are those that belong to the same group
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
tuple(maker.inputs.index(inp) for inp in group)
for group in potential_aliased_input_groups
if len(group) > 1
)

# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
finder = {}
Expand Down Expand Up @@ -844,11 +809,18 @@ def __call__(self, *args, **kwargs):
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
# Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

else:
# Reinitialize each container's 'provided' counter
for arg_container in input_storage:
arg_container.provided = 0

Expand Down Expand Up @@ -899,39 +871,10 @@ def __call__(self, *args, **kwargs):
raise
arg_container.provided += 1

# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

if not self.trust_input:
# Collect aliased inputs among the storage space
for potential_group in self._potential_aliased_input_groups:
args_share_memory: list[list[int]] = []
for i in potential_group:
i_type = self.maker.inputs[i].variable.type
i_val = input_storage[i].storage[0]

# Check if value is aliased with any of the values in one of the groups
for j_group in args_share_memory:
if any(
i_type.may_share_memory(input_storage[j].storage[0], i_val)
for j in j_group
):
j_group.append(i)
break
else: # no break
# Create a new group
args_share_memory.append([i])

# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for i in group[1:]:
input_storage[i].storage[0] = copy.copy(
input_storage[i].storage[0]
)
# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg

# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
Expand Down
52 changes: 0 additions & 52 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,52 +752,6 @@ def test_default_values(self):
except TypeError:
assert funct(first=1) == x

def test_check_for_aliased_inputs(self):
b = np.random.random((5, 4))
s1 = shared(b)
s2 = shared(b)
x1 = vector()
x2 = vector(shape=(3,))
x3 = vector(shape=(1,))

# Assert cases we should not check for aliased inputs
for d in [
dict(outputs=[s1 + 1]),
dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
),
dict(
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
outputs=[x2 + 2, x3 + 3],
),
]:
if "inputs" not in d:
d["inputs"] = []
f = function(**d)
assert not f._potential_aliased_input_groups, d

# Assert cases we should check for aliased inputs
for d in [
dict(
inputs=[In(x1, mutable=True), In(x2, mutable=True)],
outputs=[x1 + 1, x2 + 2],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, mutable=True), In(x3, mutable=True)],
outputs=[x1 + 1, x3 + 3],
updates=[(s2, s2 + 3)],
),
]:
if "inputs" not in d:
d["inputs"] = []
f = function(**d)

assert f._potential_aliased_input_groups, d

def test_output_dictionary(self):
# Tests that function works when outputs is a dictionary

Expand Down Expand Up @@ -939,12 +893,6 @@ def test_deepcopy(self):
assert x not in g.container
assert x not in g.value
assert len(f.defaults) == len(g.defaults)
# Shared variable is the first input
assert (
f._potential_aliased_input_groups
== g._potential_aliased_input_groups
== ((1, 2),)
)
assert f.name == g.name
assert f.maker.fgraph.name == g.maker.fgraph.name
# print(f"{f.defaults = }")
Expand Down

0 comments on commit 2faae23

Please sign in to comment.