Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Construct the inverse in SuggestIndexMap #12797

Merged
merged 5 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> final_indices;

/*!
* \brief The inverse index map.
*
* When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
* Otherwise, the inverse index map will be computed on the fly.
* It is the user's responsibility to ensure the correctness of the pre-defined inverse index
* map.
*
* \note ObjectRef is used here instead of IndexMap to avoid circular reference.
*/
Optional<ObjectRef> inverse_index_map;

/*!
* \brief Default constructor
*
Expand Down Expand Up @@ -133,6 +145,7 @@ class IndexMapNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
v->Visit("inverse_index_map", &inverse_index_map);
}

bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
Expand All @@ -153,15 +166,24 @@ class IndexMapNode : public Object {

class IndexMap : public ObjectRef {
public:
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);
/*!
* \brief The constructor
* \param initial_indices Variables representing the indices prior to remapping
* \param final_indices Expressions defining the indices after remapping.
* \param inverse_index_map The optional pre-defined inverse index map
*/
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map = NullOpt);

/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \param inverse_index_map The optional pre-defined inverse index map
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
Optional<IndexMap> inverse_index_map = NullOpt);

/*! \brief Generate the inverse mapping.
*
Expand Down
46 changes: 40 additions & 6 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ class IndexMap(Object):
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.
"""

initial_indices: List[Var]
Expand All @@ -281,11 +287,19 @@ class IndexMap(Object):
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)
def __init__(self, initial_indices, final_indices, inverse_index_map):
if isinstance(inverse_index_map, Callable):
inverse_index_map = IndexMap.from_func(inverse_index_map)
self.__init_handle_by_constructor__(
_ffi_api.IndexMap, initial_indices, final_indices, inverse_index_map
)

@staticmethod
def from_func(mapping_function: Callable, ndim: Optional[int] = None):
def from_func(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
):
"""Create an index map from a function

Parameters
Expand All @@ -305,14 +319,23 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
mapping_function does not use variadic arguments, ndim is
optional.

inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.

Returns
-------
index_map: IndexMap

Returns an IndexMap representing the `mapping_function`.

"""
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
index_map, axis_separators = IndexMap.from_func_with_separators(
mapping_function, ndim, inverse_index_map
)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
"may not return IndexMap.AXIS_SEPARATOR. "
Expand All @@ -321,7 +344,11 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
return index_map

@staticmethod
def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
def from_func_with_separators(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
):
"""Create an index map from a function

Parameters
Expand All @@ -341,6 +368,13 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
mapping_function does not use variadic arguments, ndim is
optional.

inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.

Returns
-------
ret: Tuple[IndexMap, List[int]]
Expand Down Expand Up @@ -401,7 +435,7 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
f"Instead received {val} of type {type(val)}."
)

return IndexMap(initial_indices, final_indices), axis_separators
return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators

def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
Expand Down
47 changes: 40 additions & 7 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
namespace tvm {
namespace tir {

IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map) {
auto n = make_object<IndexMapNode>();
n->initial_indices = std::move(initial_indices);
n->final_indices = std::move(final_indices);
n->inverse_index_map = std::move(inverse_index_map);
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
Optional<IndexMap> inverse_index_map) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(initial_indices, func(initial_indices));
return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
Expand Down Expand Up @@ -114,6 +117,10 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
if ((*this)->inverse_index_map.defined()) {
// return the pre-defined inverse index map if exists.
return Downcast<IndexMap>((*this)->inverse_index_map.value());
}
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
Expand Down Expand Up @@ -232,7 +239,14 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
return output;
}

String IndexMapNode::ToPythonString() const {
/*!
* \brief Auxilarry function to comvert an index map to lambda expression in Python.
* \param initial_indices The initial indices in the index map.
* \param final_indices The final indices in the index map.
* \return The lambda expression string.
*/
std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices) {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
Expand All @@ -259,10 +273,28 @@ String IndexMapNode::ToPythonString() const {
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
if (i != 0) {
oss << " ";
}
oss << Substitute(final_indices[i], var_remap);
oss << ", ";
oss << ",";
}
oss << ")";
return oss.str();
}

String IndexMapNode::ToPythonString() const {
std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices);
if (!inverse_index_map.defined()) {
return String(lambda_expr);
}
// Also convert the inverse index map.
IndexMap inverse = Downcast<IndexMap>(inverse_index_map.value());
std::string inverse_lambda_expr =
IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
std::ostringstream oss;
oss << "tvm.tir.IndexMap.from_func(" << lambda_expr
<< ", inverse_index_map=" << inverse_lambda_expr << ")";
return String(oss.str());
}

Expand All @@ -275,8 +307,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
return IndexMap(initial_indices, final_indices);
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map) {
return IndexMap(initial_indices, final_indices, inverse_index_map);
});

TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
Expand Down
49 changes: 43 additions & 6 deletions src/tir/schedule/analysis/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,25 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
}
return a.lower_factor > b.lower_factor;
});
// Compute the inverse permutation by argsort
std::vector<int> inverse_order = order;
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
std::sort(inverse_order.begin(), inverse_order.end(),
[&order](int _a, int _b) -> bool { return order[_a] < order[_b]; });
// Step 5. Create the indexing mapping
auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
split_exprs = std::move(split_exprs), //
order = std::move(order), //
shape = buffer->shape, //
&split_exprs, //
&order, //
& shape = buffer->shape, //
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
analyzer //
](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), shape.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
}
// Step 5.1: Fuse all indices into a flattened one
PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
int ndim = split_exprs.size();
// Step 5.1. Split the flattened index according to `split_exprs`
// Step 5.2. Split the flattened index according to `split_exprs`
std::vector<PrimExpr> split;
split.reserve(ndim);
for (int i = ndim - 1; i >= 0; --i) {
Expand All @@ -190,15 +195,47 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
index = floordiv(index, extent);
}
std::reverse(split.begin(), split.end());
// Step 5.2. Reorder the indexing pattern according to `order`
// Step 5.3. Reorder the indexing pattern according to `order`
Array<PrimExpr> results;
results.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
results.push_back(split[order[i]]);
}
return results;
};
return IndexMap::FromFunc(ndim, f_alter_layout);
// Step 6: Create the inverse index mapping.
auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape,
analyzer](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), split_exprs.size());
// Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3.
// After the inverse permutation, indices[i] corresponds to split_exprs[i]
Array<Var> inv_permuted_indices;
inv_permuted_indices.reserve(indices.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
const Var& index = indices[inverse_order[i]];
inv_permuted_indices.push_back(index);
analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent)));
}

// Step 6.2: Fuse all the indices. This is the inverse of Step 5.2.
PrimExpr flattened_index = make_const(indices[0]->dtype, 0);
int64_t stride = 1;
for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) {
flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index;
stride *= split_exprs[i].extent;
}
// Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1.
Array<PrimExpr> result;
result.reserve(shape.size());
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i]));
flattened_index = floordiv(flattened_index, shape[i]);
result.push_back(index);
}
return Array<PrimExpr>(result.rbegin(), result.rend());
};
IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse);
return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map);
}

TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,47 @@ def test_suggest_index_map_bijective():
assert index_map.is_equivalent_to(expected_index_map)


def test_suggest_index_map_winograd():
"""use case in winograd conv where the indices are complicated"""
fused_outer, i3_3_fused, i4_0, i4_1 = _make_vars("fused_outer", "i3_3_fused", "i4_0", "i4_1")
eps = floordiv(fused_outer, 336) * 2 + floordiv(floormod(fused_outer, 16), 8)
nu = floordiv(floormod(fused_outer, 336), 112) * 2 + floordiv(floormod(fused_outer, 8), 4)
co = floormod(fused_outer, 4) * 32 + i3_3_fused
ci = (i4_0 * 32) + i4_1
buffer = decl_buffer(shape=[6, 6, 128, 128])
index_map = suggest_index_map(
buffer=buffer,
indices=[eps, nu, co, ci],
loops=_make_loops(
loop_vars=[fused_outer, i3_3_fused, i4_0, i4_1],
extents=[1008, 32, 4, 32],
),
predicate=True,
)
expected_index_map = IndexMap.from_func(
lambda i0, i1, i2, i3: (
floordiv(i0, 2),
floordiv(i1, 2),
floormod(i0, 2),
floormod(((i1 * 4) + floordiv(i2, 32)), 8),
floormod(i2, 32),
floordiv(i3, 32),
floormod(i3, 32),
)
)
assert index_map.is_equivalent_to(expected_index_map)
inverse_index_map = index_map.inverse(buffer.shape)
expected_inverse_index_map = IndexMap.from_func(
lambda i0, i1, i2, i3, i4, i5, i6: (
((i0 * 2) + i2),
((i1 * 2) + floordiv(((i3 * 32) + i4), 128)),
floormod(((i3 * 32) + i4), 128),
((i5 * 32) + i6),
)
)
assert inverse_index_map.is_equivalent_to(expected_inverse_index_map)


@tvm.script.ir_module
class DenseVNNIModule:
@T.prim_func
Expand Down