Skip to content

Commit

Permalink
Merge pull request #1297 from spcl/refactor-variable-visitors
Browse files Browse the repository at this point in the history
Refactoring Frontend's Variable-Access Visitors
  • Loading branch information
tbennun authored Jul 17, 2023
2 parents 4cc7177 + 026d45f commit d7c26ef
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,7 +4565,10 @@ def _visitname(self, name: str, node: ast.AST):
rname = self.scope_vars[name]
if rname in self.scope_arrays:
rng = subsets.Range.from_array(self.scope_arrays[rname])
rname, _ = self._add_read_access(rname, rng, node)
if isinstance(node.ctx, ast.Store):
rname, _ = self._add_write_access(rname, rng, node)
else:
rname, _ = self._add_read_access(rname, rng, node)
return rname

#### Visitors that return arrays
Expand Down Expand Up @@ -4900,6 +4903,8 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]:
### Subscript (slicing) handling
def visit_Subscript(self, node: ast.Subscript, inference: bool = False):

is_read: bool = not isinstance(node.ctx, ast.Store)

if self.nested:

defined_vars = {**self.variables, **self.scope_vars}
Expand All @@ -4926,13 +4931,19 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False):
if inference:
rng.offset(rng, True)
return self.sdfg.arrays[true_name].dtype, rng.size()
new_name, new_rng = self._add_read_access(name, rng, node)
if is_read:
new_name, new_rng = self._add_read_access(name, rng, node)
else:
new_name, new_rng = self._add_write_access(name, rng, node)
new_arr = self.sdfg.arrays[new_name]
full_rng = subsets.Range.from_array(new_arr)
if new_rng.ranges == full_rng.ranges:
return new_name
else:
new_name, _ = self.make_slice(new_name, new_rng)
if is_read:
new_name, _ = self.make_slice(new_name, new_rng)
else:
raise NotImplementedError('Cannot slice a write access')
return new_name

# Obtain array/tuple
Expand Down Expand Up @@ -4973,8 +4984,11 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False):
rng = expr.subset
rng.offset(rng, True)
return self.sdfg.arrays[array].dtype, rng.size()

return self._add_read_slice(array, node, expr)

if is_read:
return self._add_read_slice(array, node, expr)
else:
raise NotImplementedError('Write slicing not implemented')

def _visit_ast_or_value(self, node: ast.AST) -> Any:
result = self.visit(node)
Expand Down

0 comments on commit d7c26ef

Please sign in to comment.