Skip to content

Commit

Permalink
Merge branch 'master' into free_symbols_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Jul 14, 2023
2 parents 4ec4451 + 28b93bd commit 4638a0a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
15 changes: 12 additions & 3 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,20 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT
]

if isinstance(visited_slice, ast.Tuple):
if len(strides) != len(visited_slice.elts):
# If slice is multi-dimensional and writes to array with more than 1 elements, then:
# - Assume this is indirection (?)
# - Soft-squeeze the slice (remove unit-modes) to match the treatment of the strides above.
if target not in self.constants:
desc = self.sdfg.arrays[dname]
if isinstance(desc, data.Array) and data._prod(desc.shape) != 1:
elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1]
else:
elts = visited_slice.elts
if len(strides) != len(elts):
raise SyntaxError('Invalid number of dimensions in expression (expected %d, '
'got %d)' % (len(strides), len(visited_slice.elts)))
'got %d)' % (len(strides), len(elts)))

return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(visited_slice.elts, strides))
return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(elts, strides))

if len(strides) != 1:
raise SyntaxError('Missing dimensions in expression (expected %d, got one)' % len(strides))
Expand Down
22 changes: 22 additions & 0 deletions tests/python_frontend/indirections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,27 @@ def test_spmv():
assert (np.allclose(y, ref))


def test_indirection_size_1():

def compute_index(scal: dc.int32[5]):
result = 0
with dace.tasklet:
s << scal
r >> result
r = s[1] + 1 - 1
return result

@dc.program
def tester(a: dc.float64[1, 2, 3], scal: dc.int32[5]):
ind = compute_index(scal)
a[0, ind, 0] = 1

arr = np.random.rand(1, 2, 3)
scal = np.array([1, 1, 1, 1, 1], dtype=np.int32)
tester(arr, scal)
assert arr[0, 1, 0] == 1


if __name__ == "__main__":
test_indirection_scalar()
test_indirection_scalar_assign()
Expand All @@ -412,3 +433,4 @@ def test_spmv():
test_indirection_array_nested()
test_indirection_array_nested_nsdfg()
test_spmv()
test_indirection_size_1()

0 comments on commit 4638a0a

Please sign in to comment.