diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index eb5731231ce4..91e6691a5e96 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -1536,6 +1536,130 @@ def t2d_2(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 ) +def test_cpu_nrm(): + # fmt: off + @T.prim_func + def nrm_0(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) + C = T.alloc_buffer([1], dtype="float32") + C_rf = T.alloc_buffer([1, 32768], dtype="float32") + for i0, i1_i2_fused_0, i1_i2_fused_1 in T.grid(1, 32768, 2): + with T.block("C_rf"): + vi1_i2_fused_0, b, vi1_i2_fused_1 = T.axis.remap("SSR", [i1_i2_fused_0, i0, i1_i2_fused_1]) + T.reads(A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256]) + T.writes(C_rf[b, vi1_i2_fused_0]) + with T.init(): + C_rf[b, vi1_i2_fused_0] = T.float32(0) + C_rf[b, vi1_i2_fused_0] = C_rf[b, vi1_i2_fused_0] + A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] * A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] + for i0, i1_i2_fused_0 in T.grid(1, 32768): + with T.block("C"): + vi1_i2_fused_0, b = T.axis.remap("RS", [i1_i2_fused_0, i0]) + T.reads(C_rf[b, vi1_i2_fused_0]) + T.writes(C[b]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vi1_i2_fused_0] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.spatial(1, i0) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + @T.prim_func + def nrm_1(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64}) + C = T.alloc_buffer([1], dtype="float32") + C_rf = T.alloc_buffer([1, 2], dtype="float32") + for i0, i1_i2_fused_0, i1_i2_fused_1 in T.grid(1, 32768, 2): + with T.block("C_rf"): + vi1_i2_fused_1, b, vi1_i2_fused_0 = T.axis.remap("SSR", [i1_i2_fused_1, i0, i1_i2_fused_0]) + T.reads(A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256]) + T.writes(C_rf[b, vi1_i2_fused_1]) + with T.init(): + C_rf[b, vi1_i2_fused_1] = T.float32(0) + C_rf[b, vi1_i2_fused_1] = C_rf[b, vi1_i2_fused_1] + A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] * A[b, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) // 256, (vi1_i2_fused_0 * 2 + vi1_i2_fused_1) % 256] + for i0, i1_i2_fused_1 in T.grid(1, 2): + with T.block("C"): + vi1_i2_fused_1, b = T.axis.remap("RS", [i1_i2_fused_1, i0]) + T.reads(C_rf[b, vi1_i2_fused_1]) + T.writes(C[b]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vi1_i2_fused_1] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.spatial(1, i0) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + @T.prim_func + def nrm_2(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64}) + C = T.alloc_buffer([1], dtype="float32") + for i0, i1, i2 in T.grid(1, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [i0, i1, i2]) + T.reads(A[b, i, j]) + T.writes(C[b]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.spatial(1, i0) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [32768, 2]), + ("SampleCategorical", 0), + ("SampleComputeLocation", -1), + ("SampleComputeLocation", -1), + ] + decision_1 = [ + ("SamplePerfectTile", [32768, 2]), + ("SampleCategorical", 1), + ("SampleComputeLocation", -1), + ("SampleComputeLocation", -1), + ] + decision_2 = [ + ("SampleCategorical", 0), + ("SampleComputeLocation", -1), + ] + mod = create_te_workload("NRM", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[nrm_0, nrm_1, nrm_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + if __name__ == "__main__": test_cpu_c1d() test_cpu_c2d() @@ -1546,3 +1670,4 @@ def t2d_2(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 test_cpu_gmm() test_cpu_grp() test_cpu_t2d() + test_cpu_nrm() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 6440f1e199d7..e477e90b6bf8 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -833,6 +833,88 @@ def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 ) +def test_cuda_nrm(): + # fmt: off + @T.prim_func + def nrm_0(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":512}) + C = T.alloc_buffer([1], dtype="float32") + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + for i1, i2 in T.grid(256, 256): + with T.block("C"): + b = T.axis.spatial(1, 0) + i, j = T.axis.remap("RR", [i1, i2]) + T.reads(A[b, i, j]) + T.writes(C[b]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, 0) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + @T.prim_func + def nrm_1(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + C_shared = T.alloc_buffer([1], dtype="float32", scope="shared") + for i0_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0, ax1_ax2_fused_0 in T.grid(1, 512): + for ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("C"): + b = T.axis.spatial(1, ax0) + i = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) // 256) + j = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) % 256) + T.reads(A[b, i, j]) + T.writes(C_shared[b]) + with T.init(): + C_shared[b] = T.float32(0) + C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, i0_1) + T.where(0 * 128 + i0_1 < 1) + T.reads(C_shared[b]) + T.writes(D[b]) + D[b] = T.sqrt(C_shared[b], dtype="float32") + # fmt: on + decision_0 = [ + ("SampleCategorical", 3), + ] + decision_1 = [ + ("SampleCategorical", 5), + ("SampleCategorical", 4), + ] + mod = create_te_workload("NRM", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[nrm_0, nrm_1], + expected_decisions=[decision_0, decision_1], + ) + + if __name__ == "__main__": test_cuda_c1d() test_cuda_c2d() @@ -843,3 +925,4 @@ def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 test_cuda_gmm() test_cuda_grp() test_cuda_t2d() + test_cuda_nrm()