From 0e16e61ab4889a82d4b220a9108fa235d0706de3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Nov 2022 21:37:36 -0700 Subject: [PATCH] [MetaSchedule] Refactor ScheduleRule Attributes This PR refactors the existing `schedule_rule` block annotation-based dispatch into a more organized folder structure, which follows the exact same mechanism as AutoTVM. In the example below, when target is `Target("llvm -keys=cuda,gpu")` and the block annotation is `conv2d_nchw_winograd_inverse`: ```python with T.block("some_block"): T.block_attr({ "scehdule_rule": "conv2d_nchw_winograd_inverse", }) ``` the system will find the following global packed functions in order: - `meta_schedule.cuda.conv2d_nchw_winograd_inverse` - `meta_schedule.gpu.conv2d_nchw_winograd_inverse` whose function signatures are: ```python def schedule_rule( sch: tvm.tir.Schedule, block: tvm.tir.schedule.BlockRV, ) -> List[tvm.tir.Schedule]: ``` In terms of code organization, for example, for target key `cuda`, the schedule functions are supposed to be defined in: - `include/tvm/meta_schedule/schedule/cuda` for public methods - `src/meta_schedule/schedule/cuda` for private methods - `python/tvm/meta_schedule/schedule/cuda` for direct python side definition --- .../tvm/meta_schedule/schedule/cpu/.gitignore | 0 .../meta_schedule/schedule/cuda/thread_bind.h | 50 ++- .../meta_schedule/schedule/generic/winograd.h | 37 ++ .../tvm/meta_schedule/schedule/x86/.gitignore | 0 include/tvm/meta_schedule/schedule_rule.h | 8 + python/tvm/meta_schedule/__init__.py | 6 +- python/tvm/meta_schedule/relay_integration.py | 16 +- python/tvm/meta_schedule/schedule/__init__.py | 18 + .../meta_schedule/schedule/cpu/__init__.py | 17 + .../meta_schedule/schedule/cuda/__init__.py | 17 + .../schedule/generic/__init__.py | 17 + .../meta_schedule/schedule/x86/__init__.py | 17 + .../meta_schedule/schedule_rule/__init__.py | 5 +- .../schedule_rule/apply_custom_rule.py | 33 ++ .../search_strategy/evolutionary_search.py | 2 +- .../testing/conv2d_winograd_cpu.py | 172 --------- .../testing/conv2d_winograd_cuda.py | 173 --------- .../meta_schedule/testing/relay_workload.py | 1 - .../meta_schedule/testing/space_generation.py | 2 +- .../tvm/meta_schedule/testing/te_workload.py | 150 ++++---- python/tvm/relay/backend/te_compiler.py | 4 +- python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/strategy/adreno.py | 10 +- python/tvm/relay/op/strategy/arm_cpu.py | 12 +- python/tvm/relay/op/strategy/bifrost.py | 10 +- python/tvm/relay/op/strategy/cuda.py | 72 +++- python/tvm/relay/op/strategy/generic.py | 10 +- python/tvm/relay/op/strategy/mali.py | 10 +- python/tvm/relay/op/strategy/x86.py | 10 +- python/tvm/topi/cuda/conv2d_alter_op.py | 56 +-- python/tvm/topi/cuda/conv2d_nhwc_winograd.py | 4 +- python/tvm/topi/cuda/conv2d_winograd.py | 35 +- python/tvm/topi/nn/conv2d.py | 352 ++++++++++++++--- python/tvm/topi/utils.py | 14 +- python/tvm/topi/x86/batch_matmul.py | 8 +- python/tvm/topi/x86/dense.py | 14 +- .../postproc/rewrite_unbound_block.cc | 3 +- src/meta_schedule/schedule/cpu/winograd.cc | 101 +++++ .../schedule/cuda/thread_bind.cc | 181 +++++++++ src/meta_schedule/schedule/cuda/winograd.cc | 163 ++++++++ .../schedule/generic/winograd.cc | 46 +++ src/meta_schedule/schedule/x86/.gitignore | 0 .../schedule_rule/apply_custom_rule.cc | 92 +++++ src/meta_schedule/schedule_rule/auto_bind.cc | 138 +------ .../schedule_rule/schedule_rule.cc | 56 ++- src/meta_schedule/schedule_rule/winograd.cc | 249 ------------ .../space_generator/post_order_apply.cc | 49 +-- src/meta_schedule/utils.h | 35 +- src/target/tag.cc | 9 +- src/te/operation/create_primfunc.cc | 17 +- .../metaschedule_e2e/test_resnet50_int8.py | 14 +- ..._meta_schedule_custom_rule_winograd_cpu.py | 206 ---------- ...meta_schedule_custom_rule_winograd_cuda.py | 328 ---------------- .../test_meta_schedule_post_order_apply.py | 43 --- .../test_meta_schedule_relay_integration.py | 7 +- .../test_meta_schedule_space_cpu_winograd.py | 168 +++++++++ .../unittest/test_meta_schedule_space_cuda.py | 169 --------- .../test_meta_schedule_space_cuda_winograd.py | 355 ++++++++++++++++++ .../test_meta_schedule_vnni_integration.py | 14 +- .../unittest/test_te_create_primfunc.py | 2 - .../test_tir_analysis_stmt_finding.py | 7 +- 61 files changed, 1966 insertions(+), 1852 deletions(-) create mode 100644 include/tvm/meta_schedule/schedule/cpu/.gitignore rename src/meta_schedule/schedule_rule/auto_bind.h => include/tvm/meta_schedule/schedule/cuda/thread_bind.h (57%) create mode 100644 include/tvm/meta_schedule/schedule/generic/winograd.h create mode 100644 include/tvm/meta_schedule/schedule/x86/.gitignore create mode 100644 python/tvm/meta_schedule/schedule/__init__.py create mode 100644 python/tvm/meta_schedule/schedule/cpu/__init__.py create mode 100644 python/tvm/meta_schedule/schedule/cuda/__init__.py create mode 100644 python/tvm/meta_schedule/schedule/generic/__init__.py create mode 100644 python/tvm/meta_schedule/schedule/x86/__init__.py create mode 100644 python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py delete mode 100644 python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py delete mode 100644 python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py create mode 100644 src/meta_schedule/schedule/cpu/winograd.cc create mode 100644 src/meta_schedule/schedule/cuda/thread_bind.cc create mode 100644 src/meta_schedule/schedule/cuda/winograd.cc create mode 100644 src/meta_schedule/schedule/generic/winograd.cc create mode 100644 src/meta_schedule/schedule/x86/.gitignore create mode 100644 src/meta_schedule/schedule_rule/apply_custom_rule.cc delete mode 100644 src/meta_schedule/schedule_rule/winograd.cc delete mode 100644 tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py delete mode 100644 tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py create mode 100644 tests/python/unittest/test_meta_schedule_space_cpu_winograd.py create mode 100644 tests/python/unittest/test_meta_schedule_space_cuda_winograd.py diff --git a/include/tvm/meta_schedule/schedule/cpu/.gitignore b/include/tvm/meta_schedule/schedule/cpu/.gitignore new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/meta_schedule/schedule_rule/auto_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h similarity index 57% rename from src/meta_schedule/schedule_rule/auto_bind.h rename to include/tvm/meta_schedule/schedule/cuda/thread_bind.h index b397d2015c19..ae6d492bfe12 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -16,37 +16,53 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_ -#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_ +#ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_ +#define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_ -#include "../utils.h" +#include + +#include +#include +#include namespace tvm { namespace meta_schedule { /*! - * \brief Bind the given block if it is not bound to blockIdx or threadIdx. + * \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical` + * to return a random thread extent. + * \param sch The schedule + * \param thread_extents The candidate thread extents. + * \return A sampler that returns a random thread extent. + */ +std::function MakeFactorSampler(tir::Schedule sch, + Array thread_extents); + +/*! + * \brief Bind blockIdx.x and threadIdx.x to the given loop * \param sch The schedule. - * \param block The block to be bound. + * \param loop The loop to be bound. * \param max_threadblocks The maximum number of threadblocks allowed. - * \param max_threads The maximum number of threads allowed. + * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. */ -void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block, - int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor); +Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor = nullptr); /*! - * \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical` - * to return a random thread extent. - * \param sch The schedule - * \param thread_extents The candidate thread extents. - * \return A sampler that returns a random thread extent. + * \brief Bind the given block if it is not bound to blockIdx or threadIdx. + * \param sch The schedule. + * \param block The block to be bound. + * \param max_threadblocks The maximum number of threadblocks allowed. + * \param max_threads_per_block The maximum number of threads allowed. + * \param get_factor A function that returns the tiling factor. */ -std::function MakeFactorSampler(tir::Schedule sch, - Array thread_extents); +void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor = nullptr); } // namespace meta_schedule } // namespace tvm -#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_ +#endif // TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_ diff --git a/include/tvm/meta_schedule/schedule/generic/winograd.h b/include/tvm/meta_schedule/schedule/generic/winograd.h new file mode 100644 index 000000000000..dc9b32fd10de --- /dev/null +++ b/include/tvm/meta_schedule/schedule/generic/winograd.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_ +#define TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Get the producer block of a given block. + * If there is a constant winograd transform matrix, inline it. + * \return The only producer block. + */ +tir::BlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::BlockRV block); + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_ diff --git a/include/tvm/meta_schedule/schedule/x86/.gitignore b/include/tvm/meta_schedule/schedule/x86/.gitignore new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1b018512146f..da8f1faa8e1d 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -99,6 +99,14 @@ class ScheduleRule : public runtime::ObjectRef { * \return The cloned schedule rule. */ using FClone = runtime::TypedPackedFunc; + /*! + * \brief Create a rule that applies customized rules registered using block attribute + * `schedule_rule`. The rule will be dispatched according to target keys. + * \return The created schedule rule. + */ + TVM_DLL static ScheduleRule ApplyCustomRule(); + /*! \brief Check if the rule is `ApplyCustomRule` */ + TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule); /*! * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions * \param into_producer If allows to inline a block into its producer diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 0dd679e047e0..30a4fc6d9467 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -26,6 +26,7 @@ postproc, relay_integration, runner, + schedule, schedule_rule, search_strategy, space_generator, @@ -41,10 +42,7 @@ from .mutator import Mutator from .postproc import Postproc from .profiler import Profiler -from .relay_integration import ( - is_meta_schedule_dispatch_enabled, - is_meta_schedule_enabled, -) +from .relay_integration import is_meta_schedule_enabled from .runner import Runner from .schedule_rule import ScheduleRule from .search_strategy import MeasureCandidate, SearchStrategy diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 5e77181d32bf..df76684d2d42 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -377,7 +377,7 @@ def compile_relay( mod, target, params, pass_config, executor = _normalize_params( mod, target, params, pass_config, executor ) - pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", target.kind.name != "cuda") + pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True) with Profiler.timeit("PostTuningCompilation"): with target, _autotvm_silencer(), database: with transform.PassContext( @@ -404,17 +404,3 @@ def is_meta_schedule_enabled() -> bool: "relay.backend.use_meta_schedule", False, ) - - -def is_meta_schedule_dispatch_enabled() -> bool: - """Return whether the meta-schedule dispatch is enabled. - - Returns - ------- - enabled: bool - Whether the meta schedule is enabled - """ - return transform.PassContext.current().config.get( - "relay.backend.use_meta_schedule_dispatch", - False, - ) diff --git a/python/tvm/meta_schedule/schedule/__init__.py b/python/tvm/meta_schedule/schedule/__init__.py new file mode 100644 index 000000000000..0f5efce9ff65 --- /dev/null +++ b/python/tvm/meta_schedule/schedule/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Per-block schedule rules in MetaSchedule""" +from . import cpu, cuda, generic, x86 diff --git a/python/tvm/meta_schedule/schedule/cpu/__init__.py b/python/tvm/meta_schedule/schedule/cpu/__init__.py new file mode 100644 index 000000000000..ddc0155ee4f4 --- /dev/null +++ b/python/tvm/meta_schedule/schedule/cpu/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Per-block schedule rules in MetaSchedule for target key 'cpu'""" diff --git a/python/tvm/meta_schedule/schedule/cuda/__init__.py b/python/tvm/meta_schedule/schedule/cuda/__init__.py new file mode 100644 index 000000000000..937a6e16a91b --- /dev/null +++ b/python/tvm/meta_schedule/schedule/cuda/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Per-block schedule rules in MetaSchedule for target key 'cuda'""" diff --git a/python/tvm/meta_schedule/schedule/generic/__init__.py b/python/tvm/meta_schedule/schedule/generic/__init__.py new file mode 100644 index 000000000000..38ba5beb6772 --- /dev/null +++ b/python/tvm/meta_schedule/schedule/generic/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Per-block schedule rules in MetaSchedule for generic cases""" diff --git a/python/tvm/meta_schedule/schedule/x86/__init__.py b/python/tvm/meta_schedule/schedule/x86/__init__.py new file mode 100644 index 000000000000..d41979638078 --- /dev/null +++ b/python/tvm/meta_schedule/schedule/x86/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Per-block schedule rules in MetaSchedule for target key 'x86'""" diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index a015d0eb1ab2..5971ad53c48c 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -20,15 +20,16 @@ blocks in a schedule. See also PostOrderApply. """ from .add_rfactor import AddRFactor +from .apply_custom_rule import ApplyCustomRule from .auto_bind import AutoBind from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction from .multi_level_tiling import ( MultiLevelTiling, - MultiLevelTilingWithIntrin, - ReuseType, MultiLevelTilingTensorCore, MultiLevelTilingWideVector, + MultiLevelTilingWithIntrin, + ReuseType, ) from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py new file mode 100644 index 000000000000..29e25f992930 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Create a rule that applies customized rules registered using block attribute `schedule_rule`. +The rule will be dispatched according to target keys.""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.ApplyCustomRule") +class ApplyCustomRule(ScheduleRule): + """A rule that applies customized rules registered using block attribute `schedule_rule`. + The rule will be dispatched according to target keys.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleApplyCustomRule, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 65e7ddc468b5..44f32527fad9 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -58,7 +58,7 @@ class EvolutionarySearch(SearchStrategy): def __init__( self, *, - population_size: int = 2048, + population_size: int = 512, init_measured_ratio: float = 0.2, init_min_unmeasured: int = 50, max_fail_count: int = 5, diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py deleted file mode 100644 index d6242020726b..000000000000 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py +++ /dev/null @@ -1,172 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring - -from tvm.script import tir as T - -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument,chained-comparison,misplaced-comparison-constant - - -@T.prim_func -def conv2d_winograd_cpu( - X: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore - W: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore - conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore -) -> None: - # body - data_pad = T.alloc_buffer([1, 16, 16, 128]) - input_tile = T.alloc_buffer([6, 6, 9, 128]) - B = T.alloc_buffer([6, 6]) - data_pack = T.alloc_buffer([6, 6, 9, 128]) - bgemm = T.alloc_buffer([6, 6, 9, 128]) - A = T.alloc_buffer([6, 4]) - inverse = T.alloc_buffer([4, 4, 9, 128]) - for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): - with T.block("data_pad"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.block_attr({"schedule_rule": "None"}) - T.reads([X[i0_1, i1_1, i2_1, i3_1]]) - T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) - data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore - X[i0_1, i1_1, i2_1, i3_1], - T.float32(0), - dtype="float32", - ) - for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): - with T.block("input_tile"): - eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) - T.block_attr({"schedule_rule": "None"}) - T.reads( - data_pad[ - T.floordiv(p, 9), # type: ignore - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore - ((T.floormod(p, 3) * 4) + nu), # type: ignore - ci, - ] - ) - T.writes([input_tile[eps, nu, p, ci]]) - input_tile[eps, nu, p, ci] = data_pad[ - T.floordiv(p, 9), # type: ignore - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore - ((T.floormod(p, 3) * 4) + nu), # type: ignore - ci, - ] - for i0_3, i1_3 in T.grid(6, 6): - with T.block("B"): - i, j = T.axis.remap("SS", [i0_3, i1_3]) - T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) - T.writes([B[i, j]]) - # fmt: off - B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore - # fmt: on - for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): - with T.block("data_pack"): - eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( - "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] - ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.llvm"}) - T.reads( - [ - data_pack[eps_1, nu_1, p_1, ci_1], - input_tile[r_a, r_b, p_1, ci_1], - B[ - T.min(r_a, r_b) : ( # type: ignore - T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore - ), - T.min(eps_1, nu_1) : ( # type: ignore - T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore - ), - ], - ] - ) - T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) - with T.init(): - data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) - data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( - (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] - ) - for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): - with T.block("bgemm"): - eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) - T.block_attr({"meta_schedule.write_cache_level": [2]}) - T.reads( - [ - bgemm[eps_2, nu_2, p_2, co], - data_pack[eps_2, nu_2, p_2, ci_2], - W[eps_2, nu_2, co, ci_2], - ] - ) - T.writes([bgemm[eps_2, nu_2, p_2, co]]) - with T.init(): - bgemm[eps_2, nu_2, p_2, co] = T.float32(0) - bgemm[eps_2, nu_2, p_2, co] = ( - bgemm[eps_2, nu_2, p_2, co] - + data_pack[eps_2, nu_2, p_2, ci_2] * W[eps_2, nu_2, co, ci_2] - ) - for i0_6, i1_6 in T.grid(6, 4): - with T.block("A"): - i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) - T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) - T.writes([A[i_1, j_1]]) - # fmt: off - A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore - # fmt: on - for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): - with T.block("inverse"): - vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( - "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] - ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.llvm"}) - T.reads( - [ - inverse[vh, vw, p_3, co_1], - bgemm[r_a_1, r_b_1, p_3, co_1], - A[ - T.min(r_a_1, r_b_1) : ( # type: ignore - T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore - ), - T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore - ], - ] - ) - T.writes([inverse[vh, vw, p_3, co_1]]) - with T.init(): - inverse[vh, vw, p_3, co_1] = T.float32(0) - inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( - (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] - ) - for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): - with T.block("conv2d_winograd"): - n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) - T.reads( - [ - inverse[ - T.floormod(h, 4), # type: ignore - T.floormod(w, 4), # type: ignore - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore - co_2, - ] - ] - ) - T.writes([conv2d_winograd[n, h, w, co_2]]) - conv2d_winograd[n, h, w, co_2] = inverse[ - T.floormod(h, 4), # type: ignore - T.floormod(w, 4), # type: ignore - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore - co_2, - ] diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py deleted file mode 100644 index e737f9b04e62..000000000000 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py +++ /dev/null @@ -1,173 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring - -from tvm.script import tir as T - -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument,chained-comparison,misplaced-comparison-constant - - -@T.prim_func -def conv2d_winograd_cuda( # type: ignore - placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore - placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore - conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore -) -> None: - # type: ignore - data_pad = T.alloc_buffer([1, 16, 16, 128]) - input_tile = T.alloc_buffer([6, 6, 9, 128]) - B = T.alloc_buffer([6, 6]) - data_pack = T.alloc_buffer([6, 6, 9, 128]) - bgemm = T.alloc_buffer([6, 6, 9, 128]) - A = T.alloc_buffer([6, 4]) - inverse = T.alloc_buffer([4, 4, 9, 128]) - for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): - with T.block("data_pad"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.block_attr({"schedule_rule": "None"}) - T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) - T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) - data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore - placeholder[i0_1, i1_1, i2_1, i3_1], - T.float32(0), - dtype="float32", - ) - for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): - with T.block("input_tile"): - eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) - T.block_attr({"schedule_rule": "None"}) - T.reads( - [ - data_pad[ - T.floordiv(p, 9), # type: ignore - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore - ((T.floormod(p, 3) * 4) + nu), # type: ignore - ci, - ] - ] - ) - T.writes([input_tile[eps, nu, p, ci]]) - input_tile[eps, nu, p, ci] = data_pad[ - T.floordiv(p, 9), # type: ignore - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore - ((T.floormod(p, 3) * 4) + nu), # type: ignore - ci, - ] - for i0_3, i1_3 in T.grid(6, 6): - with T.block("B"): - i, j = T.axis.remap("SS", [i0_3, i1_3]) - T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) - T.writes([B[i, j]]) - # fmt: off - B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore - # fmt: on - for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): - with T.block("data_pack"): - eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( - "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] - ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) - T.reads( - [ - data_pack[eps_1, nu_1, p_1, ci_1], - input_tile[r_a, r_b, p_1, ci_1], - B[ - T.min(r_a, r_b) : ( # type: ignore - T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore - ), - T.min(eps_1, nu_1) : ( # type: ignore - T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore - ), - ], - ] - ) - T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) - with T.init(): - data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) - data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( - (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] - ) - for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): - with T.block("bgemm"): - eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) - T.block_attr({"meta_schedule.write_cache_level": [3]}) - T.reads( - [ - bgemm[eps_2, nu_2, p_2, co], - data_pack[eps_2, nu_2, p_2, ci_2], - placeholder_1[eps_2, nu_2, co, ci_2], - ] - ) - T.writes([bgemm[eps_2, nu_2, p_2, co]]) - with T.init(): - bgemm[eps_2, nu_2, p_2, co] = T.float32(0) - bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( - data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] - ) - for i0_6, i1_6 in T.grid(6, 4): - with T.block("A"): - i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) - T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) - T.writes([A[i_1, j_1]]) - # fmt: off - A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore - # fmt: on - for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): - with T.block("inverse"): - vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( - "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] - ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.cuda"}) - T.reads( - [ - inverse[vh, vw, p_3, co_1], - bgemm[r_a_1, r_b_1, p_3, co_1], - A[ - T.min(r_a_1, r_b_1) : ( # type: ignore - T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore - ), - T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore - ], - ] - ) - T.writes([inverse[vh, vw, p_3, co_1]]) - with T.init(): - inverse[vh, vw, p_3, co_1] = T.float32(0) - inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( - (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] - ) - for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): - with T.block("conv2d_winograd"): - n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) - T.reads( - [ - inverse[ - T.floormod(h, 4), # type: ignore - T.floormod(w, 4), # type: ignore - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore - co_2, - ] - ] - ) - T.writes([conv2d_winograd[n, h, w, co_2]]) - conv2d_winograd[n, h, w, co_2] = inverse[ - T.floormod(h, 4), # type: ignore - T.floormod(w, 4), # type: ignore - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore - co_2, - ] diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 6d1cd7f1604c..20abcfce3dc1 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -232,7 +232,6 @@ def get_network( inputs : Tuple[str, List[int], str] The name, shape and dtype of the input tensor. """ - mod: IRModule params: Dict[str, NDArray] inputs: Tuple[str, List[int], str] diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 5ac20f8fdf2f..0b7072b65afe 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -127,7 +127,7 @@ def check_sketches( def print_sketches(sketches: List[Schedule]): for i, sch in enumerate(sketches): print(f"###### {i}") - print(sch.mod.script()) + sch.mod.show() for inst in sch.trace.insts: if inst in sch.trace.decisions: print(f'("{inst.kind.name}", {sch.trace.decisions[inst]}),') diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 6fac1c2960ac..cdc430087542 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -19,6 +19,7 @@ from typing import Tuple from tvm import te, tir, topi +from tvm.target import Target def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring @@ -519,93 +520,68 @@ def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring stride: int = 1, padding: int = 0, dilation: int = 1, + tile_size: int = 4, ) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: - tile_size = 4 # _infer_tile_size(data, kernel) - inputs = te.placeholder((N, H, W, CI), name="inputs") - N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" - - KH = KW = kernel_size - HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride - assert HSTR == 1 and WSTR == 1 and KH == KW - - data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") - - r = KW - m = tile_size - alpha = m + r - 1 - A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") - - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW - _rkh = te.reduce_axis((0, KH), name="r_kh") - _rkw = te.reduce_axis((0, KW), name="r_kw") - kshape = (alpha, alpha, CI, CO) - kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") - - idxdiv = te.indexdiv - idxmod = te.indexmod - # pack input tile - input_tile = te.compute( - (alpha, alpha, P, CI), - lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ - idxmod(p, nW) * m + nu - ][ci], - name="input_tile", - ) - - # transform data - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_b") - data_pack = te.compute( - (alpha, alpha, P, CI), - lambda eps, nu, p, ci: te.sum( - input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] - ), - name="data_pack", - attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + from tvm.topi.nn.conv2d import ( # pylint: disable=import-outside-toplevel + _conv2d_winograd_nhwc_impl, ) - # do batch gemm - ci = te.reduce_axis((0, CI), name="ci") - bgemm = te.compute( - (alpha, alpha, P, CO), - lambda eps, nu, p, co: te.sum( - data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] - ), - name="bgemm", + target = Target.current(allow_none=True) + if target is not None and target.kind.name == "cuda": + write_cache_level = 3 + else: + write_cache_level = 2 + data = te.placeholder((N, H, W, CI), "float32", name="data") + weight = te.placeholder((kernel_size, kernel_size, CO, CI), "float32", name="weight") + out = _conv2d_winograd_nhwc_impl( + data, + weight, + stride, + padding, + dilation, + "float32", + pre_computed=True, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, + tile_size=tile_size, + write_cache_level=write_cache_level, + ) + return (data, weight, out) + + +def conv2d_winograd_nchw( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 1, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + from tvm.topi.cuda.conv2d_winograd import ( # pylint: disable=import-outside-toplevel + _infer_tile_size, ) - - # inverse transform - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_b") - inverse = te.compute( - (m, m, P, CO), - lambda vh, vw, p, co: te.sum( - bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] - ), - name="inverse", - attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + from tvm.topi.nn.conv2d import ( # pylint: disable=import-outside-toplevel + _conv2d_winograd_nchw_impl, ) - # output - output = te.compute( - (N, H, W, CO), - lambda n, h, w, co: inverse[ - idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co - ], - name="conv2d_winograd", + data = te.placeholder((N, CI, H, W), "float32", name="data") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), "float32", name="weight") + out = _conv2d_winograd_nchw_impl( + data, + weight, + stride, + padding, + dilation, + "float32", + pre_computed=True, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, + tile_size=_infer_tile_size(data, weight), ) - - return (inputs, kernel_pack, output) + return (data, weight, out) def matmul( @@ -833,7 +809,7 @@ def create_te_workload(name: str, idx: int) -> tir.PrimFunc: "T2D": ( conv2d_transpose_nhwc, [ - # all conv2d tranpose layers in DCGAN + # all conv2d transpose layers in DCGAN (1, 4, 4, 512, 256, 4, 2, 1), (1, 8, 8, 256, 128, 4, 2, 1), (1, 16, 16, 128, 64, 4, 2, 1), @@ -886,4 +862,16 @@ def create_te_workload(name: str, idx: int) -> tir.PrimFunc: (1, 128, 12, 128), ], ), + "C2D_WIN_NHWC": ( + conv2d_winograd_nhwc, + [ + (1, 14, 14, 128, 128, 6), + ], + ), + "C2D_WIN_NCHW": ( + conv2d_winograd_nchw, + [ + (1, 56, 56, 64, 64, 6), + ], + ), } diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 173f31ef08f9..5594e36cb855 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -24,7 +24,7 @@ import tvm from tvm import autotvm, te from tvm.auto_scheduler import is_auto_scheduler_enabled -from tvm.meta_schedule import is_meta_schedule_dispatch_enabled +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.runtime import Object from tvm.support import libinfo from tvm.target import Target @@ -181,7 +181,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) # Disable autotvm if auto_scheduler is enabled. # (i.e., always return the implementation with the highest priority for auto-scheduler). - if is_auto_scheduler_enabled() or is_meta_schedule_dispatch_enabled(): + if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): use_autotvm = False # If not use autotvm, always return the implementation with the highest priority diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 90a94c422992..53aec11e5816 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -459,7 +459,7 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts): # conv3d_winograd related operators reg.register_strategy( "nn.contrib_conv3d_winograd_without_weight_transform", - strategy.conv3d_winograd_without_weight_transfrom_strategy, + strategy.conv3d_winograd_without_weight_transform_strategy, ) @@ -733,7 +733,7 @@ def mirror_pad_func(attrs, inputs, _): # conv2d_winograd related operators reg.register_strategy( "nn.contrib_conv2d_winograd_without_weight_transform", - strategy.conv2d_winograd_without_weight_transfrom_strategy, + strategy.conv2d_winograd_without_weight_transform_strategy, ) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 011622d5374f..21252215fc28 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -162,14 +162,14 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): return strategy -@conv2d_winograd_without_weight_transfrom_strategy.register("adreno") -def conv2d_winograd_without_weight_transfrom_strategy_adreno(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom adreno strategy""" +@conv2d_winograd_without_weight_transform_strategy.register("adreno") +def conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform adreno strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout assert dilation == (1, 1), "Do not support dilate now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout in ("NCHW", "NCHW4c"): strategy.add_implementation( @@ -187,7 +187,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_adreno(attrs, inputs, out_ ) else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index b7650480d0e4..5c25696a1ee1 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -395,9 +395,9 @@ def _compute_conv2d_nnpack(attrs, inputs, out_type): return _compute_conv2d_nnpack -@conv2d_winograd_without_weight_transfrom_strategy.register("arm_cpu") -def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom arm cpu strategy""" +@conv2d_winograd_without_weight_transform_strategy.register("arm_cpu") +def conv2d_winograd_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform arm cpu strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout @@ -405,7 +405,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out kernel = inputs[1] assert dilation == (1, 1), "Do not support dilate now" assert strides == (1, 1), "Do not support strides now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout == "NCHW": if len(kernel.shape) == 5: @@ -436,7 +436,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out raise RuntimeError("Unsupported kernel shape: {}".format(kernel.shape)) else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy @@ -463,7 +463,7 @@ def _compute_conv2d_gemm(attrs, inputs, out_type): @conv2d_gemm_without_weight_transform_strategy.register("arm_cpu") def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom arm cpu strategy""" + """conv2d_winograd_without_weight_transform arm cpu strategy""" layout = attrs.data_layout data = inputs[0] strategy = _op.OpStrategy() diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index ec3edab2c8b1..46ebb6048c2d 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -100,16 +100,16 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): return strategy -@conv2d_winograd_without_weight_transfrom_strategy.register("bifrost") -def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom mali(bifrost) strategy""" +@conv2d_winograd_without_weight_transform_strategy.register("bifrost") +def conv2d_winograd_without_weight_transform_strategy_bifrost(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform mali(bifrost) strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout strides = attrs.get_int_tuple("strides") assert dilation == (1, 1), "Do not support dilate now" assert strides == (1, 1), "Do not support strides now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout == "NCHW": strategy.add_implementation( @@ -119,7 +119,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out ) else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9bedfe8cb038..312ec0fe2f97 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -145,7 +145,6 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): kernel_layout = attrs.kernel_layout if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") - if groups == 1: if layout == "NCHW": assert kernel_layout == "OIHW" @@ -166,9 +165,34 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), name="conv2d_nchw.cuda", ) - _, _, kh, kw = get_const_tuple(kernel.shape) - if ( - (2 < kh < 8 and 2 < kw < 8 and kh == kw) + N, _, H, W = get_const_tuple(data.shape) + CO, CI, KH, KW = get_const_tuple(kernel.shape) + (_, _, judge_winograd_auto_scheduler) = judge_winograd( + N, + H, + W, + KH, + KW, + CI, + CO, + padding, + stride_h, + stride_w, + dilation_h, + dilation_w, + data.dtype, + kernel.dtype, + pre_flag=False, + ) + if is_meta_schedule_enabled() and judge_winograd_auto_scheduler: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nchw), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nchw_winograd.cuda", + plevel=15, + ) + elif ( + (2 < KH < 8 and 2 < KW < 8 and KH == KW) and (stride_h == 1 and stride_w == 1) and (dilation_h == 1 and dilation_w == 1) ): @@ -490,9 +514,9 @@ def judge_winograd( return judge_winograd_tensorcore, judge_winograd_autotvm, judge_winograd_auto_scheduler -@conv2d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"]) -def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom cuda strategy""" +@conv2d_winograd_without_weight_transform_strategy.register(["cuda", "gpu"]) +def conv2d_winograd_without_weight_transform_strategy_cuda(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform cuda strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout @@ -500,14 +524,24 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty stride_h, stride_w = attrs.get_int_tuple("strides") padding = attrs.get_int_tuple("padding") assert dilation == (1, 1), "Do not support dilate now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout == "NCHW": - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform), - name="conv2d_nchw_winograd_without_weight_transform.cuda", - ) + if is_meta_schedule_enabled(): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nchw_without_weight_transform), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nchw_winograd_without_weight_transform", + plevel=15, + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform + ), + name="conv2d_nchw_winograd_without_weight_transform.cuda", + ) elif layout == "NHWC": N, H, W, _ = get_const_tuple(data.shape) alpha, _, CI, CO = get_const_tuple(kernel.shape) @@ -568,7 +602,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty ) else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy @@ -744,14 +778,14 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): return strategy -@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"]) -def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target): - """conv3d_winograd_without_weight_transfrom cuda strategy""" +@conv3d_winograd_without_weight_transform_strategy.register(["cuda", "gpu"]) +def conv3d_winograd_without_weight_transform_strategy_cuda(attrs, inputs, out_type, target): + """conv3d_winograd_without_weight_transform cuda strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout assert dilation == (1, 1, 1), "Do not support dilate now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout == "NCDHW": strategy.add_implementation( @@ -761,7 +795,7 @@ def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty ) else: raise RuntimeError( - "Unsupported conv3d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv3d_winograd_without_weight_transform layout {}".format(layout) ) return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6ab281abeb37..1cf55f7145cd 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -386,15 +386,15 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target): # conv2d_winograd_without_weight_transform @override_native_generic_func("conv2d_winograd_without_weight_transform_strategy") -def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom generic strategy""" +def conv2d_winograd_without_weight_transform_strategy(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform generic strategy""" raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform") # conv2d_gemm_without_weight_transform @override_native_generic_func("conv2d_gemm_without_weight_transform_strategy") def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target): - """conv2d_gemm_without_weight_transfrom generic strategy""" + """conv2d_gemm_without_weight_transform generic strategy""" raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform") @@ -619,8 +619,8 @@ def conv3d_strategy(attrs, inputs, out_type, target): # conv3d_winograd_without_weight_transform @override_native_generic_func("conv3d_winograd_without_weight_transform_strategy") -def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target): - """conv3d_winograd_without_weight_transfrom generic strategy""" +def conv3d_winograd_without_weight_transform_strategy(attrs, inputs, out_type, target): + """conv3d_winograd_without_weight_transform generic strategy""" raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform") diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index dca684835ba4..c39487b16d55 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -169,9 +169,9 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): return strategy -@conv2d_winograd_without_weight_transfrom_strategy.register("mali") -def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom mali strategy""" +@conv2d_winograd_without_weight_transform_strategy.register("mali") +def conv2d_winograd_without_weight_transform_strategy_mali(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform mali strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout @@ -179,7 +179,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty kernel = inputs[1] assert dilation == (1, 1), "Do not support dilate now" assert strides == (1, 1), "Do not support strides now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() if layout == "NCHW": assert len(kernel.shape) == 5, "Kernel must be packed into 5-dim" @@ -208,7 +208,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty ) else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6575e0f5c5a2..10d7fbb3a926 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -764,16 +764,16 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target): return strategy -@conv2d_winograd_without_weight_transfrom_strategy.register("cpu") -def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom cpu strategy""" +@conv2d_winograd_without_weight_transform_strategy.register("cpu") +def conv2d_winograd_without_weight_transform_strategy_cpu(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transform cpu strategy""" dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout strides = attrs.get_int_tuple("strides") assert dilation == (1, 1), "Do not support dilate now" assert strides == (1, 1), "Do not support strides now" - assert groups == 1, "Do not supoort arbitrary group number" + assert groups == 1, "Do not support arbitrary group number" strategy = _op.OpStrategy() need_auto_scheduler_layout = is_auto_scheduler_enabled() need_meta_schedule_layout = is_meta_schedule_enabled() @@ -802,7 +802,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ raise RuntimeError("Both AutoScheduler and MetaSchedule are not enabled") else: raise RuntimeError( - "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + "Unsupported conv2d_winograd_without_weight_transform layout {}".format(layout) ) return strategy diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 35d50eb3673c..93512ca07d9e 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -18,15 +18,15 @@ """Conv2D alter op and legalize functions for cuda backend""" import logging + import tvm -from tvm import te, relay, autotvm +from tvm import autotvm, relay, te from .. import nn +from ..nn import conv2d_legalize from ..utils import get_const_tuple, is_target from .conv2d_winograd import _infer_tile_size from .tensorcore_alter_op import pad_to_tensorcore -from ..nn import conv2d_legalize - logger = logging.getLogger("topi") @@ -61,24 +61,38 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): logger.warning("Does not support weight pre-transform for dilated convolution.") return None - assert data_layout == "NHWC" and kernel_layout == "HWIO" - N, H, W, CI = get_const_tuple(data.shape) - KH, KW, _, CO = get_const_tuple(kernel.shape) - - # Pre-compute weight transformation in winograd - tile_size = _infer_tile_size(tinfos[0], tinfos[1], layout="NHWC") - - # HWIO -> OIHW - kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) - # alpha, alpha, CO, CI - weight = relay.nn.contrib_conv2d_winograd_weight_transform( - kernel_transform, tile_size=tile_size - ) - new_attrs["tile_size"] = tile_size - new_attrs["channels"] = CO - return relay.nn.contrib_conv2d_winograd_without_weight_transform( - inputs[0], weight, **new_attrs - ) + if data_layout == "NHWC" and kernel_layout == "HWIO": + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + # Pre-compute weight transformation in winograd + tile_size = _infer_tile_size(tinfos[0], tinfos[1], layout="NHWC") + # HWIO -> OIHW + kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) + # alpha, alpha, CO, CI + weight = relay.nn.contrib_conv2d_winograd_weight_transform( + kernel_transform, tile_size=tile_size + ) + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + elif data_layout == "NCHW" and kernel_layout == "OIHW": + N, CI, H, W = get_const_tuple(data.shape) + CO, _, KH, KW = get_const_tuple(kernel.shape) + # Pre-compute weight transformation in winograd + tile_size = _infer_tile_size(tinfos[0], tinfos[1], layout="NCHW") + # alpha, alpha, CO, CI + weight = relay.nn.contrib_conv2d_winograd_weight_transform( + inputs[1], tile_size=tile_size + ) + # alpha, alpha, CI, CO + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) return None diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index 8accbbe53273..77b332400d0b 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -408,7 +408,6 @@ def nhwc_winograd_cuda( input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", - attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}, ) # Convert data type of input feature maps and weights for tensorcore @@ -433,14 +432,13 @@ def nhwc_winograd_cuda( # Inverse transform r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") inverse = te.compute( (P, CO, m, m), lambda p, co, vh, vw: te.sum( bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", - attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"}, ) # Output diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index 239d05844b40..eca51c921016 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -23,7 +23,12 @@ from tvm import autotvm, te from .. import nn -from ..nn.conv2d import _conv2d_winograd_nhwc_impl, conv2d_winograd_nhwc +from ..nn.conv2d import ( + _conv2d_winograd_nchw_impl, + _conv2d_winograd_nhwc_impl, + conv2d_winograd_nchw, + conv2d_winograd_nhwc, +) from ..nn.winograd_util import winograd_transform_matrices from ..utils import get_const_int, get_const_tuple, traverse_inline @@ -104,7 +109,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] ), name="kernel_pack", - attrs={"schedule_rule": "meta_schedule.winograd_kernel_pack.nchw.cuda"}, ) else: kernel_pack = kernel @@ -129,7 +133,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", - attrs={"schedule_rule": "meta_schedule.winograd_data_pack.nchw.cuda"}, ) # do batch gemm @@ -151,7 +154,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", - attrs={"schedule_rule": "meta_schedule.winograd_inverse.nchw.cuda"}, ) # output @@ -162,10 +164,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ ], name="output", tag="conv2d_nchw_winograd", - attrs={ - "schedule_rule": "meta_schedule.winograd_output.nchw.cuda", - "winograd_tile_size": alpha - 3 + 1, - }, ) if isinstance(N, int): @@ -391,3 +389,24 @@ def conv2d_winograd_nhwc_cuda( return _conv2d_winograd_nhwc_impl( data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed ) + + +@conv2d_winograd_nchw.register(["cuda", "gpu"]) +def conv2d_winograd_nchw_cuda( + data, + weight, + strides, + padding, + dilation, + out_dtype, + pre_computed=False, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, +): + """Conv2D Winograd in NCHW layout. + This is a clean version to be used by the auto-scheduler for both CPU and GPU. + """ + tile_size = _infer_tile_size(data, weight, layout="NCHW") + return _conv2d_winograd_nchw_impl( + data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed + ) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 5070c84c7e51..db1bcaa27694 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -548,7 +548,7 @@ def conv2d_NCHWc_int8( ), name="conv2d_NCHWc_int8", tag="conv2d_NCHWc_int8", - attrs={"schedule_rule": "meta_schedule.conv2d_NCHWc_int8"}, + attrs={"schedule_rule": "conv2d_NCHWc_int8"}, ) # for int8 group conv support ic_chunk = in_channel // ic_bn @@ -571,7 +571,7 @@ def conv2d_NCHWc_int8( ), name="conv2d_NCHWc_int8", tag="conv2d_NCHWc_int8", - attrs={"schedule_rule": "meta_schedule.conv2d_NCHWc_int8"}, + attrs={"schedule_rule": "conv2d_NCHWc_int8"}, ) @@ -989,6 +989,119 @@ def unpack_NCHWc_to_nchw(packed_out, out_dtype): return unpacked_out +@tvm.target.generic_func +def conv2d_winograd_nhwc( + data, + weight, + strides, + padding, + dilation, + out_dtype, + pre_computed=False, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, +): + """Conv2D Winograd in NHWC layout. + This is a clean version to be used by the auto-scheduler for both CPU and GPU. + + Parameters + ---------- + data : tvm.te.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + weight : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + strides : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + out_dtype : str, optional + Specifies the output data type. + pre_computed: bool + Whether the kernel is precomputed + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + tile_size = 4 + return _conv2d_winograd_nhwc_impl( + data, + weight, + strides, + padding, + dilation, + out_dtype, + tile_size, + pre_computed=pre_computed, + write_cache_level=2, + auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout, + meta_schedule_original_shape=meta_schedule_original_shape, + ) + + +@tvm.target.generic_func +def conv2d_winograd_nchw( + data, + weight, + strides, + padding, + dilation, + out_dtype, + pre_computed=False, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, +): + """Conv2D Winograd in NCHW layout. + This is a clean version to be used by the auto-scheduler for both CPU and GPU. + + Parameters + ---------- + data : tvm.te.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + weight : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + strides : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + out_dtype : str, optional + Specifies the output data type. + pre_computed: bool + Whether the kernel is precomputed + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + tile_size = 4 + return _conv2d_winograd_nchw_impl( + data, + weight, + strides, + padding, + dilation, + out_dtype, + tile_size, + pre_computed, + auto_scheduler_rewritten_layout, + meta_schedule_original_shape, + ) + + def _conv2d_winograd_nhwc_impl( data, weight, @@ -998,6 +1111,7 @@ def _conv2d_winograd_nhwc_impl( out_dtype, tile_size, pre_computed=False, + write_cache_level=None, auto_scheduler_rewritten_layout="", meta_schedule_original_shape=None, ): @@ -1022,6 +1136,8 @@ def _conv2d_winograd_nhwc_impl( The size of the tile to use for the Winograd filter pre_computed: bool = False Whether the kernel is precomputed + write_cache_level: Optional[int] = None + The cache level to write to in multi-level tiling rule in MetaSchedule. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. meta_schedule_original_shape: Optional[List[PrimExpr]] = None @@ -1085,45 +1201,48 @@ def _conv2d_winograd_nhwc_impl( kernel_pack = te.compute( (alpha, alpha, CO, CI), lambda eps, nu, co, ci: te.sum( - weight[r_kh][r_kw][ci][co] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + weight[r_kh, r_kw, ci, co] * G[eps, r_kh] * G[nu, r_kw], + axis=[r_kh, r_kw], ), name="kernel_pack", ) - attrs = {} + bgemm_attrs = {} else: kernel_pack = weight - attrs = {"layout_free_placeholders": [kernel_pack]} + bgemm_attrs = {"layout_free_placeholders": [kernel_pack]} + if write_cache_level is not None: + if not isinstance(write_cache_level, int): + bgemm_attrs["meta_schedule.write_cache_level"] = write_cache_level + else: + bgemm_attrs["meta_schedule.write_cache_level"] = [write_cache_level] # pack data tile input_tile = te.compute( (alpha, alpha, P, CI), - lambda eps, nu, p, ci: data_pad[p // (nH * nW)][((p // nW) % nH) * m + eps][ - (p % nW) * m + nu - ][ci], + lambda eps, nu, p, ci: data_pad[ + p // (nH * nW), + ((p // nW) % nH) * m + eps, + (p % nW) * m + nu, + ci, + ], name="input_tile", attrs={"schedule_rule": "None"}, ) # transform data - target = tvm.target.Target.current(allow_none=True) - if target is not None: - target_kind = "meta_schedule.winograd_data_pack." + target.kind.name - else: - target_kind = "None" - r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") data_pack = te.compute( (alpha, alpha, P, CI), lambda eps, nu, p, ci: te.sum( - input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu], + axis=[r_a, r_b], ), name="data_pack", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"], - "schedule_rule": target_kind, + "schedule_rule": "conv2d_nhwc_winograd_data_pack", }, - # the attrs are necessary hints for the auto-scheduler ) # do batch gemm @@ -1131,59 +1250,211 @@ def _conv2d_winograd_nhwc_impl( bgemm = te.compute( (alpha, alpha, P, CO), lambda eps, nu, p, co: te.sum( - data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][co][ci], axis=[ci] + data_pack[eps, nu, p, ci] * kernel_pack[eps, nu, co, ci], + axis=[ci], ), name="bgemm", - attrs=attrs, + attrs=bgemm_attrs, ) if auto_scheduler_rewritten_layout: bgemm = auto_scheduler.rewrite_compute_body(bgemm, auto_scheduler_rewritten_layout) # inverse transform - if target is not None: - target_kind = "meta_schedule.winograd_inverse." + target.kind.name - else: - target_kind = "None" r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") inverse = te.compute( (m, m, P, CO), lambda vh, vw, p, co: te.sum( - bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw], + axis=[r_a, r_b], ), name="inverse", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"], - "schedule_rule": target_kind, + "schedule_rule": "conv2d_nhwc_winograd_inverse", }, - # the attrs are necessary hints for the auto-scheduler ) # output output = te.compute( (N, H, W, CO), - lambda n, h, w, co: inverse[h % m, w % m, n * nH * nW + (h // m) * nW + (w // m), co], + lambda n, h, w, co: inverse[ + h % m, + w % m, + n * nH * nW + (h // m) * nW + (w // m), + co, + ], name="conv2d_winograd", ) return output -@tvm.target.generic_func -def conv2d_winograd_nhwc( +def _conv2d_winograd_nchw_impl( data, weight, strides, padding, dilation, out_dtype, + tile_size, pre_computed=False, + write_cache_level=None, auto_scheduler_rewritten_layout="", meta_schedule_original_shape=None, ): - """Conv2D Winograd in NHWC layout. + """ + write_cache_level: Optional[int] = None + The cache level to write to in multi-level tiling rule in MetaSchedule. + """ + del auto_scheduler_rewritten_layout + + N, CI, H, W = get_const_tuple(data.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + if meta_schedule_original_shape: + auto_scheduler.rewrite_tensor_shape(weight, meta_schedule_original_shape) + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # kernel tensor is raw tensor, do strict check + CO, CI, KH, KW = get_const_tuple(weight.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + alpha, _, CI, CO = get_const_tuple(weight.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + pad_t, pad_l, pad_b, pad_r = get_pad_tuple(padding, (KH, KW)) + assert HSTR == 1 and WSTR == 1 and KH == 3 and KW == 3 + + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + data_pad = pad( + data, + (0, 0, pt, pl), + (0, 0, pb, pr), + name="data_pad", + ) + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + + P = N * nH * nW if isinstance(N, int) else nH * nW + + # transform kernel + if not pre_computed: + r_kh = te.reduce_axis((0, KH), name="r_kh") + r_kw = te.reduce_axis((0, KW), name="r_kw") + kernel_pack = te.compute( + (alpha, alpha, CI, CO), + lambda eps, nu, ci, co: te.sum( + weight[co, ci, r_kh, r_kw] * G[eps, r_kh] * G[nu, r_kw], + axis=[r_kh, r_kw], + ), + name="kernel_pack", + ) + bgemm_attrs = {} + else: + kernel_pack = weight + bgemm_attrs = {"layout_free_placeholders": [kernel_pack]} + if write_cache_level is not None: + if not isinstance(write_cache_level, int): + bgemm_attrs["meta_schedule.write_cache_level"] = write_cache_level + else: + bgemm_attrs["meta_schedule.write_cache_level"] = [write_cache_level] + + # pack data tile + input_tile = te.compute( + (CI, P, alpha, alpha), + lambda ci, p, eps, nu: data_pad[ + p // (nH * nW), + ci, + ((p // nW) % nH) * m + eps, + (p % nW) * m + nu, + ], + name="input_tile", + attrs={"schedule_rule": "None"}, + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, CI, P), + lambda eps, nu, ci, p: te.sum( + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu], + axis=[r_a, r_b], + ), + name="data_pack", + attrs={ + "schedule_rule": "conv2d_nchw_winograd_data_pack", + }, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, CO, P), + lambda eps, nu, co, p: te.sum( + data_pack[eps, nu, ci, p] * kernel_pack[eps, nu, ci, co], + axis=[ci], + ), + name="bgemm", + attrs=bgemm_attrs, + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (CO, P, m, m), + lambda co, p, vh, vw: te.sum( + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw], + axis=[r_a, r_b], + ), + name="inverse", + attrs={ + "schedule_rule": "conv2d_nchw_winograd_inverse", + }, + ) + + # output + output = te.compute( + (N, CO, H, W), + lambda n, co, h, w: inverse[ + co, + n * nH * nW + (h // m) * nW + (w // m), + h % m, + w % m, + ], + name="conv2d_winograd", + ) + + return output + + +def conv2d_winograd_nhwc_without_weight_transform( + data, + weight, + strides, + padding, + dilation, + out_dtype, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, +): + """Conv2D Winograd without layout transform in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. Parameters @@ -1200,8 +1471,6 @@ def conv2d_winograd_nhwc( dilation size, or [dilation_height, dilation_width] out_dtype : str, optional Specifies the output data type. - pre_computed: bool - Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. meta_schedule_original_shape: Optional[List[PrimExpr]] = None @@ -1212,23 +1481,21 @@ def conv2d_winograd_nhwc( output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ - tile_size = 4 - return _conv2d_winograd_nhwc_impl( + return conv2d_winograd_nhwc( data, weight, strides, padding, dilation, out_dtype, - tile_size, - pre_computed, - auto_scheduler_rewritten_layout, - meta_schedule_original_shape, + pre_computed=True, + auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout, + meta_schedule_original_shape=meta_schedule_original_shape, ) -def conv2d_winograd_nhwc_without_weight_transform( +def conv2d_winograd_nchw_without_weight_transform( data, weight, strides, @@ -1238,8 +1505,8 @@ def conv2d_winograd_nhwc_without_weight_transform( auto_scheduler_rewritten_layout="", meta_schedule_original_shape=None, ): - """Conv2D Winograd without layout transform in NHWC layout. - This is a clean version to be used by the auto-scheduler for both CPU and GPU. + """Conv2D Winograd without layout transform in NCHW layout. + This is a clean version to be used by meta-schedule for both CPU and GPU. Parameters ---------- @@ -1265,8 +1532,7 @@ def conv2d_winograd_nhwc_without_weight_transform( output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ - - return conv2d_winograd_nhwc( + return conv2d_winograd_nchw( data, weight, strides, diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 1fd842f2d4cc..8251dac4137b 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -24,6 +24,7 @@ import tvm from tvm import te from tvm.tir import bijective_layout, layout + from . import cpp, tag @@ -325,7 +326,7 @@ def unravel_index(idx, shape): return indices -def const_matrix(matrix, name="const_matrix"): +def const_matrix(matrix, name="const_matrix", attrs=None): """convert a const numpy 2-dimensional matrix to tvm tensor Parameters @@ -355,14 +356,17 @@ def select_array(i, j): ) return now + if attrs is None: + attrs = { + "const_matrix": True, + "schedule_rule": "None", + } + return te.compute( matrix.shape, select_array, name=name, - attrs={ - "const_matrix": True, - "schedule_rule": "meta_schedule.compute_inline", - }, + attrs=attrs, ) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 3d64239044e2..025f41660c9c 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -17,13 +17,13 @@ # pylint: disable=invalid-name,too-many-locals,unused-variable """x86 batch_matmul operators""" import tvm -from tvm import te -from tvm import autotvm +from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl + from .. import generic, nn from ..transform import layout_transform -from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor +from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline from .dense import dense_vnni_schedule from .injective import schedule_injective_from_existing @@ -47,7 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y, *_): axis=ak, ), tag="batch_matmul_vnni", - attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"}, + attrs={"schedule_rule": "batch_matmul_vnni"}, ) _, a_y, _ = z.op.axis diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 88a2499c2c1e..8ddb8d7a5c9a 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -18,18 +18,16 @@ # pylint: disable=no-value-for-parameter """x86 dense operators""" from __future__ import absolute_import as _abs + import tvm -from tvm import te -from tvm import autotvm +from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity -from tvm.contrib import cblas -from tvm.contrib import mkl -from tvm.contrib import dnnl +from tvm.contrib import cblas, dnnl, mkl -from .utils import get_simd_32bit_lanes from .. import generic, tag -from ..utils import traverse_inline, get_const_tuple +from ..utils import get_const_tuple, traverse_inline from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from .utils import get_simd_32bit_lanes def _schedule_dense_pack_template(cfg, s, C, O): @@ -296,7 +294,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_vnni", - attrs={"schedule_rule": "meta_schedule.dense_vnni"}, + attrs={"schedule_rule": "dense_vnni"}, ) if bias is not None: diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 1ba68538ea04..27ce34a8cb27 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -16,7 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include "../schedule_rule/auto_bind.h" +#include + #include "../utils.h" namespace tvm { diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc new file mode 100644 index 000000000000..16e53b56923a --- /dev/null +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../../utils.h" + +namespace tvm { +namespace meta_schedule { + +using namespace tvm::tir; + +static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { + using namespace tvm::tir; + ICHECK_EQ(tiled.size(), 2); + ICHECK_EQ(unrolled.size(), 4); + Array factors{nullptr}; + Array loops = sch->GetLoops(block); + ICHECK_EQ(loops.size(), 6); + + factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); + Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ICHECK_EQ(t0.size(), 2); + + factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); + Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ICHECK_EQ(t1.size(), 2); + + sch->Unroll(loops[unrolled[0]]); + sch->Unroll(loops[unrolled[1]]); + sch->Unroll(loops[unrolled[2]]); + sch->Unroll(loops[unrolled[3]]); + sch->Reorder({ + t0[0], + t1[0], + t0[1], + t1[1], + loops[unrolled[0]], + loops[unrolled[1]], + loops[unrolled[2]], + loops[unrolled[3]], + }); + return {t0[0], t1[0], t0[1], t1[1]}; +} + +TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), + /*preserve_unit_loops=*/true); + sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), + /*preserve_unit_loops=*/true); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") + .set_body_typed([](Schedule sch, BlockRV block) -> Array { + GetWinogradProducerAndInlineConst(sch, block); + ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), + /*preserve_unit_loops=*/true); + sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), + /*preserve_unit_loops=*/true); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") + .set_body_typed([](Schedule sch, BlockRV block) -> Array { + GetWinogradProducerAndInlineConst(sch, block); + ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); + return {sch}; + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc new file mode 100644 index 000000000000..e5dd5068783d --- /dev/null +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" + +namespace tvm { +namespace meta_schedule { + +using namespace tvm::tir; + +std::function MakeFactorSampler(Schedule sch, Array thread_extents) { + return [sch = std::move(sch), + thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { + Array extents; + extents.reserve(thread_extents.size()); + for (const Integer extent : thread_extents) { + if (extent->value <= max_extent) { + extents.push_back(extent); + } + } + int n = extents.size(); + if (n == 0) { + return Integer(max_extent); + } + if (n == 1) { + return Integer(extents[0]); + } + Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + return sch->SampleCategorical(extents, probs); + }; +} + +Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, + int64_t max_threads_per_block, + std::function get_factor) { + int64_t extent = -1; + if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { + extent = *e; + } else { + extent = std::numeric_limits::max(); + } + if (extent <= max_threadblocks * max_threads_per_block) { + if (!get_factor) { + get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); + } + ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); + Array splits = sch->Split(loop, {NullOpt, factor}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + return {splits[0], splits[1]}; + } else { + Array splits = sch->Split(loop, {NullOpt, + Integer(max_threadblocks), // + Integer(max_threads_per_block)}); + ICHECK_EQ(splits.size(), 3); + sch->Reorder({splits[1], splits[2], splits[0]}); + sch->Bind(splits[1], "blockIdx.x"); + sch->Bind(splits[2], "threadIdx.x"); + return {splits[1], splits[2]}; + } +} + +void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + if (block_sref->parent == nullptr) { + return; + } + if (tir::HasBeenMultiLevelTiled(block_sref)) { + return; + } + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + int i_spatial_loop = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (loop->kind != ForKind::kSerial) { + if (i_multi_child == -1) { + i_multi_child = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + if (GetLoopIterType(loop_sref) == IterVarType::kDataPar) { + if (i_spatial_loop == i - 1) { + ++i_spatial_loop; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if (i_block_idx != -1 && i_thread_idx != -1) { + return; + } + if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } + LoopRV loop_rv{nullptr}; + { + Array loop_rvs = sch->GetLoops(block_rv); + if (i_spatial_loop == -1) { + LoopRV spatial_loop_rv{nullptr}; + if (loop_rvs.empty()) { + spatial_loop_rv = sch->AddUnitLoop(block_rv); + } else { + spatial_loop_rv = sch->AddUnitLoop(loop_rvs[0]); + } + loop_rvs.insert(loop_rvs.begin(), spatial_loop_rv); + i_spatial_loop = 0; + if (i_block_idx != -1) { + i_block_idx += 1; + } + if (i_thread_idx != -1) { + i_thread_idx += 1; + } + if (i_multi_child != -1) { + i_multi_child += 1; + } + } + if (i_block_idx == -1 && i_thread_idx != -1) { + int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + Array loop_rvs = sch->GetLoops(block_rv); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + sch->Bind(loop_rv, "blockIdx.x"); + return; + } else { // i_block_idx == -1 && i_thread_idx == -1 + int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + } + } + BindSpatialLoop(sch, loop_rv, max_threadblocks, max_threads_per_block, get_factor); +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc new file mode 100644 index 000000000000..5334c4df2ac9 --- /dev/null +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include + +#include "../../utils.h" + +namespace tvm { +namespace meta_schedule { + +using namespace tvm::tir; + +static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { + // This method is used for NHWC layout only. Will likely be refactored into a more schedule + using namespace tvm::tir; + ICHECK_EQ(tiled.size(), 2); + ICHECK_EQ(unrolled.size(), 4); + Array factors{nullptr}; + Array loops = sch->GetLoops(block); + ICHECK_EQ(loops.size(), 6); + + factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); + Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ICHECK_EQ(t0.size(), 2); + + factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); + Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ICHECK_EQ(t1.size(), 2); + + sch->Unroll(loops[unrolled[0]]); + sch->Unroll(loops[unrolled[1]]); + sch->Unroll(loops[unrolled[2]]); + sch->Unroll(loops[unrolled[3]]); + sch->Reorder({ + t0[0], + t1[0], + t0[1], + t1[1], + loops[unrolled[0]], + loops[unrolled[1]], + loops[unrolled[2]], + loops[unrolled[3]], + }); + return {t0[0], t1[0], t0[1], t1[1]}; +} + +TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + { + BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); + } + { + sch->ComputeAt(input_tile, /*loop_rv=*/loops.back(), /*preserve_unit_loops=*/true); + sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); + sch->ComputeInline(data_pad); + } + { + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + Array loops = sch->GetLoops(data_pack); + ICHECK_EQ(loops.size(), 8); + BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, + max_threads_per_block); + } + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") + .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { + GetWinogradProducerAndInlineConst(sch, inverse); + ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + Array loops = sch->GetLoops(inverse); + ICHECK_EQ(loops.size(), 8); + BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, + max_threads_per_block); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + LoopRV outer{nullptr}; + { + Array loops = sch->GetLoops(data_pack); + ICHECK_EQ(loops.size(), 6); + sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); + sch->Unroll(loops[0]); + sch->Unroll(loops[1]); + sch->Unroll(loops[4]); + sch->Unroll(loops[5]); + outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, + max_threads_per_block)[1]; + } + { + BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true); + } + { + sch->ComputeAt(input_tile, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); + sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); + sch->ComputeInline(data_pad); + } + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") + .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { + GetWinogradProducerAndInlineConst(sch, inverse); + // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] + int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; + LoopRV outer{nullptr}; + { + BlockRV output = sch->GetConsumers(inverse)[0]; + Array nchw = sch->GetLoops(output); + ICHECK_EQ(nchw.size(), 4); + Array hs = sch->Split(nchw[2], {NullOpt, Integer(tile_size)}); + Array ws = sch->Split(nchw[3], {NullOpt, Integer(tile_size)}); + sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); + outer = ws[0]; + } + { + sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); + sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); + Array loops = sch->GetLoops(inverse); + ICHECK_EQ(loops.size(), 10); + sch->Unroll(loops[6]); + sch->Unroll(loops[7]); + sch->Unroll(loops[8]); + sch->Unroll(loops[9]); + } + return {sch}; + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule/generic/winograd.cc b/src/meta_schedule/schedule/generic/winograd.cc new file mode 100644 index 000000000000..edb14667bcec --- /dev/null +++ b/src/meta_schedule/schedule/generic/winograd.cc @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +namespace tvm { +namespace meta_schedule { + +using namespace tvm::tir; + +/*! + * \brief Get the producer block of a given block. + * If there is a constant winograd transform matrix, inline it. + * \return The only producer block. + */ +BlockRV GetWinogradProducerAndInlineConst(Schedule sch, BlockRV block) { + Array producers = sch->GetProducers(block); + Array results; + for (const BlockRV& producer : producers) { + if (sch->Get(producer)->reads.empty()) { + sch->ComputeInline(producer); + } else { + results.push_back(producer); + } + } + ICHECK_EQ(results.size(), 1); + return results[0]; +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule/x86/.gitignore b/src/meta_schedule/schedule/x86/.gitignore new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc new file mode 100644 index 000000000000..4b0fa675acc7 --- /dev/null +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class ApplyCustomRuleNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: Target is not defined in the tune context."; + this->target_ = context->target; + } + + static std::string GetCustomRuleName(const std::string& name, const std::string& key) { + return "meta_schedule." + key + "." + name; + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + CHECK(this->target_.defined()) + << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; + Array keys = this->target_.value()->keys; + if (Optional ann = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + if (ann.value() != "None") { + for (const String& key : keys) { + if (const runtime::PackedFunc* custom_schedule_fn = + runtime::Registry::Get(GetCustomRuleName(ann.value(), key))) { + Array result = ((*custom_schedule_fn)(sch, block_rv)); + return result; + } + } + std::ostringstream os; + os << "Unknown schedule rule \"" << ann.value() << "\" for target keys \"" << keys + << "\". Checked PackedFuncs:"; + for (const String& key : keys) { + os << "\n " << GetCustomRuleName(ann.value(), key); + } + LOG(WARNING) << os.str(); + } + } + return {sch}; + } + + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = make_object(*this); + n->target_ = target_; + return ScheduleRule(n); + } + + public: + Optional target_ = NullOpt; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target_", &target_); } + + static constexpr const char* _type_key = "meta_schedule.ApplyCustomRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(ApplyCustomRuleNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::ApplyCustomRule() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { + return rule->IsInstance(); +} + +TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleApplyCustomRule") + .set_body_typed(ScheduleRule::ApplyCustomRule); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 4d16a6d4d65d..fa47d1edb860 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include "./auto_bind.h" +#include #include #include @@ -26,142 +26,6 @@ namespace tvm { namespace meta_schedule { -void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, - int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor) { - using namespace tvm::tir; - StmtSRef block_sref = sch->GetSRef(block_rv); - if (block_sref->parent == nullptr) { - return; - } - if (tir::HasBeenMultiLevelTiled(block_sref)) { - return; - } - Array loops = tir::GetLoops(block_sref); - int n = loops.size(); - int i_block_idx = -1; - int i_thread_idx = -1; - int i_multi_child = -1; - int i_spatial_loop = -1; - for (int i = 0; i < n; ++i) { - const StmtSRef& loop_sref = loops[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - runtime::ThreadScope thread_scope = GetThreadScope(loop); - if (IsBlockIdx(thread_scope)) { - if (i_block_idx == -1) { - i_block_idx = i; - } - } - if (IsThreadIdx(thread_scope)) { - if (i_thread_idx == -1) { - i_thread_idx = i; - } - } - if (loop->kind != ForKind::kSerial) { - if (i_multi_child == -1) { - i_multi_child = i; - } - } - if (!IsSingleStmt(loop->body)) { - if (i_multi_child == -1) { - i_multi_child = i + 1; - } - } - if (GetLoopIterType(loop_sref) == IterVarType::kDataPar) { - if (i_spatial_loop == i - 1) { - ++i_spatial_loop; - } - } - } - if (i_multi_child == -1) { - i_multi_child = n; - } - if (i_block_idx != -1 && i_thread_idx != -1) { - return; - } - if (i_block_idx != -1 && i_thread_idx == -1) { - ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; - throw; - } - LoopRV loop_rv{nullptr}; - { - Array loop_rvs = sch->GetLoops(block_rv); - if (i_spatial_loop == -1) { - LoopRV spatial_loop_rv{nullptr}; - if (loop_rvs.empty()) { - spatial_loop_rv = sch->AddUnitLoop(block_rv); - } else { - spatial_loop_rv = sch->AddUnitLoop(loop_rvs[0]); - } - loop_rvs.insert(loop_rvs.begin(), spatial_loop_rv); - i_spatial_loop = 0; - if (i_block_idx != -1) { - i_block_idx += 1; - } - if (i_thread_idx != -1) { - i_thread_idx += 1; - } - if (i_multi_child != -1) { - i_multi_child += 1; - } - } - if (i_block_idx == -1 && i_thread_idx != -1) { - int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); - Array loop_rvs = sch->GetLoops(block_rv); - loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); - sch->Bind(loop_rv, "blockIdx.x"); - return; - } else { // i_block_idx == -1 && i_thread_idx == -1 - int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); - loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); - } - } - int64_t extent = -1; - if (const int64_t* e = GetLoopIntExtent(sch->Get(loop_rv).get())) { - extent = *e; - } else { - extent = std::numeric_limits::max(); - } - if (extent <= max_threadblocks * max_threads_per_block) { - ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); - Array splits = sch->Split(loop_rv, {NullOpt, factor}); - ICHECK_EQ(splits.size(), 2); - sch->Bind(splits[0], "blockIdx.x"); - sch->Bind(splits[1], "threadIdx.x"); - } else { - Array splits = sch->Split(loop_rv, {NullOpt, - Integer(max_threadblocks), // - Integer(max_threads_per_block)}); - ICHECK_EQ(splits.size(), 3); - sch->Reorder({splits[1], splits[2], splits[0]}); - sch->Bind(splits[1], "blockIdx.x"); - sch->Bind(splits[2], "threadIdx.x"); - } -} - -std::function MakeFactorSampler(tir::Schedule sch, - Array thread_extents) { - return [sch = std::move(sch), - thread_extents = std::move(thread_extents)](int64_t max_extent) -> tir::ExprRV { - Array extents; - extents.reserve(thread_extents.size()); - for (const Integer extent : thread_extents) { - if (extent->value <= max_extent) { - extents.push_back(extent); - } - } - int n = extents.size(); - if (n == 0) { - return Integer(max_extent); - } - if (n == 1) { - return Integer(extents[0]); - } - Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); - return sch->SampleCategorical(extents, probs); - }; -} - class AutoBindNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 8e4642b50ddb..141b93be5e34 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -53,7 +53,15 @@ ScheduleRule ScheduleRule::PyScheduleRule( Array ScheduleRule::DefaultLLVM() { return { - GetDefaultAutoInline("llvm"), + ScheduleRule::ApplyCustomRule(), + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -78,6 +86,7 @@ Array ScheduleRule::DefaultLLVM() { Array ScheduleRule::DefaultCUDA() { return { + ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTiling( /*structure=*/"SSSRRSRS", /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, @@ -91,7 +100,14 @@ Array ScheduleRule::DefaultCUDA() { Map{{"req", String("must")}, {"levels", Array{3}}, // {"scope", String("local")}}), - GetDefaultAutoInline("cuda"), + ScheduleRule::AutoInline( + /*into_producer=*/true, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/false, + /*require_injective=*/false, + /*require_ordered=*/false, + /*disallow_op=*/Array{}), ScheduleRule::CrossThreadReduction( /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( @@ -136,28 +152,32 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"store", "wmma_store_16x16x16_s32_shared"}, }, }; - Array results{ScheduleRule::MultiLevelTilingTensorCore( - /*intrin_groups=*/intrin_groups, - /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, - /*max_innermost_factor=*/Integer(4), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, - /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared")}}, - /*reuse_write=*/ - Map{{"req", String("must")}, - {"levels", Array{2}}, // - {"scope", String("shared")}}, - /*use_software_pipeline=*/false)}; + Array results{ + ScheduleRule::ApplyCustomRule(), + ScheduleRule::MultiLevelTilingTensorCore( + /*intrin_groups=*/intrin_groups, + /*structure=*/"SSSRRSRS", + /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*max_innermost_factor=*/Integer(4), + /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*reuse_read=*/ + Map{{"req", String("must")}, + {"levels", Array{4}}, // + {"scope", String("shared")}}, + /*reuse_write=*/ + Map{{"req", String("must")}, + {"levels", Array{2}}, // + {"scope", String("shared")}}, + /*use_software_pipeline=*/false) // + }; Array append = ScheduleRule::DefaultCUDA(); - results.insert(results.end(), append.begin(), append.end()); + results.insert(results.end(), append.begin() + 1, append.end()); return results; } Array ScheduleRule::DefaultHexagon() { return { + ScheduleRule::ApplyCustomRule(), ScheduleRule::AutoInline( /*into_producer=*/false, /*into_consumer=*/true, diff --git a/src/meta_schedule/schedule_rule/winograd.cc b/src/meta_schedule/schedule_rule/winograd.cc deleted file mode 100644 index 22e2300d63b6..000000000000 --- a/src/meta_schedule/schedule_rule/winograd.cc +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "../utils.h" -#include "./auto_bind.h" - -namespace tvm { -namespace meta_schedule { - -using namespace tvm::tir; - -TVM_REGISTER_GLOBAL("meta_schedule.compute_inline") - .set_body_typed([](Schedule sch, BlockRV block) -> Array { - sch->ComputeInline(block); - return {sch}; - }); - -inline BlockRV GetOnlyProducer(Schedule sch, BlockRV block) { - Array producers = sch->GetProducers(block); - ICHECK_EQ(producers.size(), 1); - return producers[0]; -} - -inline BlockRV GetOnlyConsumer(Schedule sch, BlockRV block) { - Array consumers = sch->GetConsumers(block); - ICHECK_EQ(consumers.size(), 1); - return consumers[0]; -} - -inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) { - Array factors{nullptr}; - Array loops = sch->GetLoops(block); - ICHECK_EQ(loops.size(), 6); - - factors = sch->SamplePerfectTile(loops[2], /*n=*/2, /*max_innermost_factor=*/64); - Array t0 = sch->Split(loops[2], {factors.begin(), factors.end()}); - ICHECK_EQ(t0.size(), 2); - - factors = sch->SamplePerfectTile(loops[3], /*n=*/2, /*max_innermost_factor=*/64); - Array t1 = sch->Split(loops[3], {factors.begin(), factors.end()}); - ICHECK_EQ(t1.size(), 2); - - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) { - if (*i <= 16) { - sch->Unroll(loops[0]); - } - } - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) { - if (*i <= 16) { - sch->Unroll(loops[1]); - } - } - sch->Unroll(loops[4]); - sch->Unroll(loops[5]); - sch->Reorder({ - t0[0], - t1[0], - t0[1], - t1[1], - loops[0], - loops[1], - loops[4], - loops[5], - }); - return t1[1]; -} - -inline LoopRV ScheduleDataPackNCHW(Schedule sch, BlockRV block) { - Array loops = sch->GetLoops(block); - ICHECK_EQ(loops.size(), 6); - - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) { - if (*i <= 16) { - sch->Unroll(loops[0]); - } - } - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) { - if (*i <= 16) { - sch->Unroll(loops[1]); - } - } - sch->Unroll(loops[4]); - sch->Unroll(loops[5]); - - Array factors = sch->SamplePerfectTile(loops[3], /*n=*/2, /*max_innermost_factor=*/64); - Array split = - sch->Split(loops[3], /*factors=*/{factors[0], factors[1]}, /*preserve_unit_loops=*/true); - - LoopRV fused = sch->Fuse({loops[2], split[0]}); - sch->Reorder({fused, split[1], loops[0], loops[1]}); - return split[1]; -} - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.llvm") - .set_body_typed([](Schedule sch, BlockRV block) -> Array { - ScheduleDataPack(sch, block); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetOnlyProducer(sch, data_pack); - BlockRV data_pad = GetOnlyProducer(sch, input_tile); - ScheduleDataPack(sch, data_pack); - sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), - /*preserve_unit_loops=*/true); - sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), - /*preserve_unit_loops=*/true); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_output.nchw.cuda") - .set_body_typed([](Schedule sch, BlockRV output) -> Array { - // get loops - Array loops = sch->GetLoops(output); - ICHECK_EQ(loops.size(), 4); - - BlockRV OL{nullptr}; - - // tile - Optional tile_size = - tir::GetAnn(sch->GetSRef(output), "winograd_tile_size"); - ICHECK(tile_size.defined()) << "Winograd tile size is not defined in block annotation!"; - Array split0 = sch->Split(loops[2], {NullOpt, tile_size.value()}); - Array split1 = sch->Split(loops[3], {NullOpt, tile_size.value()}); - sch->Reorder({split0[0], split1[0], split0[1], split1[1]}); - - // compute_at - BlockRV inverse = GetOnlyProducer(sch, output); - sch->ComputeAt(inverse, /*loop_rv=*/split1[0], - /*preserve_unit_loops=*/true); - - // fuse - LoopRV fused = sch->Fuse({loops[0], loops[1], split0[0], split1[0]}); - - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); - BindBlockThreadIdx(sch, output, max_threadblocks, max_threads_per_block, get_factor); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.cuda") - .set_body_typed([](Schedule sch, BlockRV block) -> Array { - ScheduleDataPack(sch, block); - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); - BindBlockThreadIdx(sch, block, max_threadblocks, max_threads_per_block, get_factor); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.nchw.cuda") - .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { - sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); - Array loops = sch->GetLoops(inverse); - ICHECK_EQ(loops.size(), 6); - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[2]))) { - if (*i <= 16) { - sch->Unroll(loops[2]); - } - } - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[3]))) { - if (*i <= 16) { - sch->Unroll(loops[3]); - } - } - sch->Unroll(loops[4]); - sch->Unroll(loops[5]); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_kernel_pack.nchw.cuda") - .set_body_typed([](Schedule sch, BlockRV kernel_pack) -> Array { - Array loops = sch->GetLoops(kernel_pack); - ICHECK_EQ(loops.size(), 6); - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) { - if (*i <= 16) { - sch->Unroll(loops[0]); - } - } - if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) { - if (*i <= 16) { - sch->Unroll(loops[1]); - } - } - sch->Unroll(loops[4]); - sch->Unroll(loops[5]); - - LoopRV fused = sch->Fuse({loops[2], loops[3]}); - - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); - BindBlockThreadIdx(sch, kernel_pack, max_threadblocks, max_threads_per_block, get_factor); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetOnlyProducer(sch, data_pack); - BlockRV data_pad = GetOnlyProducer(sch, input_tile); - LoopRV loop = ScheduleDataPack(sch, data_pack); - sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); - sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); - sch->ComputeInline(data_pad); - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); - BindBlockThreadIdx(sch, data_pack, max_threadblocks, max_threads_per_block, get_factor); - return {sch}; - }); - -TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.nchw.cuda") - .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { - BlockRV input_tile = GetOnlyProducer(sch, data_pack); - BlockRV data_pad = GetOnlyProducer(sch, input_tile); - - BlockRV data_l = sch->CacheWrite(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local"); - BlockRV d = sch->CacheRead(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local"); - LoopRV loop = ScheduleDataPackNCHW(sch, data_pack); - sch->ReverseComputeAt(data_l, loop, /*preserve_unit_loops=*/true); - sch->ComputeAt(d, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); - sch->ComputeInline(data_pad); - - int64_t max_threadblocks = 256; - int64_t max_threads_per_block = 1024; - auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); - BindBlockThreadIdx(sch, data_pack, max_threadblocks, max_threads_per_block, get_factor); - return {sch}; - }); - -} // namespace meta_schedule -} // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 8eb2760dc791..491af6e28f77 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -118,20 +118,11 @@ class PostOrderApplyNode : public SpaceGeneratorNode { std::vector stack; Array result{sch}; - // Enumerate the schedule rules first because you can - // always concat multiple schedule rules as one Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); - Array> rules{NullOpt}; - rules.insert(rules.end(), sch_rules.value().begin(), sch_rules.value().end()); - for (Optional sch_rule : rules) { - if (sch_rule.defined()) { - for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, all_blocks); - } - } else { - for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, Array{all_blocks.rbegin(), all_blocks.rend()}); - } + + for (ScheduleRule sch_rule : sch_rules.value()) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, all_blocks); } result.clear(); while (!stack.empty()) { @@ -150,33 +141,13 @@ class PostOrderApplyNode : public SpaceGeneratorNode { stack.emplace_back(sch, blocks); continue; } - - Optional ann = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule"); - const runtime::PackedFunc* custom_schedule_fn = - ann.defined() ? runtime::Registry::Get(ann.value()) : nullptr; - const bool has_schedule_rule = custom_schedule_fn != nullptr; - - if (ann.defined() && ann.value() != "None" && !has_schedule_rule) { - LOG(WARNING) << "Custom schedule rule not found, ignoring schedule_rule annotation: " - << ann.value(); - } - - if ((has_schedule_rule && sch_rule.defined()) || - (!has_schedule_rule && !sch_rule.defined()) || - (ann.defined() && ann.value() == "None")) { - stack.emplace_back(sch, blocks); - continue; - } - - Array applied{nullptr}; - if (sch_rule.defined()) { - applied = sch_rule.value()->Apply(sch, /*block=*/block_rv); - } else { - ICHECK(custom_schedule_fn) - << "ValueError: Custom schedule rule not found: " << ann.value(); - applied = (*custom_schedule_fn)(sch, block_rv); + if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { + if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").defined()) { + stack.emplace_back(sch, blocks); + continue; + } } - + Array applied = sch_rule->Apply(sch, /*block=*/block_rv); for (const tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 37e0d1db5e98..80264516c4ce 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -520,27 +520,24 @@ inline bool IsGPUTarget(const std::string& target_name) { * \return The AutoInline schedule rule for the given target. */ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { - if (target_name == "llvm" || target_name == "hexagon") { - return ScheduleRule::AutoInline( - /*into_producer=*/false, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/true, - /*require_injective=*/true, - /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}); + Array rules{nullptr}; + if (target_name == "llvm") { + rules = ScheduleRule::DefaultLLVM(); + } else if (target_name == "hexagon") { + rules = ScheduleRule::DefaultHexagon(); } else if (IsGPUTarget(target_name)) { - return ScheduleRule::AutoInline( - /*into_producer=*/true, - /*into_consumer=*/true, - /*inline_const_tensor=*/true, - /*disallow_if_then_else=*/false, - /*require_injective=*/false, - /*require_ordered=*/false, - /*disallow_op=*/Array{}); + rules = ScheduleRule::DefaultCUDA(); + } else { + LOG(FATAL) << "ValueError: Unsupported target: " << target_name; + } + for (const ScheduleRule& rule : rules) { + if (rule->GetTypeKey() == "meta_schedule.AutoInline") { + return rule; + } } - LOG(FATAL) << "Unsupported target " << target_name; - return ScheduleRule(nullptr); + LOG(FATAL) << "ValueError: AutoInline rule is not found in the default rules for target: " + << target_name; + throw; } } // namespace meta_schedule diff --git a/src/target/tag.cc b/src/target/tag.cc index 0747769b1e04..c9f24145814b 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -97,6 +97,7 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ {"kind", String("cuda")}, \ + {"keys", Array{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ {"max_shared_memory_per_block", Integer(SharedMem)}, \ {"max_threads_per_block", Integer(1024)}, \ @@ -358,9 +359,11 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); #undef TVM_REGISTER_CUDA_TAG -#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("llvm")}, {"mcpu", String(Arch)}, {"num-cores", Integer(Cores)}}); +#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ + TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ + {"keys", Array{"x86", "cpu"}}, \ + {"mcpu", String(Arch)}, \ + {"num-cores", Integer(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index c222de81f2ad..80da5a727926 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include "../../tir/ir/functor_common.h" #include "../../tir/transforms/ir_utils.h" @@ -107,17 +108,21 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { Stmt VisitStmt_(const BlockNode* _block) final { Block block = Downcast(StmtMutator::VisitStmt_(_block)); - if (Optional ann = block->annotations.Get(topi_attr)) { - Array new_buffers; + BlockNode* n = block.CopyOnWrite(); + if (Optional ann = n->annotations.Get(topi_attr)) { for (Buffer buffer : Downcast>(ann)) { auto it = buffer2index_.find(buffer); if (it != buffer2index_.end()) { layout_free_buffer_indices_.insert(it->second); - } else { - new_buffers.push_back(buffer); } } - block.CopyOnWrite()->annotations.Set(topi_attr, new_buffers); + n->annotations.erase(topi_attr); + } + for (const String& attr : this->blocklist) { + auto it = n->annotations.find(attr); + if (it != n->annotations.end()) { + n->annotations.erase(attr); + } } return std::move(block); } @@ -125,6 +130,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { std::unordered_map buffer2index_; std::set layout_free_buffer_indices_; String topi_attr = "layout_free_placeholders"; + std::vector blocklist = {"const_matrix", "auto_scheduler_simplify_const_tensor_indices", + "workload"}; }; BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 6970b0ac06b5..b703c79c5d3a 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -22,16 +22,19 @@ import numpy as np import pytest - import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm import relay from tvm._ffi import register_func +from tvm.contrib.hexagon.meta_schedule import ( + get_hexagon_local_builder, + get_hexagon_rpc_runner, +) from tvm.meta_schedule import postproc, schedule_rule -from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN -from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner -from tvm import meta_schedule as ms from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN + from ..infrastructure import get_hexagon_target MODEL_JSON = "resnet50_int8.json" @@ -44,6 +47,7 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): """Tune VRMPY with auto tensorization.""" sch_rules = [ + schedule_rule.ApplyCustomRule(), schedule_rule.AutoInline( into_producer=False, into_consumer=True, @@ -269,7 +273,7 @@ def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): _schedule_packed_8x8x32_conv2d()(sch, conv2d_block) return [sch] - register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + register_func("meta_schedule.conv2d_NCHWc_int8.hexagon", schedule_rule_conv2d_packed_8x8x32) def schedule_conv2d_for_tune(sch: Schedule): _schedule_packed_8x8x32_conv2d()(sch) diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py deleted file mode 100644 index ac18bab81006..000000000000 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py +++ /dev/null @@ -1,206 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring - -import tvm -from tvm import meta_schedule as ms -from tvm.ir import IRModule -from tvm.meta_schedule.testing.conv2d_winograd_cpu import conv2d_winograd_cpu -from tvm.target import Target -from tvm.tir.schedule import Schedule, Trace - - -def _get_mod(): - # pylint: disable=invalid-name - def inline(sch: Schedule): - b1 = sch.get_block(name="A") - b2 = sch.get_block(name="B") - sch.compute_inline(block=b1) - sch.compute_inline(block=b2) - - def input_tile_data_pad(sch: Schedule): - b78 = sch.get_block(name="input_tile") - l80 = sch.sample_compute_location(block=b78, decision=4) - sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) - - b81 = sch.get_block(name="data_pad") - l83 = sch.sample_compute_location(block=b81, decision=-2) - sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True) - - def data_pack(sch: Schedule): - b18 = sch.get_block(name="data_pack") - l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) - sch.unroll(loop=l19) - sch.unroll(loop=l20) - v25, v26 = sch.sample_perfect_tile( - n=2, - loop=l21, - max_innermost_factor=64, - decision=[9, 1], - ) - l27, l28 = sch.split(loop=l21, factors=[v25, v26]) - v29, v30 = sch.sample_perfect_tile( - n=2, - loop=l22, - max_innermost_factor=64, - decision=[32, 4], - ) - l31, l32 = sch.split(loop=l22, factors=[v29, v30]) - sch.unroll(loop=l23) - sch.unroll(loop=l24) - sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24) - - def bgemm(sch: Schedule): - bgemm = sch.get_block(name="bgemm") - write_cache = sch.cache_write( - block=bgemm, - write_buffer_index=0, - storage_scope="global", - ) - sch.annotate( - block_or_loop=bgemm, - ann_key="meta_schedule.tiling_structure", - ann_val="SSRSRS", - ) - # b33, b34 = b34, b33 - l35, l36, l37, l38, l39 = sch.get_loops(block=bgemm) - v40, v41, v42, v43 = sch.sample_perfect_tile( - n=4, - loop=l35, - max_innermost_factor=64, - decision=[1, 2, 3, 1], - ) - l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) - v48, v49, v50, v51 = sch.sample_perfect_tile( - n=4, - loop=l36, - max_innermost_factor=64, - decision=[1, 1, 1, 6], - ) - l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) - v56, v57, v58, v59 = sch.sample_perfect_tile( - n=4, - loop=l37, - max_innermost_factor=64, - decision=[1, 1, 1, 9], - ) - l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) - v64, v65, v66, v67 = sch.sample_perfect_tile( - n=4, - loop=l38, - max_innermost_factor=64, - decision=[2, 1, 16, 4], - ) - l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) - v72, v73 = sch.sample_perfect_tile( - n=2, - loop=l39, - max_innermost_factor=64, - decision=[16, 8], - ) - l74, l75 = sch.split(loop=l39, factors=[v72, v73]) - sch.reorder( - # fmt: off - l44, l52, l60, l68, - l45, l53, l61, l69, - l74, - l46, l54, l62, l70, - l75, - l47, l55, l63, l71, - # fmt: on - ) - sch.reverse_compute_at(block=write_cache, loop=l69, preserve_unit_loops=True) - - def inverse(sch: Schedule): - b3 = sch.get_block(name="inverse") - l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) - sch.unroll(loop=l4) - sch.unroll(loop=l5) - v10, v11 = sch.sample_perfect_tile( - n=2, - loop=l6, - max_innermost_factor=64, - decision=[1, 9], - ) - l12, l13 = sch.split(loop=l6, factors=[v10, v11]) - v14, v15 = sch.sample_perfect_tile( - n=2, - loop=l7, - max_innermost_factor=64, - decision=[2, 64], - ) - l16, l17 = sch.split(loop=l7, factors=[v14, v15]) - sch.unroll(loop=l8) - sch.unroll(loop=l9) - sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9) - - # pylint: enable=invalid-name - - sch = Schedule(mod=conv2d_winograd_cpu) - inline(sch) - data_pack(sch) - input_tile_data_pad(sch) - bgemm(sch) - inverse(sch) - return sch.mod - - -def test_conv2d_winograd_cpu(): - mod = conv2d_winograd_cpu - mod = IRModule({"main": mod}) - target = Target("llvm --num-cores=16") - context = ms.TuneContext( - mod=mod, - target=target, - task_name="Custom Search Space Task", - space_generator=ms.space_generator.PostOrderApply(), - ) - post_order_apply = context.space_generator - (sch,) = post_order_apply.generate_design_space(mod) - decisions = dict( - zip( - [i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample")], - [ - # data_pack - [9, 1], - [32, 4], - # input_tile - 4, - # data_pad - -2, - # inverse - [1, 9], - [2, 64], - # bgemm - [1, 2, 3, 1], - [1, 1, 1, 6], - [1, 1, 1, 9], - [2, 1, 16, 4], - [16, 8], - ], - ) - ) - trace = Trace(sch.trace.insts[:-4], decisions=decisions) - sch = Schedule(mod=mod) - trace.apply_to_schedule(sch, remove_postproc=False) - answer = sch.mod - expected = _get_mod() - tvm.ir.assert_structural_equal(answer, expected) - - -if __name__ == "__main__": - test_conv2d_winograd_cpu() diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py deleted file mode 100644 index 89a04a9464ce..000000000000 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py +++ /dev/null @@ -1,328 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring - -import tvm -from tvm import meta_schedule as ms -from tvm.ir import IRModule -from tvm.meta_schedule.testing.conv2d_winograd_cuda import conv2d_winograd_cuda -from tvm.target import Target -from tvm.tir.schedule import Schedule, Trace - - -def _get_mod(): - # pylint: disable=invalid-name - def inline(sch: Schedule): - b125 = sch.get_block(name="A") - sch.compute_inline(block=b125) - b126 = sch.get_block(name="B") - sch.compute_inline(block=b126) - - def input_tile_data_pad(sch: Schedule): - b115 = sch.get_block(name="input_tile") - (b116,) = sch.get_consumers(block=b115) - _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) - sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) - sch.set_scope(block=b115, buffer_index=0, storage_scope="local") - - b127 = sch.get_block(name="data_pad") - sch.compute_inline(block=b127) - - b3 = sch.get_block(name="data_pack") - l25, l26, l27, l28, _, _, _, _ = sch.get_loops(block=b3) - l33 = sch.fuse(l25, l26, l27, l28) - v34 = sch.sample_categorical( - candidates=[32, 64, 128, 256, 512, 1024], - probs=[ - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - ], - decision=2, - ) - l35, l36 = sch.split(loop=l33, factors=[None, v34]) - sch.bind(loop=l35, thread_axis="blockIdx.x") - sch.bind(loop=l36, thread_axis="threadIdx.x") - - def data_pack(sch: Schedule): - b16 = sch.get_block(name="data_pack") - l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) - sch.unroll(loop=l17) - sch.unroll(loop=l18) - v23, v24 = sch.sample_perfect_tile( - n=2, - loop=l19, - max_innermost_factor=64, - decision=[3, 3], - ) - l25, l26 = sch.split(loop=l19, factors=[v23, v24]) - v27, v28 = sch.sample_perfect_tile( - n=2, - loop=l20, - max_innermost_factor=64, - decision=[64, 2], - ) - l29, l30 = sch.split(loop=l20, factors=[v27, v28]) - sch.unroll(loop=l21) - sch.unroll(loop=l22) - sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22) - - def bgemm(sch: Schedule): - b31 = sch.get_block(name="bgemm") - sch.annotate( - block_or_loop=b31, - ann_key="meta_schedule.tiling_structure", - ann_val="SSSRRSRS", - ) - sch.annotate( - block_or_loop=b31, - ann_key="meta_schedule.thread_extent_low_inclusive", - ann_val=32, - ) - sch.annotate( - block_or_loop=b31, - ann_key="meta_schedule.thread_extent_high_inclusive", - ann_val=1024, - ) - b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") - b31, b32 = b32, b31 - l33, l34, l35, l36, l37 = sch.get_loops(block=b32) - v38, v39, v40, v41, v42 = sch.sample_perfect_tile( - n=5, - loop=l33, - max_innermost_factor=64, - decision=[1, 1, 1, 1, 6], - ) - l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) - v48, v49, v50, v51, v52 = sch.sample_perfect_tile( - n=5, - loop=l34, - max_innermost_factor=64, - decision=[1, 1, 1, 3, 2], - ) - l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) - v58, v59, v60, v61, v62 = sch.sample_perfect_tile( - n=5, - loop=l35, - max_innermost_factor=64, - decision=[3, 1, 1, 1, 3], - ) - l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) - v68, v69, v70, v71, v72 = sch.sample_perfect_tile( - n=5, - loop=l36, - max_innermost_factor=64, - decision=[4, 2, 1, 4, 4], - ) - l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) - v78, v79, v80 = sch.sample_perfect_tile( - n=3, - loop=l37, - max_innermost_factor=64, - decision=[32, 1, 4], - ) - l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) - sch.reorder( - # fmt: off - l43, l53, l63, l73, - l44, l54, l64, l74, - l45, l55, l65, l75, - l81, - l82, - l46, l56, l66, l76, - l83, - l47, l57, l67, l77, - # fmt: on - ) - l84 = sch.fuse(l43, l53, l63, l73) - sch.bind(loop=l84, thread_axis="blockIdx.x") - l85 = sch.fuse(l44, l54, l64, l74) - sch.bind(loop=l85, thread_axis="vthread.x") - l86 = sch.fuse(l45, l55, l65, l75) - sch.bind(loop=l86, thread_axis="threadIdx.x") - - b87 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") - sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) - _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87) - sch.fuse(l92, l93, l94, l95) - v97 = sch.sample_categorical( - candidates=[1, 2, 3, 4], - probs=[0.25, 0.25, 0.25, 0.25], - decision=1, - ) - sch.annotate( - block_or_loop=b87, - ann_key="meta_schedule.cooperative_fetch", - ann_val=v97, - ) - - b101 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") - sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) - _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101) - sch.fuse(l106, l107, l108, l109) - v110 = sch.sample_categorical( - candidates=[1, 2, 3, 4], - probs=[0.25, 0.25, 0.25, 0.25], - decision=1, - ) - sch.annotate( - block_or_loop=b101, - ann_key="meta_schedule.cooperative_fetch", - ann_val=v110, - ) - - sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True) - - def inverse(sch: Schedule): - b1 = sch.get_block(name="inverse") - l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) - sch.unroll(loop=l2) - sch.unroll(loop=l3) - v8, v9 = sch.sample_perfect_tile( - n=2, - loop=l4, - max_innermost_factor=64, - decision=[3, 3], - ) - l10, l11 = sch.split(loop=l4, factors=[v8, v9]) - v12, v13 = sch.sample_perfect_tile( - n=2, - loop=l5, - max_innermost_factor=64, - decision=[2, 64], - ) - l14, l15 = sch.split(loop=l5, factors=[v12, v13]) - sch.unroll(loop=l6) - sch.unroll(loop=l7) - sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) - l59 = sch.fuse(l10, l14, l11, l15) - v60 = sch.sample_categorical( - candidates=[32, 64, 128, 256, 512, 1024], - probs=[ - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - ], - decision=2, - ) - l61, l62 = sch.split(loop=l59, factors=[None, v60]) - sch.bind(loop=l61, thread_axis="blockIdx.x") - sch.bind(loop=l62, thread_axis="threadIdx.x") - - def conv2d(sch: Schedule): - b7 = sch.get_block(name="conv2d_winograd") - l141, l142, l143, l144 = sch.get_loops(block=b7) - l145 = sch.fuse(l141, l142, l143, l144) - v146 = sch.sample_categorical( - candidates=[32, 64, 128, 256, 512, 1024], - probs=[ - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - 0.16666666666666666, - ], - decision=2, - ) - l147, l148 = sch.split(loop=l145, factors=[None, v146]) - sch.bind(loop=l147, thread_axis="blockIdx.x") - sch.bind(loop=l148, thread_axis="threadIdx.x") - - def root_anno(sch: Schedule): - b8 = sch.get_block(name="root", func_name="main") - v140 = sch.sample_categorical( - candidates=[0, 16, 64, 512, 1024], - probs=[ - 0.20000000000000001, - 0.20000000000000001, - 0.20000000000000001, - 0.20000000000000001, - 0.20000000000000001, - ], - decision=2, - ) - sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140) - - # pylint: enable=invalid-name - - sch = Schedule(mod=conv2d_winograd_cuda) - inline(sch) - data_pack(sch) - input_tile_data_pad(sch) - bgemm(sch) - inverse(sch) - conv2d(sch) - root_anno(sch) - - return sch.mod - - -def test_conv2d_winograd_cuda(): - mod = conv2d_winograd_cuda - mod = IRModule({"main": mod}) - context = ms.TuneContext( - mod=mod, - target=Target("nvidia/geforce-rtx-3090", host="llvm"), - task_name="Custom Search Space Task", - space_generator=ms.space_generator.PostOrderApply(), - ) - post_order_apply = context.space_generator - (sch,) = post_order_apply.generate_design_space(mod) - decisions = dict( - zip( - [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], - [ - # data_pack - [3, 3], - [64, 2], - 2, - # inverse - [3, 3], - [2, 64], - 2, - # bgemm - [1, 1, 1, 1, 6], - [1, 1, 1, 3, 2], - [3, 1, 1, 1, 3], - [4, 2, 1, 4, 4], - [32, 1, 4], - 1, - 1, - # root anno - 2, - # conv2d - 2, - ], - ) - ) - trace = Trace(sch.trace.insts, decisions=decisions) - sch = Schedule(mod=mod) - trace.apply_to_schedule(sch, remove_postproc=False) - answer = sch.mod - expected = _get_mod() - tvm.ir.assert_structural_equal(answer, expected) - - -if __name__ == "__main__": - test_conv2d_winograd_cuda() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 9026feb9e08e..c1d2dc3d0788 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -122,23 +122,6 @@ def main(a: T.handle, d: T.handle) -> None: D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) -@tvm.script.ir_module -class MatmulCustomized: - @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main"}) - A = T.match_buffer(a, (1024, 1024), "float32") - B = T.match_buffer(b, (1024, 1024), "float32") - C = T.match_buffer(c, (1024, 1024), "float32") - with T.block("root"): - for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): - T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"}) - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -382,32 +365,6 @@ def correct_trace(a, b, c, d): ) -def test_meta_schedule_custom_search_space(): - mod = MatmulCustomized - context = TuneContext( - mod=mod, - target=Target("llvm"), - task_name="Custom Search Space Task", - space_generator=PostOrderApply( - sch_rules=[], - postprocs=[], - mutator_probs={}, - ), - ) - post_order_apply = context.space_generator - post_order_apply.generate_design_space(mod) - called = False - - def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: - nonlocal called - called = True - return [sch] - - register_func("tvm.meta_schedule.test.custom_search_space", custom_search_space_func) - post_order_apply.generate_design_space(mod) - assert called - - def test_target_blocks_search_space(): # Test that specific blocks of trinity matmul can be targeted. def filter_fn(block, target_names) -> bool: diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index c689a15c56b2..bf302cd0e5bf 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Integration test for MetaSchedule""" -from typing import List import tempfile +from typing import List + import numpy as np import pytest import tvm @@ -27,7 +28,7 @@ from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir.transform import PassContext -from tvm.meta_schedule.database import Workload, TuningRecord +from tvm.meta_schedule.database import TuningRecord, Workload from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.meta_schedule.tune_context import _normalize_mod @@ -333,7 +334,6 @@ def _test(mod, params, target): assert "schedule_rule" in annotations assert "vnni" in annotations["schedule_rule"] - ... mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) _test(mod, params, target="llvm -mcpu=cascadelake") @@ -445,7 +445,6 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore diff --git a/tests/python/unittest/test_meta_schedule_space_cpu_winograd.py b/tests/python/unittest/test_meta_schedule_space_cpu_winograd.py new file mode 100644 index 000000000000..78b75d592ed4 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_space_cpu_winograd.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for MetaSchedule search space on CPU""" +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + print_sketches, +) +from tvm.meta_schedule.testing.te_workload import create_te_workload +from tvm.script import tir as T +from tvm.target import Target + + +def _target(): + return Target("aws/cpu/c5.9xlarge") + + +def _design_space(mod): + return generate_design_space( + kind="llvm", + mod=mod, + target=_target(), + types=ms.ScheduleRule, + ) + + +def test_cpu_nhwc(): + # fmt: off + @T.prim_func + def cpu_nhwc_0(X: T.Buffer[(1, 14, 14, 128), "float32"], W: T.Buffer[(6, 6, 128, 128), "float32"], conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) + data_pad = T.alloc_buffer([1, 16, 16, 128], dtype="float32") + input_tile = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + data_pack = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + inverse = T.alloc_buffer([4, 4, 9, 128], dtype="float32") + bgemm_global = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + for i2_0 in T.serial(9): + for ax0, ax1, ax2, ax3 in T.grid(1, 6, 6, 128): + with T.block("data_pad"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(16, i2_0 // 3 * 4 + ax1) + i2 = T.axis.spatial(16, i2_0 % 3 * 4 + ax2) + i3 = T.axis.spatial(128, ax3) + T.reads(X[i0, i1, i2, i3]) + T.writes(data_pad[i0, i1, i2, i3]) + T.block_attr({"schedule_rule":"None"}) + data_pad[i0, i1, i2, i3] = T.if_then_else(0 <= i1 and i1 < 14 and 0 <= i2 and i2 < 14, X[i0, i1, i2, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(2): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 64): + with T.block("input_tile"): + eps, nu = T.axis.remap("SS", [ax0, ax1]) + p = T.axis.spatial(9, i2_0 + ax2) + ci = T.axis.spatial(128, i3_0 * 64 + ax3) + T.reads(data_pad[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci]) + T.writes(input_tile[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile[eps, nu, p, ci] = data_pad[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci] + for i2_1, i3_1 in T.grid(1, 64): + for i0 in T.unroll(6): + for i1 in T.unroll(6): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack"): + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(9, i2_0 + i2_1) + ci = T.axis.spatial(128, i3_0 * 64 + i3_1) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(input_tile[r_a, r_b, p, ci]) + T.writes(data_pack[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_data_pack"}) + with T.init(): + data_pack[eps, nu, p, ci] = T.float32(0) + data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(3, 2, 3, 1, 1, 1, 1, 1): + for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(32, 1, 1, 1, 2, 4, 2, 3, 3, 64): + with T.block("bgemm"): + eps = T.axis.spatial(6, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) + nu = T.axis.spatial(6, i1_0 * 3 + i1_1 * 3 + i1_2 * 3 + i1_3) + p = T.axis.spatial(9, i2_0 * 3 + i2_1 * 3 + i2_2 * 3 + i2_3) + co = T.axis.spatial(128, i3_0 * 128 + i3_1 * 128 + i3_2 * 64 + i3_3) + ci = T.axis.reduce(128, i4_0 * 4 + i4_1) + T.reads(data_pack[eps, nu, p, ci], W[eps, nu, co, ci]) + T.writes(bgemm_global[eps, nu, p, co]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS", "meta_schedule.write_cache_level":[2]}) + with T.init(): + bgemm_global[eps, nu, p, co] = T.float32(0) + bgemm_global[eps, nu, p, co] = bgemm_global[eps, nu, p, co] + data_pack[eps, nu, p, ci] * W[eps, nu, co, ci] + for ax0, ax1, ax2, ax3 in T.grid(2, 3, 3, 128): + with T.block("bgemm_global"): + v0 = T.axis.spatial(6, i0_0 * 2 + ax0) + v1 = T.axis.spatial(6, i1_0 * 3 + ax1) + v2 = T.axis.spatial(9, i2_0 * 3 + ax2) + v3 = T.axis.spatial(128, ax3) + T.reads(bgemm_global[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_global[v0, v1, v2, v3] + for i2_0, i3_0, i2_1, i3_1 in T.grid(3, 8, 3, 16): + for i0 in T.unroll(4): + for i1 in T.unroll(4): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("inverse"): + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(9, i2_0 * 3 + i2_1) + co = T.axis.spatial(128, i3_0 * 16 + i3_1) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(bgemm[r_a, r_b, p, co]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_inverse"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0, i1, i2, i3 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [3, 3]), + ("SamplePerfectTile", [8, 16]), + ("SamplePerfectTile", [9, 1]), + ("SamplePerfectTile", [2, 64]), + ("SampleComputeLocation", 1), + ("SampleComputeLocation", 0), + ("SamplePerfectTile", [3, 1, 1, 2]), + ("SamplePerfectTile", [2, 1, 1, 3]), + ("SamplePerfectTile", [3, 1, 1, 3]), + ("SamplePerfectTile", [1, 1, 2, 64]), + ("SamplePerfectTile", [32, 4]), + ("SampleCategorical", 2), + ] + with _target(): + mod = create_te_workload("C2D_WIN_NHWC", 0) + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_nhwc_0], + expected_decisions=[decision_0], + ) + + +if __name__ == "__main__": + test_cpu_nhwc() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index f0f6e91ea655..324d8a9ec4f8 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. """Tests for MetaSchedule search space on CUDA""" -from tvm import autotvm from tvm import meta_schedule as ms -from tvm import te, topi from tvm.meta_schedule.testing.space_generation import ( check_sketches, generate_design_space, @@ -41,27 +39,6 @@ def _design_space(mod): ) -def _conv2d_winograd_nchw(): - data = te.placeholder((1, 64, 224, 224), name="data", dtype="float32") - kernel = te.placeholder((6, 6, 64, 64), name="kernel", dtype="float32") - return te.create_prim_func( - [ - data, - kernel, - topi.cuda.conv2d_winograd.winograd_cuda( - cfg=autotvm.ConfigSpace(), - data=data, - kernel=kernel, - strides=(1, 1), - padding=(1, 1), - dilation=(1, 1), - out_dtype="float32", - pre_computed=True, - ), - ] - ) - - def test_cuda_c1d(): # fmt: off @T.prim_func @@ -1272,151 +1249,6 @@ def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, ) -def test_cuda_winograd_nchw_conv2d(): - # fmt: off - @T.prim_func - def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T.Buffer[(6, 6, 64, 64), "float32"], output: T.Buffer[(1, 64, 224, 224), "float32"]) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - data_pack = T.alloc_buffer([6, 6, 64, 3136], dtype="float32") - bgemm = T.alloc_buffer([6, 6, 64, 3136], dtype="float32") - inverse_local = T.alloc_buffer([64, 3136, 4, 4], dtype="float32", scope="local") - data_pack_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") - d_local = T.alloc_buffer([64, 3136, 6, 6], dtype="float32", scope="local") - bgemm_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") - kernel_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") - data_pack_shared = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="shared") - for i2_i3_0_fused_i3_1_fused_0 in T.thread_binding(3136, thread="blockIdx.x"): - for i2_i3_0_fused_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6): - with T.block("d_local"): - v0 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax0) - v1 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7 + ax1) - v2, v3 = T.axis.remap("SS", [ax2, ax3]) - T.reads(data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1]) - T.writes(d_local[v0, v1, v2, v3]) - d_local[v0, v1, v2, v3] = T.if_then_else(1 <= v1 % 3136 // 56 * 4 + v2 and v1 % 3136 // 56 * 4 + v2 < 225 and 1 <= v1 % 56 * 4 + v3 and v1 % 56 * 4 + v3 < 225, data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1], T.float32(0), dtype="float32") - for i0 in T.unroll(6): - for i1 in T.unroll(6): - for i4 in T.unroll(6): - for i5 in T.unroll(6): - with T.block("data_pack"): - eps, nu = T.axis.remap("SS", [i0, i1]) - ci = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136) - p = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7) - r_a, r_a_1 = T.axis.remap("RR", [i4, i5]) - T.reads(d_local[ci, p, r_a, r_a_1]) - T.writes(data_pack_local[eps, nu, ci, p]) - T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"}) - with T.init(): - data_pack_local[eps, nu, ci, p] = T.float32(0) - data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + d_local[ci, p, r_a, r_a_1] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) - for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): - with T.block("data_pack_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax2) - v3 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7 + ax3) - T.reads(data_pack_local[v0, v1, v2, v3]) - T.writes(data_pack[v0, v1, v2, v3]) - data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(96, thread="blockIdx.x"): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(12, thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(224, thread="threadIdx.x"): - for i4_0 in T.serial(32): - for ax0_ax1_ax2_ax3_fused in T.serial(192): - with T.block("kernel_shared"): - v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 32) - v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 16) - v2 = T.axis.spatial(64, i4_0 * 2 + ax0_ax1_ax2_ax3_fused % 32 // 16) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 16 // 4 * 16 + ax0_ax1_ax2_ax3_fused % 16) - T.reads(kernel[v0, v1, v2, v3]) - T.writes(kernel_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) - kernel_shared[v0, v1, v2, v3] = kernel[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused in T.serial(9408): - with T.block("data_pack_shared"): - v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 1568) - v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 16) - v2 = T.axis.spatial(64, i4_0 * 2 + ax0_ax1_ax2_ax3_fused % 1568 // 784) - v3 = T.axis.spatial(3136, i0_0_i1_0_i2_0_i3_0_fused % 4 * 784 + ax0_ax1_ax2_ax3_fused % 784) - T.reads(data_pack[v0, v1, v2, v3]) - T.writes(data_pack_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] - for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(2, 1, 1, 1, 14, 1, 1, 1, 2, 1): - with T.block("bgemm"): - eps = T.axis.spatial(6, i0_4 + i0_1_i1_1_i2_1_i3_1_fused // 4 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 112 + i0_3) - nu = T.axis.spatial(6, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 16 + i1_3) - co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 16 // 4 * 16 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 112 // 56 * 2 + i2_3 * 2 + i2_4) - p = T.axis.spatial(3136, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 4 * 784 + i0_2_i1_2_i2_2_i3_2_fused % 56 * 14 + i3_3) - ci = T.axis.reduce(64, i4_0 * 2 + i4_1 + i4_2) - T.reads(kernel_shared[eps, nu, ci, co], data_pack_shared[eps, nu, ci, p]) - T.writes(bgemm_local[eps, nu, co, p]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) - with T.init(): - bgemm_local[eps, nu, co, p] = T.float32(0) - bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + kernel_shared[eps, nu, ci, co] * data_pack_shared[eps, nu, ci, p] - for ax0, ax1, ax2, ax3 in T.grid(1, 1, 2, 14): - with T.block("bgemm_local"): - v0 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 4 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 112 + ax0) - v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 16 + ax1) - v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 16 // 4 * 16 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 112 // 56 * 2 + ax2) - v3 = T.axis.spatial(3136, i0_0_i1_0_i2_0_i3_0_fused % 4 * 784 + i0_2_i1_2_i2_2_i3_2_fused % 56 * 14 + ax3) - T.reads(bgemm_local[v0, v1, v2, v3]) - T.writes(bgemm[v0, v1, v2, v3]) - bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] - for i0_i1_i2_0_i3_0_fused_fused_0 in T.thread_binding(6272, thread="blockIdx.x"): - for i0_i1_i2_0_i3_0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 4, 4, 6, 6): - with T.block("inverse"): - co = T.axis.spatial(64, ax0 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136) - p = T.axis.spatial(3136, ax1 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 56 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56) - vh, vw, r_a_2, r_a_3 = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) - T.reads(bgemm[r_a_2, r_a_3, co, p]) - T.writes(inverse_local[co, p, vh, vw]) - T.block_attr({"schedule_rule":"meta_schedule.winograd_inverse.nchw.cuda"}) - with T.init(): - inverse_local[co, p, vh, vw] = T.float32(0) - inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a_2, r_a_3, co, p] * T.Select(r_a_2 % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a_2 % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a_2 % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a_2 % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a_2 % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a_2 % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a_2 % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a_2 % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_a_3 % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_a_3 % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_a_3 % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_a_3 % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_a_3 % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_a_3 % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_a_3 % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_a_3 % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) - for i2_1, i3_1 in T.grid(4, 4): - with T.block("output"): - n = T.axis.spatial(1, 0) - co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136) - h = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 4 + i2_1) - w = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56 * 4 + i3_1) - T.reads(inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4]) - T.writes(output[n, co, h, w]) - T.block_attr({"schedule_rule":"meta_schedule.winograd_output.nchw.cuda", "winograd_tile_size":4}) - output[n, co, h, w] = inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4] - # fmt: on - decision_0 = [ - ("SamplePerfectTile", [448, 7]), - ("SampleCategorical", 1), - ("SampleCategorical", 0), - ("SamplePerfectTile", [1, 3, 2, 1, 1]), - ("SamplePerfectTile", [6, 1, 1, 1, 1]), - ("SamplePerfectTile", [4, 4, 2, 1, 2]), - ("SamplePerfectTile", [4, 1, 56, 14, 1]), - ("SamplePerfectTile", [32, 2, 1]), - ("SampleCategorical", 0), - ("SampleCategorical", 2), - ("SampleCategorical", 4), - ] - mod = _conv2d_winograd_nchw() - actual = _design_space(mod) - check_sketches( - mod, - sketches=actual, - expected_mods=[winograd_nchw_conv2d], - expected_decisions=[decision_0], - ) - - if __name__ == "__main__": test_cuda_c1d() test_cuda_c2d() @@ -1431,4 +1263,3 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T test_cuda_sfm() test_cuda_cbr() test_cuda_tbg() - test_cuda_winograd_nchw_conv2d() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py new file mode 100644 index 000000000000..16f9e64252ad --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -0,0 +1,355 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for MetaSchedule search space on CUDA""" +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + print_sketches, +) +from tvm.meta_schedule.testing.te_workload import create_te_workload +from tvm.script import tir as T +from tvm.target import Target + + +def _target(): + return Target("nvidia/geforce-rtx-3070") + + +def _design_space(mod): + return generate_design_space( + kind="cuda", + mod=mod, + target=_target(), + types=ms.ScheduleRule, + ) + + +def test_cuda_nhwc(): + # fmt: off + @T.prim_func + def cuda_nhwc_0(data: T.Buffer[(1, 14, 14, 128), "float32"], weight: T.Buffer[(6, 6, 128, 128), "float32"], conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":16}) + input_tile_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") + data_pack = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 9, 128], dtype="float32") + inverse = T.alloc_buffer([4, 4, 9, 128], dtype="float32") + data_pack_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") + bgemm_local = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="local") + data_pack_shared = T.alloc_buffer([6, 6, 9, 128], dtype="float32", scope="shared") + weight_shared = T.alloc_buffer([6, 6, 128, 128], dtype="float32", scope="shared") + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(2, thread="blockIdx.x"): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("input_tile"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) + eps, nu = T.axis.remap("SS", [ax0, ax1]) + p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8 + ax2) + ci = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) + T.reads(data[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci]) + T.writes(input_tile_local[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"None"}) + input_tile_local[eps, nu, p, ci] = T.if_then_else(0 <= p % 9 // 3 * 4 + eps and p % 9 // 3 * 4 + eps < 14 and 0 <= p % 3 * 4 + nu and p % 3 * 4 + nu < 14, data[p // 9, p % 9 // 3 * 4 + eps, p % 3 * 4 + nu, ci], T.float32(0), dtype="float32") + for i0 in T.unroll(6): + for i1 in T.unroll(6): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) + eps, nu = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8) + ci = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(input_tile_local[r_a, r_b, p, ci]) + T.writes(data_pack_local[eps, nu, p, ci]) + T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_data_pack"}) + with T.init(): + data_pack_local[eps, nu, p, ci] = T.float32(0) + data_pack_local[eps, nu, p, ci] = data_pack_local[eps, nu, p, ci] + input_tile_local[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("data_pack_local"): + T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 1152) + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8 + ax2) + v3 = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) + T.reads(data_pack_local[v0, v1, v2, v3]) + T.writes(data_pack[v0, v1, v2, v3]) + data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(96, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(27, thread="threadIdx.x"): + for i4_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused in T.serial(1728): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 864) + v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 864 // 144) + v2 = T.axis.spatial(9, ax0_ax1_ax2_ax3_fused % 144 // 16) + v3 = T.axis.spatial(128, i4_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(768): + with T.block("weight_shared"): + v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0_ax1_ax2_ax3_fused // 384) + v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 384 // 64) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + ax0_ax1_ax2_ax3_fused % 64 // 16) + v3 = T.axis.spatial(128, i4_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 2, 1, 1, 2, 16, 1, 1, 1, 1): + with T.block("bgemm"): + eps = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + i0_3 + i0_4) + nu = T.axis.spatial(6, i1_3 + i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 9) + p = T.axis.spatial(9, i0_2_i1_2_i2_2_i3_2_fused % 9 + i2_3 + i2_4) + co = T.axis.spatial(128, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 2 + i3_3) + ci = T.axis.reduce(128, i4_0 * 16 + i4_1 * 16 + i4_2) + T.reads(data_pack_shared[eps, nu, p, ci], weight_shared[eps, nu, co, ci]) + T.writes(bgemm_local[eps, nu, p, co]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS", "meta_schedule.write_cache_level":[3]}) + with T.init(): + bgemm_local[eps, nu, p, co] = T.float32(0) + bgemm_local[eps, nu, p, co] = bgemm_local[eps, nu, p, co] + data_pack_shared[eps, nu, p, ci] * weight_shared[eps, nu, co, ci] + for ax0, ax1, ax2, ax3 in T.grid(2, 1, 1, 2): + with T.block("bgemm_local"): + v0 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 32 * 2 + ax0) + v1 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 9 + ax1) + v2 = T.axis.spatial(9, i0_2_i1_2_i2_2_i3_2_fused % 9 + ax2) + v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused % 32 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 2 + ax3) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(18, thread="blockIdx.x"): + for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for i0 in T.unroll(4): + for i1 in T.unroll(4): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("inverse"): + vh, vw = T.axis.remap("SS", [i0, i1]) + p = T.axis.spatial(9, (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) // 384 * 3 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 24 // 8) + co = T.axis.spatial(128, (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 384 // 24 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 64 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(bgemm[r_a, r_b, p, co]) + T.writes(inverse[vh, vw, p, co]) + T.block_attr({"schedule_rule":"conv2d_nhwc_winograd_inverse"}) + with T.init(): + inverse[vh, vw, p, co] = T.float32(0) + inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_i1_i2_i3_fused_0 in T.thread_binding(144, thread="blockIdx.x"): + for i0_i1_i2_i3_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("conv2d_winograd"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(12, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) // 1536) + w = T.axis.spatial(12, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 1536 // 128) + co = T.axis.spatial(128, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 128) + T.reads(inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co]) + T.writes(conv2d_winograd[n, h, w, co]) + conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 9 + h // 4 * 3 + w // 4, co] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [3, 3]), + ("SamplePerfectTile", [16, 8]), + ("SampleCategorical", 1), + ("SamplePerfectTile", [3, 3]), + ("SamplePerfectTile", [16, 8]), + ("SampleCategorical", 5), + ("SamplePerfectTile", [3, 1, 1, 2, 1]), + ("SamplePerfectTile", [1, 2, 3, 1, 1]), + ("SamplePerfectTile", [1, 1, 9, 1, 1]), + ("SamplePerfectTile", [32, 2, 1, 2, 1]), + ("SamplePerfectTile", [8, 1, 16]), + ("SampleCategorical", 0), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ("SampleCategorical", 2), + ] + with _target(): + mod = create_te_workload("C2D_WIN_NHWC", 0) + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[cuda_nhwc_0], + expected_decisions=[decision_0], + ) + + +def test_cuda_nchw(): + # fmt: off + @T.prim_func + def cuda_nchw_0(data: T.Buffer[(1, 64, 56, 56), "float32"], weight: T.Buffer[(6, 6, 64, 64), "float32"], conv2d_winograd: T.Buffer[(1, 64, 56, 56), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":16}) + input_tile_local = T.alloc_buffer([64, 196, 6, 6], dtype="float32", scope="local") + data_pack = T.alloc_buffer([6, 6, 64, 196], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 64, 196], dtype="float32") + inverse_local = T.alloc_buffer([64, 196, 4, 4], dtype="float32", scope="local") + data_pack_local = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="local") + bgemm_local = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="local") + data_pack_shared = T.alloc_buffer([6, 6, 64, 196], dtype="float32", scope="shared") + weight_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") + for i2_i3_fused_0 in T.thread_binding(25, thread="blockIdx.x"): + for i2_i3_fused_1 in T.thread_binding(512, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6): + with T.block("input_tile"): + T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) + ci = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196 + ax0) + p = T.axis.spatial(196, (i2_i3_fused_0 * 120 + i2_i3_fused_1) % 196 + ax1) + eps, nu = T.axis.remap("SS", [ax2, ax3]) + T.reads(data[p // 196, ci, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1]) + T.writes(input_tile_local[ci, p, eps, nu]) + T.block_attr({"schedule_rule":"None"}) + input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 196 // 14 * 4 + eps and p % 196 // 14 * 4 + eps < 57 and 1 <= p % 14 * 4 + nu and p % 14 * 4 + nu < 57, data[p // 196, ci, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1], T.float32(0), dtype="float32") + for i0 in T.unroll(6): + for i1 in T.unroll(6): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack"): + T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) + eps, nu = T.axis.remap("SS", [i0, i1]) + ci = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196) + p = T.axis.spatial(196, (i2_i3_fused_0 * 512 + i2_i3_fused_1) % 196) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(input_tile_local[ci, p, r_a, r_b]) + T.writes(data_pack_local[eps, nu, ci, p]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) + with T.init(): + data_pack_local[eps, nu, ci, p] = T.float32(0) + data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + input_tile_local[ci, p, r_a, r_b] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("data_pack_local"): + T.where(i2_i3_fused_0 * 512 + i2_i3_fused_1 < 12544) + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(64, (i2_i3_fused_0 * 512 + i2_i3_fused_1) // 196 + ax2) + v3 = T.axis.spatial(196, (i2_i3_fused_0 * 120 + i2_i3_fused_1) % 196 + ax3) + T.reads(data_pack_local[v0, v1, v2, v3]) + T.writes(data_pack[v0, v1, v2, v3]) + data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(14, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(224, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i4_0 in T.serial(2): + for ax0_ax1_ax2_ax3_fused in T.serial(32256): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 5376) + v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 5376 // 896) + v2 = T.axis.spatial(64, i4_0 * 32 + ax0_ax1_ax2_ax3_fused % 896 // 28) + v3 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + ax0_ax1_ax2_ax3_fused % 28) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(36864): + with T.block("weight_shared"): + v0 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused // 6144) + v1 = T.axis.spatial(6, ax0_ax1_ax2_ax3_fused % 6144 // 1024) + v2 = T.axis.spatial(64, i4_0 * 32 + ax0_ax1_ax2_ax3_fused % 1024 // 32) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(16, 2, 3, 1, 4, 2, 3, 1, 1, 1): + with T.block("bgemm"): + eps = T.axis.spatial(6, i0_3 * 3 + i0_4) + nu = T.axis.spatial(6, i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 112 * 3 + i1_3) + co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 112 // 7 * 2 + i0_2_i1_2_i2_2_i3_2_fused + i2_3 + i2_4) + p = T.axis.spatial(196, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 7 * 4 + i3_3) + ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 2 + i4_2) + T.reads(data_pack_shared[eps, nu, ci, p], weight_shared[eps, nu, ci, co]) + T.writes(bgemm_local[eps, nu, co, p]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + with T.init(): + bgemm_local[eps, nu, co, p] = T.float32(0) + bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + data_pack_shared[eps, nu, ci, p] * weight_shared[eps, nu, ci, co] + for ax0, ax1, ax2, ax3 in T.grid(6, 3, 1, 4): + with T.block("bgemm_local"): + v0 = T.axis.spatial(6, ax0) + v1 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 112 * 3 + ax1) + v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused // 7 * 32 + i0_1_i1_1_i2_1_i3_1_fused % 112 // 7 * 2 + i0_2_i1_2_i2_2_i3_2_fused + ax2) + v3 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 7 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 7 * 4 + ax3) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i0_i1_i2_0_i3_0_fused_0 in T.thread_binding(196, thread="blockIdx.x"): + for i0_i1_i2_0_i3_0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1 in T.grid(1, 1): + for ax2 in T.unroll(4): + for ax3 in T.unroll(4): + for ax4 in T.unroll(6): + for ax5 in T.unroll(6): + with T.block("inverse"): + co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) // 196 + ax0) + p = T.axis.spatial(196, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 196 // 14 * 14 + (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 14 + ax1) + vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) + T.reads(bgemm[r_a, r_b, co, p]) + T.writes(inverse_local[co, p, vh, vw]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) + with T.init(): + inverse_local[co, p, vh, vw] = T.float32(0) + inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i2_1, i3_1 in T.grid(4, 4): + with T.block("conv2d_winograd"): + n = T.axis.spatial(1, 0) + co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) // 196) + h = T.axis.spatial(56, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 196 // 14 * 4 + i2_1) + w = T.axis.spatial(56, (i0_i1_i2_0_i3_0_fused_0 * 64 + i0_i1_i2_0_i3_0_fused_1) % 14 * 4 + i3_1) + T.reads(inverse_local[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4]) + T.writes(conv2d_winograd[n, co, h, w]) + conv2d_winograd[n, co, h, w] = inverse_local[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4] + # fmt: on + decision_0 = [ + ("SampleCategorical", 4), + ("SamplePerfectTile", [1, 1, 1, 2, 3]), + ("SamplePerfectTile", [1, 2, 1, 3, 1]), + ("SamplePerfectTile", [2, 16, 2, 1, 1]), + ("SamplePerfectTile", [7, 7, 1, 4, 1]), + ("SamplePerfectTile", [2, 16, 2]), + ("SampleCategorical", 3), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ("SampleCategorical", 1), + ] + with _target(): + mod = create_te_workload("C2D_WIN_NCHW", 0) + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[cuda_nchw_0], + expected_decisions=[decision_0], + debug_mask=0, + ) + + +if __name__ == "__main__": + test_cuda_nhwc() + test_cuda_nchw() diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index 1f91dc593143..3bbe916472f5 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -20,8 +20,8 @@ from typing import Optional import numpy as np # type: ignore -import pytest import tvm +import tvm.testing from tvm import meta_schedule as ms from tvm import relay from tvm._ffi import register_func @@ -176,29 +176,29 @@ def test_vnni_schedule_fn_tune(): C = te.compute( ... - attrs={"schedule_rule": "meta_schedule.dense_vnni"}, + attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, ) When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, it looks up the packed func registry for a function that is associated with the given schedule - rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule - functions must be + rule key ("meta_schedule.x86.dense_vnni" in this example). The signature of such custom + schedule functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. The BlockRV argument corresponds to the TE compute annotated with "schedule_rule". - The relevant code is in meta_schedule/space_generator/post_order_apply.cc. + The relevant code is in `src/meta_schedule/space_generator/apply_custom_rule.cc`. """ def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): _schedule_dense(m=None, do_tune=True)(sch, dense_block) return [sch] - register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni) + register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni) m, n, k = 1024, 1024, 1024 - target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") + target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") dev = tvm.cpu(0) relay_mod, params, f_check = _relay_dense(m, n, k) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 4c216cdbc53a..b59880758e5d 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -385,14 +385,12 @@ def expected_layout_attr( for i0, i1, i2 in T.grid(128, 128, 128): with T.block("C"): x, y, k = T.axis.remap("SSR", [i0, i1, i2]) - T.block_attr({"layout_free_placeholders": []}) with T.init(): C[x, y] = T.float32(0) C[x, y] = C[x, y] + A[x, k] * B[y, k] for i0, i1 in T.grid(128, 128): with T.block("D"): x, y = T.axis.remap("SS", [i0, i1]) - T.block_attr({"layout_free_placeholders": [C]}) D[x, y] = C[x, y] + T.float32(1) diff --git a/tests/python/unittest/test_tir_analysis_stmt_finding.py b/tests/python/unittest/test_tir_analysis_stmt_finding.py index 791699e4e4ed..acb5faa0de12 100644 --- a/tests/python/unittest/test_tir_analysis_stmt_finding.py +++ b/tests/python/unittest/test_tir_analysis_stmt_finding.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. import pytest - import tvm -from tvm.tir.analysis import find_anchor_block from tvm import te, topi -from tvm.meta_schedule.testing.te_workload import matmul, conv2d_winograd_nhwc +from tvm.meta_schedule.testing.te_workload import conv2d_winograd_nhwc, matmul +from tvm.tir.analysis import find_anchor_block def test_matmul_add(): @@ -35,7 +34,7 @@ def test_matmul_add(): def test_winograd(): mod = tvm.IRModule() - mod["main"] = te.create_prim_func(conv2d_winograd_nhwc(1, 56, 56, 64, 64, 3)) + mod["main"] = te.create_prim_func(conv2d_winograd_nhwc(1, 14, 14, 128, 128, 6)) block = find_anchor_block(mod)