Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Refactor ScheduleRule Attributes #13195

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,53 @@
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
#ifndef TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
#define TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_

#include "../utils.h"
#include <tvm/tir/schedule/schedule.h>

#include <algorithm>
#include <limits>
#include <utility>

namespace tvm {
namespace meta_schedule {

/*!
* \brief Bind the given block if it is not bound to blockIdx or threadIdx.
* \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical`
* to return a random thread extent.
* \param sch The schedule
* \param thread_extents The candidate thread extents.
* \return A sampler that returns a random thread extent.
*/
std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
Array<Integer> thread_extents);

/*!
* \brief Bind blockIdx.x and threadIdx.x to the given loop
* \param sch The schedule.
* \param block The block to be bound.
* \param loop The loop to be bound.
* \param max_threadblocks The maximum number of threadblocks allowed.
* \param max_threads The maximum number of threads allowed.
* \param max_threads_per_block The maximum number of threads allowed.
* \param get_factor A function that returns the tiling factor.
*/
void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block,
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t max_extent)> get_factor);
Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t)> get_factor = nullptr);

/*!
* \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical`
* to return a random thread extent.
* \param sch The schedule
* \param thread_extents The candidate thread extents.
* \return A sampler that returns a random thread extent.
* \brief Bind the given block if it is not bound to blockIdx or threadIdx.
* \param sch The schedule.
* \param block The block to be bound.
* \param max_threadblocks The maximum number of threadblocks allowed.
* \param max_threads_per_block The maximum number of threads allowed.
* \param get_factor A function that returns the tiling factor.
*/
std::function<tir::ExprRV(int64_t max_extent)> MakeFactorSampler(tir::Schedule sch,
Array<Integer> thread_extents);
void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t max_extent)> get_factor = nullptr);

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_BIND_H_
#endif // TVM_META_SCHEDULE_SCHEDULE_CUDA_THREAD_BIND_H_
37 changes: 37 additions & 0 deletions include/tvm/meta_schedule/schedule/generic/winograd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_
#define TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

/*!
* \brief Get the producer block of a given block.
* If there is a constant winograd transform matrix, inline it.
* \return The only producer block.
*/
tir::BlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::BlockRV block);

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_GENERIC_WINOGRAD_H_
Empty file.
8 changes: 8 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The cloned schedule rule.
*/
using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
/*!
* \brief Create a rule that applies customized rules registered using block attribute
* `schedule_rule`. The rule will be dispatched according to target keys.
* \return The created schedule rule.
*/
TVM_DLL static ScheduleRule ApplyCustomRule();
/*! \brief Check if the rule is `ApplyCustomRule` */
TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
postproc,
relay_integration,
runner,
schedule,
schedule_rule,
search_strategy,
space_generator,
Expand All @@ -41,10 +42,7 @@
from .mutator import Mutator
from .postproc import Postproc
from .profiler import Profiler
from .relay_integration import (
is_meta_schedule_dispatch_enabled,
is_meta_schedule_enabled,
)
from .relay_integration import is_meta_schedule_enabled
from .runner import Runner
from .schedule_rule import ScheduleRule
from .search_strategy import MeasureCandidate, SearchStrategy
Expand Down
16 changes: 1 addition & 15 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def compile_relay(
mod, target, params, pass_config, executor = _normalize_params(
mod, target, params, pass_config, executor
)
pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", target.kind.name != "cuda")
pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True)
with Profiler.timeit("PostTuningCompilation"):
with target, _autotvm_silencer(), database:
with transform.PassContext(
Expand All @@ -404,17 +404,3 @@ def is_meta_schedule_enabled() -> bool:
"relay.backend.use_meta_schedule",
False,
)


def is_meta_schedule_dispatch_enabled() -> bool:
"""Return whether the meta-schedule dispatch is enabled.
Returns
-------
enabled: bool
Whether the meta schedule is enabled
"""
return transform.PassContext.current().config.get(
"relay.backend.use_meta_schedule_dispatch",
False,
)
18 changes: 18 additions & 0 deletions python/tvm/meta_schedule/schedule/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule"""
from . import cpu, cuda, generic, x86
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for target key 'cpu'"""
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for target key 'cuda'"""
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for generic cases"""
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule/x86/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for target key 'x86'"""
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
blocks in a schedule. See also PostOrderApply.
"""
from .add_rfactor import AddRFactor
from .apply_custom_rule import ApplyCustomRule
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import (
MultiLevelTiling,
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
MultiLevelTilingWideVector,
MultiLevelTilingWithIntrin,
ReuseType,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Create a rule that applies customized rules registered using block attribute `schedule_rule`.
The rule will be dispatched according to target keys."""
from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.ApplyCustomRule")
class ApplyCustomRule(ScheduleRule):
"""A rule that applies customized rules registered using block attribute `schedule_rule`.
The rule will be dispatched according to target keys."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleApplyCustomRule, # type: ignore # pylint: disable=no-member
)
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class EvolutionarySearch(SearchStrategy):
def __init__(
self,
*,
population_size: int = 2048,
population_size: int = 512,
Copy link
Member

@masahi masahi Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zxybazh Is this change intended to go with this PR, or is it a left-over from development?

This change is so small and looks unrelated to this PR, but it has huge implications. Tuning time will become much shorter, which I like, but could there be a concern for perf regression due to this change?

I'm asking this because I have been doing perf and tuning time improvement for int8 TC, and after rebasing I'm directly affected by this change.

Copy link
Member Author

@junrushao junrushao Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi Thanks for asking, and it's definitely a valid concern. The reason is that we found it's somehow set to an unreasonably large number (we were consistently using 512 on downstream), and now wanted to scale it back, because it negatively affects tuning time. Also note that the population calculation is slightly different than AutoScheduler, which includes invalid candidates in population, while ours always prune ahead of time.

We did make sure we have proper numbers in hand before merging, and please refer to the table below for details:

Mainline (ms) This PR (ms) Difference
resnet_50 1.829128008 1.73111331 5.66%
mobilenet_v2 0.4773168361 0.4806015715 -0.68%
resnet_18 0.678600832 0.6301909627 7.68%
mobilenet_v3 0.6649458484 0.668666894 -0.56%
wide_resnet_50 3.753781691 3.077549256 21.97%
densenet_121 2.361973117 2.288411393 3.21%
inception_v3 3.53431975 3.478368823 1.61%
resnet3d_18 7.791658449 7.465777971 4.36%

init_measured_ratio: float = 0.2,
init_min_unmeasured: int = 50,
max_fail_count: int = 5,
Expand Down
Loading