Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DLight] Check for target in function attributes #16958

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def _is_scheduled(func: tir.PrimFunc) -> bool:
return func.attrs["tir.is_scheduled"] == 1


def _get_target(func: tir.PrimFunc) -> Target:
target = func.attrs.get("target")
if target is None:
return Target.current(allow_none=False)
else:
return target


@module_pass(opt_level=0, name="ApplyDefaultSchedule")
class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods
"""A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module."""
Expand All @@ -55,10 +63,11 @@ def transform_module( # pylint: disable=missing-function-docstring
mod: IRModule,
_: PassContext,
) -> IRModule:
target = Target.current(allow_none=False)
updated_functions = {}
for g_var, func in mod.functions_items():
if isinstance(func, tir.PrimFunc) and not _is_scheduled(func):
target = _get_target(func)

sch = _apply_rules(func, target, self.rules, tunable=False)
if sch is not None:
assert len(sch) == 1
Expand Down
78 changes: 78 additions & 0 deletions tests/python/dlight/test_gpu_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,83 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl
assert_structural_equal(mod["main"], expected)


def test_gpu_fallback_ignores_non_gpu_functions():
@I.ir_module
class Before:
# This function has no "target" attribute, and is scheduled
# using the `Target.current`.
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

# This function is identical, except that it is explicitly
# annotated with the "target" attribute, and is scheduled
# based on the annotation's target.
@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

@I.ir_module
class After:
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"tir.is_scheduled": 1})
for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1)
T.reads(A[0, v0 // 128, 0, v0 % 128])
T.writes(C[0, 0, v0])
C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]

@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

with Target("cuda"):
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(Before)
assert_structural_equal(mod, After)


if __name__ == "__main__":
tvm.testing.main()
Loading