Skip to content

Commit

Permalink
[MLIR][transform][python] add sugared python abstractions for transfo…
Browse files Browse the repository at this point in the history
…rm dialect (llvm#75073)

This adds Python abstractions for the different handle types of the
transform dialect

The abstractions allow for straightforward chaining of transforms by
calling their member functions.
As an initial PR for this infrastructure, only a single transform is
included: `transform.structured.match`.
With a future `tile` transform abstraction an example of the usage is: 
```Python
def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])
```
to generate the following IR:
```mlir
%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]
```

These abstractions are intended to enhance the usability and flexibility
of the transform dialect by providing an accessible interface that
allows for easy assembly of complex transformation chains.
  • Loading branch information
martin-luecke authored Dec 15, 2023
1 parent f0b44ce commit 681eacc
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 5 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir-c/Dialect/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
Expand All @@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type);

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);

MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
MlirType type);

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,13 @@ class mlir_type_subclass : public pure_subclass {
.attr("replace")(superCls.attr("__name__"), captureTypeName);
});
if (getTypeIDFunction) {
// 'get_static_typeid' method.
// This is modeled as a static method instead of a static property because
// `def_property_readonly_static` is not available in `pure_subclass` and
// we do not want to introduce the complexity that pybind uses to
// implement it.
def_staticmethod("get_static_typeid",
[getTypeIDFunction]() { return getTypeIDFunction(); });
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction())(pybind11::cpp_function(
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Bindings/Python/DialectTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyOpType =
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
mlirTransformAnyOpTypeGetTypeID);
anyOpType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand All @@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyParamType =
mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
mlirTransformAnyParamTypeGetTypeID);
anyParamType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand All @@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto anyValueType =
mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
mlirTransformAnyValueTypeGetTypeID);
anyValueType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//

auto paramType =
mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
mlirTransformParamTypeGetTypeID);
paramType.def_classmethod(
"get",
[](py::object cls, MlirType type, MlirContext ctx) {
Expand Down
18 changes: 17 additions & 1 deletion mlir/lib/CAPI/Dialect/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}

MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
return wrap(transform::AnyOpType::getTypeID());
}

MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
Expand All @@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}

MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
return wrap(transform::AnyParamType::getTypeID());
}

MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
Expand All @@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}

MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
return wrap(transform::AnyValueType::getTypeID());
}

MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
Expand Down Expand Up @@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
}

//===---------------------------------------------------------------------===//
// AnyOpType
// ParamType
//===---------------------------------------------------------------------===//

bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}

MlirTypeID mlirTransformParamTypeGetTypeID(void) {
return wrap(transform::ParamType::getTypeID());
}

MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings(
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
)

declare_mlir_python_sources(
MLIRPythonSources.Dialects.transform.extras
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
extras/dialects/transform/__init__.py)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
148 changes: 148 additions & 0 deletions mlir/python/mlir/extras/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import Callable, Optional, Sequence

from .... import ir
from ....dialects import transform
from ....dialects.transform import structured


class Handle(ir.Value):
"""
Base class for wrappers around different types of transform handle with
methods to chain further transforms.
The fields `children` and `parent` are used to capture the relation of
handles statically in order to enable further analysis. The payload
operation of a child handle is nested into a region of the payload operation
of the corresponding parent handle.
"""

def __init__(
self,
v: ir.Value,
*,
parent: Optional[Handle] = None,
children: Optional[Sequence[Handle]] = None,
):
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []


@ir.register_value_caster(transform.AnyOpType.get_static_typeid())
@ir.register_value_caster(transform.OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
transforms.
"""

def __init__(
self,
v: ir.Value,
*,
parent: Optional[Handle] = None,
children: Optional[Sequence[Handle]] = None,
):
super().__init__(v, parent=parent, children=children)

def match_ops(
self,
ops: str
| ir.OpView
| structured.MatchInterfaceEnum
| Sequence[str | ir.OpView],
) -> OpHandle:
"""
Emits a `transform.structured.MatchOp`.
Returns a handle to payload ops that match the given names, types, or
interface. If only a single type is given, the value wrapped by the
resulting handle is populated with the respective type.
"""
# Handle interface.
if isinstance(ops, structured.MatchInterfaceEnum) or (
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
):
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
match_op = structured.MatchOp(
transform.AnyOpType.get(),
self,
interface=ops,
)

# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
op_type = transform.OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
op_type = transform.AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
op_type = transform.OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
self,
op_names,
)

handle = OpHandle(match_op.results_, parent=self)
self.children.append(handle)
return handle


def insert_transform_script(
block_or_insertion_point: ir.Block | ir.InsertionPoint,
script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
"""
Inserts the transform script of the schedule into the module. The script
should accept an instance of OpHandle as argument, which will be called with
the block arg of the newly created named_sequence op.
Example:
This python code
```
module = ir.Module.create()
def test_match_ops_single(module: OpHandle):
module.match_ops(scf.ForOp)
insert_transform_script(module.body, script)
```
generates the following IR:
```
module {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
^bb0(%arg0: !transform.any_op):
%0 = transform.structured.match ops{["scf.for"]} in %arg0
: (!transform.any_op) -> !transform.op<"scf.for">
}
}
```
"""
if isinstance(block_or_insertion_point, ir.Block):
context = block_or_insertion_point.owner.context
insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
else:
context = block_or_insertion_point.block.owner.context
insertion_point = block_or_insertion_point

with context, ir.Location.unknown(context):
with insertion_point:
named_sequence_op = transform.NamedSequenceOp(
"__transform_main", [transform.AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
transform.YieldOp([])

if dump_script:
print(named_sequence_op)
Loading

0 comments on commit 681eacc

Please sign in to comment.