Skip to content

Commit

Permalink
Merge branch 'master' into alternative-fix-for-1300
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 authored Jul 13, 2023
2 parents b61accc + 0a29384 commit 34510c4
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 8 deletions.
5 changes: 3 additions & 2 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ def get_generated_codeobjects(self):
// Create {backend} streams and events
for(int i = 0; i < {nstreams}; ++i) {{
DACE_GPU_CHECK({backend}StreamCreateWithFlags(&__state->gpu_context->streams[i], {backend}StreamNonBlocking));
DACE_GPU_CHECK({backend}StreamCreateWithFlags(&__state->gpu_context->internal_streams[i], {backend}StreamNonBlocking));
__state->gpu_context->streams[i] = __state->gpu_context->internal_streams[i]; // Allow for externals to modify streams
}}
for(int i = 0; i < {nevents}; ++i) {{
DACE_GPU_CHECK({backend}EventCreateWithFlags(&__state->gpu_context->events[i], {backend}EventDisableTiming));
Expand All @@ -398,7 +399,7 @@ def get_generated_codeobjects(self):
// Destroy {backend} streams and events
for(int i = 0; i < {nstreams}; ++i) {{
DACE_GPU_CHECK({backend}StreamDestroy(__state->gpu_context->streams[i]));
DACE_GPU_CHECK({backend}StreamDestroy(__state->gpu_context->internal_streams[i]));
}}
for(int i = 0; i < {nevents}; ++i) {{
DACE_GPU_CHECK({backend}EventDestroy(__state->gpu_context->events[i]));
Expand Down
10 changes: 9 additions & 1 deletion dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _create_einsum_internal(sdfg: SDFG,

if init_output is None:
init_output = (beta != 1.0)

if alpha is None:
alpha = 1.0
if beta is None:
Expand Down Expand Up @@ -373,6 +373,14 @@ def _create_einsum_internal(sdfg: SDFG,
strides['sCN'] = 1
strides['sCB'] = strides['sCM'] = strides['N']

# Transposed output, swap order
if strides['sCM'] == 1:
strides['sCM'], strides['sCN'] = strides['sCN'], strides['sCM']
strides['M'], strides['N'] = strides['N'], strides['M']
(strides['sAM'], strides['sAK'], strides['sAB'], strides['sBK'], strides['sBN'], strides['sBB']) = \
(strides['sBN'], strides['sBK'], strides['sBB'], strides['sAK'], strides['sAM'], strides['sAB'])
a, b = b, a

# Create nested SDFG for GEMM
nsdfg = create_batch_gemm_sdfg(dtype, strides, alpha, beta)

Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,8 @@ def _visit_test(self, node: ast.Expr):
# Visit test-condition
if not is_test_simple:
parsed_node = self.visit(node)
if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1:
parsed_node = parsed_node[0]
if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays:
datadesc = self.sdfg.arrays[parsed_node]
if isinstance(datadesc, data.Array):
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4370,6 +4370,8 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt
'outputs': ['__out'],
'code': "__out = dace.{}(__inp)".format(dtype.to_string())
}
if dtype in (dace.bool, dace.bool_):
impl['code'] = "__out = dace.bool_(__inp)"
tasklet_params = _set_tasklet_params(impl, [arg])

# Visitor input only needed when `has_where == True`.
Expand Down
2 changes: 2 additions & 0 deletions dace/runtime/include/dace/cuda/cudacommon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ struct Context {
int num_streams;
int num_events;
gpuStream_t *streams;
gpuStream_t *internal_streams;
gpuEvent_t *events;
gpuError_t lasterror;
Context(int nstreams, int nevents)
: num_streams(nstreams), num_events(nevents), lasterror((gpuError_t)0) {
streams = new gpuStream_t[nstreams];
internal_streams = new gpuStream_t[nstreams];
events = new gpuEvent_t[nevents];
}
~Context() {
Expand Down
7 changes: 7 additions & 0 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ def validate_state(state: 'dace.sdfg.SDFGState',
src_node = path[0].src
dst_node = path[-1].dst

# NestedSDFGs must connect to AccessNodes
if not e.data.is_empty():
if isinstance(src_node, nd.NestedSDFG) and not isinstance(dst_node, nd.AccessNode):
raise InvalidSDFGEdgeError("Nested SDFG source nodes must be AccessNodes", sdfg, state_id, eid)
if isinstance(dst_node, nd.NestedSDFG) and not isinstance(src_node, nd.AccessNode):
raise InvalidSDFGEdgeError("Nested SDFG destination nodes must be AccessNodes", sdfg, state_id, eid)

# Set up memlet-specific SDFG context
memlet_context = copy.copy(context)
for pe in path:
Expand Down
4 changes: 4 additions & 0 deletions dace/transformation/auto/auto_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ def make_transients_persistent(sdfg: SDFG,
not_persistent.add(dnode.data)
continue

if desc.lifetime == dtypes.AllocationLifetime.External:
not_persistent.add(dnode.data)
continue

persistent.add(dnode.data)

for aname in (persistent - not_persistent):
Expand Down
15 changes: 12 additions & 3 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,16 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
graph.node_id(edge.dst),
edge.dst_conn,
)
# Add intermediate memory between subgraphs. If a scalar,
# uses direct connection. If an array, adds a transient node
if edge.data.subset.num_elements() == 1:
# Add intermediate memory between subgraphs.
# If a scalar, uses direct connection. If an array, adds a transient node.
# NOTE: If any of the src/dst nodes is a nested SDFG, treat it as an array.
is_scalar = edge.data.subset.num_elements() == 1
accesses = (
[graph.memlet_path(e1)[0].src for e0 in graph.in_edges(access_node) for e1 in graph.memlet_tree(e0)] +
[graph.memlet_path(e1)[-1].dst for e0 in graph.out_edges(access_node) for e1 in graph.memlet_tree(e0)])
if any(isinstance(a, nodes.NestedSDFG) for a in accesses):
is_scalar = False
if is_scalar:
local_name, _ = sdfg.add_scalar(
local_name,
dtype=access_node.desc(graph).dtype,
Expand Down Expand Up @@ -520,5 +527,7 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
# Modify data and memlets on all surrounding edges to match array
for neighbor in graph.all_edges(local_node):
for e in graph.memlet_tree(neighbor):
if e.data.data == local_name:
continue
e.data.data = local_name
e.data.subset.offset(old_edge.data.subset, negative=True)
16 changes: 15 additions & 1 deletion dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,21 @@ def _candidates(
continue
in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset))))

# TODO: Check in_candidates in interstate edges as well
# Check read memlets in interstate edges for candidates
for e in nsdfg.sdfg.edges():
for m in e.data.get_read_memlets(nsdfg.sdfg.arrays):
# If more than one unique element detected, remove from candidates
if m.data in in_candidates:
memlet, ns, indices = in_candidates[m.data]
# Try to find dimensions in which there is a mismatch and remove them from list
for i, (s1, s2) in enumerate(zip(m.subset, memlet.subset)):
if s1 != s2 and i in indices:
indices.remove(i)
if len(indices) == 0:
ignore.add(m.data)
in_candidates[m.data] = (memlet, ns, indices)
continue
in_candidates[m.data] = (m, None, set(range(len(m.subset))))

# Check in/out candidates
for cand in in_candidates.keys() & out_candidates.keys():
Expand Down
26 changes: 25 additions & 1 deletion tests/numpy/einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def test_general_einsum():

@dace.program
def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M], C: dace.float64[M]):
return np.einsum('ij,ji,i->', A, B, C)
Expand All @@ -20,6 +21,7 @@ def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M], C: dace.float64[M])


def test_matmul():

@dace.program
def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M]):
return np.einsum('ik,kj', A, B)
Expand All @@ -30,6 +32,7 @@ def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M]):


def test_batch_matmul():

@dace.program
def einsumtest(A: dace.float64[4, M, N], B: dace.float64[4, N, M]):
return np.einsum('bik,bkj->bij', A, B)
Expand All @@ -40,6 +43,7 @@ def einsumtest(A: dace.float64[4, M, N], B: dace.float64[4, N, M]):


def test_opteinsum_sym():

@dace.program
def einsumtest(A: dace.float64[N, N, N, N], B: dace.float64[N, N, N, N], C: dace.float64[N, N, N, N],
D: dace.float64[N, N, N, N], E: dace.float64[N, N, N, N]):
Expand Down Expand Up @@ -175,6 +179,7 @@ def tester(A, B):
sdfg(A, B)
assert np.allclose(B, np.einsum('ijk->', A))


def test_lift_einsum_reduce_partial():
from dace.libraries.standard.nodes.reduce import Reduce
from dace.libraries.blas.nodes.einsum import Einsum
Expand All @@ -197,7 +202,7 @@ def tester(A, B):
# Specialize to ensure Reduce node is there
sdfg.expand_library_nodes(recursive=False)
rnode = next(node for node, _ in sdfg.all_nodes_recursive() if isinstance(node, Reduce))
assert tuple(rnode.axes) == (1,)
assert tuple(rnode.axes) == (1, )

sdfg(A, B)
assert np.allclose(B, np.einsum('ijk->ik', A))
Expand Down Expand Up @@ -297,6 +302,24 @@ def tester(A, B):
assert np.allclose(sdfg(A, B), C)


def test_c_transposed():
N, F_in, F_out = 2, 3, 3

@dace.program
def fn(a, b, c):
c[:] = np.einsum('nm,nf->fm', a, b)

a = np.random.rand(N, F_in)
b = np.random.rand(N, F_out)
c_expected = np.zeros((F_out, F_in))
c = np.zeros((F_out, F_in))

fn.f(a, b, c_expected)
fn(a, b, c)

assert np.allclose(c, c_expected)


if __name__ == '__main__':
test_general_einsum()
test_matmul()
Expand All @@ -312,3 +335,4 @@ def tester(A, B):
test_lift_einsum_beta()
test_lift_einsum_alpha_beta(False)
test_lift_einsum_alpha_beta(True)
test_c_transposed()
14 changes: 14 additions & 0 deletions tests/python_frontend/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ def if_return_chain(i: dace.int64):
assert if_return_chain(15)[0] == 4


def test_if_test_call():

@dace.program
def if_test_call(a, b):
if bool(a):
return a
else:
return b

assert if_test_call(0, 2)[0] == if_test_call.f(0, 2)
assert if_test_call(1, 2)[0] == if_test_call.f(1, 2)


if __name__ == "__main__":
test_simple_if()
test_call_if()
Expand All @@ -169,3 +182,4 @@ def if_return_chain(i: dace.int64):
test_call_while()
test_if_return_both()
test_if_return_chain()
test_if_test_call()
61 changes: 61 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,71 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]):
assert np.allclose(val[0], ref)


def test_fusion_with_nested_sdfg_0():

@dace.program
def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]):
tmp = np.empty([10], dtype=np.int32)
for i in dace.map[0:10]:
if C[i] < 0:
tmp[i] = B[i] - A[i]
else:
tmp[i] = B[i] + A[i]
for i in dace.map[0:10]:
A[i] = tmp[i] * 2

sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True)
sdfg.apply_transformations(MapFusion)

for sd in sdfg.all_sdfgs_recursive():
if sd is not sdfg:
node = sd.parent_nsdfg_node
state = sd.parent
for e0 in state.out_edges(node):
for e1 in state.memlet_tree(e0):
dst = state.memlet_path(e1)[-1].dst
assert isinstance(dst, dace.nodes.AccessNode)


def test_fusion_with_nested_sdfg_1():

@dace.program
def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]):
tmp = np.empty([10], dtype=np.int32)
for i in dace.map[0:10]:
with dace.tasklet:
a << A[i]
b << B[i]
t >> tmp[i]
t = b - a
for i in dace.map[0:10]:
if C[i] < 0:
A[i] = tmp[i] * 2
else:
B[i] = tmp[i] * 2

sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True)
sdfg.apply_transformations(MapFusion)

if len(sdfg.states()) != 1:
return

for sd in sdfg.all_sdfgs_recursive():
if sd is not sdfg:
node = sd.parent_nsdfg_node
state = sd.parent
for e0 in state.in_edges(node):
for e1 in state.memlet_tree(e0):
src = state.memlet_path(e1)[0].src
assert isinstance(src, dace.nodes.AccessNode)


if __name__ == '__main__':
test_fusion_simple()
test_multiple_fusions()
test_fusion_chain()
test_fusion_with_transient()
test_fusion_with_inverted_indices()
test_fusion_with_empty_memlet()
test_fusion_with_nested_sdfg_0()
test_fusion_with_nested_sdfg_1()
Loading

0 comments on commit 34510c4

Please sign in to comment.