Skip to content

Commit

Permalink
Merge branch 'main' into kweidegh/mymachine_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
kweide committed Dec 18, 2024
2 parents d3ea15b + 262c450 commit f3ccf44
Show file tree
Hide file tree
Showing 7 changed files with 1,598 additions and 235 deletions.
2 changes: 1 addition & 1 deletion src/Milhoja_ThreadTeam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ void* ThreadTeam::threadRoutine(void* varg) {
if (auto tileWrapperPrototype =
dynamic_cast<const TileWrapper*>(team->receiverPrototype_)) {
// NOTE: this is the case where dataItem is a TilwWrapper,
// and the team->dataReceiver_ is another TileWrapper.
// and the team->receiverPrototype_ is another TileWrapper.
// Need to transfer dataItem initialized with data receiver's
// tileProtoType, as it may differ.
// TODO: very dirty ownership transfers
Expand Down
182 changes: 163 additions & 19 deletions tools/milhoja_pypkg/src/milhoja/TaskFunctionGenerator_OpenACC_F.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def generate_source_code(self, destination, overwrite):
interface = interface.rstrip(".F90")
fptr.write(f"{INDENT*2}use {interface}, ONLY : {subroutine}\n")
offloading.append("#ifndef SUPPRESS_ACC_ROUTINE_FOR_METH_IN_APP\n")
offloading.append(f"{INDENT*2}!$acc routine ({subroutine}) vector\n")
offloading.append(f"{INDENT*2}!$acc routine ({self._get_wrapper_name(subroutine)}) vector\n"
offloading.append("#endif\n")
fptr.writelines(["\n", *offloading, "\n"])
# No implicit variables
Expand Down Expand Up @@ -245,6 +245,7 @@ def generate_source_code(self, destination, overwrite):
# Data packet sent on dataQ_h
current_queues = ["dataQ_h"]

subroutine_wrappers = {}
for node in self._tf_spec.internal_subroutine_graph:
# Insert waits if needed before next round of kernel launches
extras = [f"queue{i}_h" for i in range(2, len(node) + 1)]
Expand Down Expand Up @@ -288,33 +289,25 @@ def generate_source_code(self, destination, overwrite):
current_queues = next_queues.copy()
assert len(current_queues) == len(node)
for subroutine, queue in zip(node, current_queues):
# subroutine wrapper
# to prevent passing a slice of array
# which may introduce unnecessary device to host maps
wrapper_name, wrapper_lines = self._generate_subroutine_wrapper(INDENT, subroutine)
subroutine_wrappers[wrapper_name] = wrapper_lines

fptr.write(f"{INDENT*2}!$acc parallel loop gang default(none) &\n")
fptr.write(f"{INDENT*2}!$acc& async({queue})\n")
fptr.write(f"{INDENT*2}do n = 1, nTiles_d\n")
fptr.write(f"{INDENT*3}CALL {subroutine}( &\n")
fptr.write(f"{INDENT*3}CALL {wrapper_name}( &\n")
actual_args = \
self._tf_spec.subroutine_actual_arguments(subroutine)
arg_list = []
arg_list = [f"{INDENT*5}n"]
for argument in actual_args:
spec = self._tf_spec.argument_specification(argument)
extents = ""
offs = ""
if spec["source"] in points:
extents = "(:, n)"
elif spec["source"] == TILE_DELTAS_ARGUMENT:
extents = "(:, n)"
elif spec["source"] == TILE_LEVEL_ARGUMENT:
extents = "(1, n)"
if spec["source"] == TILE_LEVEL_ARGUMENT:
offs = " + 1"
elif spec["source"] in bounds:
extents = "(:, :, n)"
elif spec["source"] == GRID_DATA_ARGUMENT:
extents = "(:, :, :, :, n)"
elif spec["source"] == SCRATCH_ARGUMENT:
dimension = len(parse_extents(spec["extents"]))
tmp = [":" for _ in range(dimension)]
extents = "(" + ", ".join(tmp) + ", n)"
arg_list.append(f"{INDENT*5}{argument}_d{extents}{offs}")
arg_list.append(f"{INDENT*5}{argument}_d{offs}")
fptr.write(", &\n".join(arg_list) + " &\n")
fptr.write(f"{INDENT*5})\n")
fptr.write(f"{INDENT*2}end do\n")
Expand Down Expand Up @@ -344,5 +337,156 @@ def generate_source_code(self, destination, overwrite):
# End subroutine declaration
fptr.write(f"{INDENT}end subroutine {self._tf_spec.function_name}\n")
fptr.write("\n")

# Write subroutine wrappers
for wrapper, lines in subroutine_wrappers.items():
for line in lines:
fptr.write(line + "\n")

# End module declaration
fptr.write(f"end module {module}\n\n")

def _get_wrapper_name(self, subroutine):
"""
A helper function to determine the name of subroutine wrapper, consisntently
"""
return "wrapper_" + subroutine

def _generate_subroutine_wrapper(self, indent, subroutine):
"""
A helper function to generate a subroutine wrapper
"""
subroutine_wrapper = self._get_wrapper_name(subroutine)
lines = []

actual_args = self._tf_spec.subroutine_actual_arguments(subroutine)
dummy_args = ["nblk"] + [f"{arg}_d" for arg in actual_args]

lines.append(f"{indent*1}subroutine {subroutine_wrapper} ( &")
dummy_arg_str = f"{indent*5}" + f", &\n{indent*5}".join(dummy_args) + f" &\n{indent*3})\n"
dummy_arg_str = "()\n" if len(dummy_args) == 0 else dummy_arg_str
lines.append(dummy_arg_str)

interface = self._tf_spec.subroutine_interface_file(subroutine).strip()
interface = interface.rstrip(".F90")
lines.append(f"{indent*2}use {interface}, ONLY: {subroutine}")
lines.append("")

lines.append(f"{indent*2}!$acc routine vector")
lines.append(f"{indent*2}!$acc routine ({subroutine}) vector")
lines.append("")

lines.append(f"{indent*2}implicit none")
lines.append("")

lines.append(f"{indent*2}! Arguments")
lines.append(f"{indent*2}integer, intent(IN) :: nblk")

points = {
TILE_LO_ARGUMENT, TILE_HI_ARGUMENT, TILE_LBOUND_ARGUMENT,
TILE_UBOUND_ARGUMENT, LBOUND_ARGUMENT
}
bounds = {TILE_INTERIOR_ARGUMENT, TILE_ARRAY_BOUNDS_ARGUMENT}
pointer_extents = {}
pointer_types = {}
for arg in actual_args:
spec = self._tf_spec.argument_specification(arg)
src = spec["source"]
if src == EXTERNAL_ARGUMENT:
extents = spec["extents"]
if extents != "()":
msg = "No test case for non-scalar externals"
raise NotImplementedError(msg)

# is this okay? Should we fail if there is no type mapping?
arg_type = C2F_TYPE_MAPPING.get(spec["type"], spec["type"])
pointer_extents[arg] = 0
pointer_types[arg] = arg_type
lines.append(f"{indent*2}{arg_type}, target, intent(IN) :: {arg}_d")

elif src in points:
pointer_extents[arg] = 1
pointer_types[arg] = "integer"
lines.append(f"{indent*2}integer, target, intent(IN) :: {arg}_d(:, :)")

elif src == TILE_DELTAS_ARGUMENT:
pointer_extents[arg] = 1
pointer_types[arg] = "real"
lines.append(f"{indent*2}real, target, intent(IN) :: {arg}_d(:, :)")

elif src in bounds:
pointer_extents[arg] = 2
pointer_types[arg] = "integer"
lines.append(f"{indent*2}integer, target, intent(IN) :: {arg}_d(:, :, :)")

elif src == TILE_LEVEL_ARGUMENT:
pointer_extents[arg] = 1
pointer_types[arg] = "integer"
lines.append(f"{indent*2}integer, target, intent(IN) :: {arg}_d(:, :)")

elif src == GRID_DATA_ARGUMENT:
if arg in self._tf_spec.tile_in_arguments:
intent = "IN"
elif arg in self._tf_spec.tile_in_out_arguments:
intent = "INOUT"
elif arg in self._tf_spec.tile_out_arguments:
intent = "OUT"
else:
raise LogicError("Unknown grid data variable class")

pointer_extents[arg] = 4
pointer_types[arg] = "real"
lines.append(f"{indent*2}real, target, intent({intent}) :: {arg}_d(:, :, :, :, :)")

elif src == SCRATCH_ARGUMENT:
arg_type = spec["type"]
dimension = len(parse_extents(spec["extents"]))
assert dimension > 0
tmp = [":" for _ in range(dimension + 1)]
array = "(" + ", ".join(tmp) + ")"
pointer_extents[arg] = len(tmp) - 1
pointer_types[arg] = arg_type
lines.append(f"{indent*2}{arg_type}, target, intent(INOUT) :: {arg}_d{array}")

else:
raise LogicError(f"{arg} of unknown argument class")
lines.append("")

lines.append(f"{indent*2}! Local variables")
pointer_mapping = {}
for arg in actual_args:
spec = self._tf_spec.argument_specification(arg)
arg_p = f"{arg}_d_p"

ptr_type = pointer_types[arg]
ptr_extents = pointer_extents[arg]

if ptr_extents > 0:
pointer_mapping[arg] = arg_p
_ext_str = ", ".join([":"] * ptr_extents)
_line = f"{indent*2}{ptr_type}, pointer :: {arg_p}({_ext_str})"
lines.append(_line)
lines.append("")

lines.append(f"{indent*2}! Attach pointers")
for arg, ptr in pointer_mapping.items():
ptr_extents = pointer_extents[arg]
_ext_str = ", ".join([":"] * ptr_extents) + ", nblk"
_line = f"{indent*2}{ptr} => {arg}_d({_ext_str})"
lines.append(_line)
lines.append("")

lines.append(f"{indent*2}! Call subroutine")
lines.append(f"{indent*2}CALL {subroutine}( &")
arg_list = []
for arg in actual_args:
_arg = pointer_mapping[arg] if arg in pointer_mapping else f"{arg}_d"
arg_list.append(_arg)
lines.append(f"{indent*5}" + f", &\n{indent*5}".join(arg_list) + " &")
lines.append(f"{indent*4})")

lines.append("")
lines.append(f"{indent*1}end subroutine {subroutine_wrapper}")
lines.append("")

return subroutine_wrapper, lines
Loading

0 comments on commit f3ccf44

Please sign in to comment.