diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2283b433bd..59f04b7c36 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4565,7 +4565,10 @@ def _visitname(self, name: str, node: ast.AST): rname = self.scope_vars[name] if rname in self.scope_arrays: rng = subsets.Range.from_array(self.scope_arrays[rname]) - rname, _ = self._add_read_access(rname, rng, node) + if isinstance(node.ctx, ast.Store): + rname, _ = self._add_write_access(rname, rng, node) + else: + rname, _ = self._add_read_access(rname, rng, node) return rname #### Visitors that return arrays @@ -4900,6 +4903,8 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: ### Subscript (slicing) handling def visit_Subscript(self, node: ast.Subscript, inference: bool = False): + is_read: bool = not isinstance(node.ctx, ast.Store) + if self.nested: defined_vars = {**self.variables, **self.scope_vars} @@ -4926,13 +4931,19 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): if inference: rng.offset(rng, True) return self.sdfg.arrays[true_name].dtype, rng.size() - new_name, new_rng = self._add_read_access(name, rng, node) + if is_read: + new_name, new_rng = self._add_read_access(name, rng, node) + else: + new_name, new_rng = self._add_write_access(name, rng, node) new_arr = self.sdfg.arrays[new_name] full_rng = subsets.Range.from_array(new_arr) if new_rng.ranges == full_rng.ranges: return new_name else: - new_name, _ = self.make_slice(new_name, new_rng) + if is_read: + new_name, _ = self.make_slice(new_name, new_rng) + else: + raise NotImplementedError('Cannot slice a write access') return new_name # Obtain array/tuple @@ -4973,8 +4984,11 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): rng = expr.subset rng.offset(rng, True) return self.sdfg.arrays[array].dtype, rng.size() - - return self._add_read_slice(array, node, expr) + + if is_read: + return self._add_read_slice(array, node, expr) + else: + raise NotImplementedError('Write slicing not implemented') def _visit_ast_or_value(self, node: ast.AST) -> Any: result = self.visit(node)