Skip to content

Commit

Permalink
[TIR] Well-Formed Verifier (apache#12166)
Browse files Browse the repository at this point in the history
* tir_well_formed_verifier

* fix typo

* lint

* fix testcase
  • Loading branch information
Hzfengsy authored and Mikael Sevenier committed Jul 26, 2022
1 parent 754d40a commit a919ece
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 15 deletions.
11 changes: 10 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file tvm/tir/analysis.h
* \brief Analysis utilitie and passes for TIR.
* \brief Analysis utilities and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
Expand Down Expand Up @@ -220,6 +220,15 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
*/
TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);

/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
* - Check if expressions not contain vars that is defined outside the block.
* \param func The PrimFunc to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR function.
*/
TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,23 @@ def apply_prim_func_arg_and_result_memory_constraints(
return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
func, relay_func_type, arg_and_result_memory_scopes
)


def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
"""Verify if the given TIR is well-formed. The verification includes:
- Check if expressions not contain vars that is defined outside the block.
Parameters
----------
func: tvm.tir.PrimFunc
The function to be verified.
assert_mode: bool
The indicator if it raises an error when the function is not well-formed.
Returns
-------
result: bool
Whether it is a well-formed TIR function.
"""
return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member
137 changes: 137 additions & 0 deletions src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/analysis/verify_well_formed.cc
* \brief Check if schedulable tir is well-formed.
*/

#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include "../ir/functor_common.h"

namespace tvm {
namespace tir {

/*! \brief Verify all Expr inside the block does not contain:
* 1. loop vars outside the current block.
* 2. block vars of parent blocks.
*/
class BlockVarAccessVerifier : public StmtExprVisitor {
public:
static bool Verify(const PrimFunc& func, bool assert_mode) {
BlockVarAccessVerifier verifier(assert_mode);
verifier(func->body);
return !verifier.has_error_;
}

private:
explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {}

void VisitStmt(const Stmt& stmt) final {
if (!has_error_) {
StmtExprVisitor::VisitStmt(stmt);
}
}

void VisitExpr(const PrimExpr& expr) final {
if (!has_error_) {
StmtExprVisitor::VisitExpr(expr);
}
}

void VisitExpr_(const VarNode* op) final {
auto it = loop_vars_.find(op);
if (it != loop_vars_.end() && it->second < cur_block_level_) {
has_error_ = true;
if (assert_mode_) {
report_error(op);
}
}
}

void VisitStmt_(const ForNode* op) final {
ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end());
loop_vars_[op->loop_var.get()] = cur_block_level_;
StmtExprVisitor::VisitStmt_(op);
loop_vars_.erase(op->loop_var.get());
}

void VisitStmt_(const BlockNode* op) final {
// Do not check boundary if it's a opaque block.
cur_block_level_ += !op->iter_vars.empty();

// Step 0. Skip block iter var's domain

// Step 1. Visit read/write regions
auto fvisit_buffer_region = [this](const BufferRegion& s) {
for (const auto& range : s->region) {
this->VisitExpr(range->min);
this->VisitExpr(range->extent);
}
};
VisitArray(op->reads, fvisit_buffer_region);
VisitArray(op->writes, fvisit_buffer_region);

// Step 2. Visit match buffers
VisitArray(op->match_buffers,
[fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
fvisit_buffer_region(match_buffer_region->source);
});

// Step 3. Visit init and body
if (op->init.defined()) {
this->VisitStmt(op->init.value());
}
this->VisitStmt(op->body);

cur_block_level_ -= !op->iter_vars.empty();
}

private:
void report_error(const VarNode* var) {
// TODO(siyuan): use the error message from the parser.
LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint
<< " is used inside the current block.";
}

/*! \brief The map from outside loop vars to its corresponding block level. */
std::unordered_map<const VarNode*, size_t> loop_vars_;
/*! \brief Whether it's in assert mode. */
bool assert_mode_;
/*! \brief Current nested block stack level. */
size_t cur_block_level_{0};
/*! \brief Whether there is error. */
bool has_error_{false};
};

bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
return false;
}
// TODO(Siyuan): add more checks here.
return true;
}

TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);

} // namespace tir
} // namespace tvm
1 change: 1 addition & 0 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class StateCreator : private StmtVisitor {
for (const auto& kv : n->mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* func = base_func.as<PrimFuncNode>()) {
VerifyWellFormed(GetRef<PrimFunc>(func));
creator.VisitStmt(func->body);
BlockInfoCollector::Collect(self, func->body);
}
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_tir_analysis_verify_well_formed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm.script import tir as T


def test_pass_simple():
@T.prim_func
def element_wise(
A: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
):
B = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
# It's a opaque block , so it can use outside variables
C[i, j] = B[i, j] * 2.0

assert tvm.tir.analysis.verify_well_formed(element_wise)


def test_fail_use_out_loop_var():
@T.prim_func
def element_wise(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
):
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
# we cannot use `i` since it's defined outside the block
B[vi, vj] = A[i, vj] * 2.0

assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False)


if __name__ == "__main__":
tvm.testing.main()
9 changes: 4 additions & 5 deletions tests/python/unittest/test_tir_schedule_set_axis_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import tvm
import tvm.testing
Expand Down Expand Up @@ -76,12 +75,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
B_subregion0[()] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
C[vi, vj] = B_subregion1[()] + 1.0


Expand All @@ -92,12 +91,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion0[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
C[vi, vj] = B_subregion1[()] + T.float32(1)


Expand Down
15 changes: 6 additions & 9 deletions tests/python/unittest/test_tir_schedule_set_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
Expand Down Expand Up @@ -59,12 +60,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
B_subregion0[()] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
C[vi, vj] = B_subregion1[()] + 1.0


Expand All @@ -75,12 +76,12 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
B_subregion0_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
B_subregion0_shared[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
B_subregion1_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
C[vi, vj] = B_subregion1_shared[()] + T.float32(1)


Expand Down Expand Up @@ -128,8 +129,4 @@ def test_set_scope_subregion():


if __name__ == "__main__":
test_set_scope()
test_set_scope_fail_on_output_buffer()
test_set_scope_fail_on_index_out_of_bound()
test_set_scope_fail_on_invalid_scope()
test_set_scope_subregion()
tvm.testing.main()

0 comments on commit a919ece

Please sign in to comment.