Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 1, 2022
1 parent 1f8b2e9 commit 3d002b9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,10 @@ constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*! \brief Override the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_override_tiling_structure =
"meta_schedule.override_tiling_structure";

/*!
* \brief Mark that the loop should be further skip and bound to environment threads to enable
* cooperative fetching.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ def _conv2d_winograd_nhwc_impl(
else:
kernel_pack = weight
bgemm_attrs = {"layout_free_placeholders": [kernel_pack]}
bgemm_attrs["meta_schedule.override_tiling_structure"] = "SSSRRS" # type: ignore[assignment]
if write_cache_level is not None:
if not isinstance(write_cache_level, int):
bgemm_attrs["meta_schedule.write_cache_level"] = write_cache_level
Expand Down Expand Up @@ -1379,6 +1380,7 @@ def _conv2d_winograd_nchw_impl(
else:
kernel_pack = weight
bgemm_attrs = {"layout_free_placeholders": [kernel_pack]}
bgemm_attrs["meta_schedule.override_tiling_structure"] = "SSSRRS" # type: ignore[assignment]
if write_cache_level is not None:
if not isinstance(write_cache_level, int):
bgemm_attrs["meta_schedule.write_cache_level"] = write_cache_level
Expand Down
7 changes: 3 additions & 4 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&
return {sch};
}
String s = this->structure;
if (Optional<String> ann =
tir::GetAnn<String>(sch->GetSRef(block_rv), tir::attr::meta_schedule_tiling_structure)) {
if (Optional<String> ann = tir::GetAnn<String>(
sch->GetSRef(block_rv), tir::attr::meta_schedule_override_tiling_structure)) {
s = ann.value();
} else {
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, s);
}
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, s);
std::vector<int> s_indices;
std::vector<int> r_indices;
std::tie(s_indices, r_indices) = ParseTileStructure(s);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
return {sch};
}
String s = this->structure;
if (Optional<String> ann = tir::GetAnn<String>(
sch->GetSRef(block_rv), tir::attr::meta_schedule_override_tiling_structure)) {
s = ann.value();
}
std::vector<int> s_indices;
std::vector<int> r_indices;
std::tie(s_indices, r_indices) = ParseTileStructure(s);

std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info;
for (int i = 0, n = intrin_groups.size(); i < n; ++i) {
Expand All @@ -212,13 +220,10 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
const tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
Schedule new_sch = sch->Copy();
new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, s);
initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv));
}
Array<Schedule> results;
std::vector<int> s_indices;
std::vector<int> r_indices;
std::tie(s_indices, r_indices) = ParseTileStructure(structure);
for (auto&& state : ApplySubRules(initial_states, s_indices, r_indices)) {
TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with "
<< state.as<TensorCoreStateNode>()->intrin_group.compute_intrin;
Expand Down

0 comments on commit 3d002b9

Please sign in to comment.