forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][transform][python] add sugared python abstractions for transfo…
…rm dialect (llvm#75073) This adds Python abstractions for the different handle types of the transform dialect The abstractions allow for straightforward chaining of transforms by calling their member functions. As an initial PR for this infrastructure, only a single transform is included: `transform.structured.match`. With a future `tile` transform abstraction an example of the usage is: ```Python def script(module: OpHandle): module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32]) ``` to generate the following IR: ```mlir %0 = transform.structured.match interface{TilingInterface} in %arg0 %tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32] ``` These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.
- Loading branch information
1 parent
f0b44ce
commit 681eacc
Showing
7 changed files
with
291 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations | ||
from typing import Callable, Optional, Sequence | ||
|
||
from .... import ir | ||
from ....dialects import transform | ||
from ....dialects.transform import structured | ||
|
||
|
||
class Handle(ir.Value): | ||
""" | ||
Base class for wrappers around different types of transform handle with | ||
methods to chain further transforms. | ||
The fields `children` and `parent` are used to capture the relation of | ||
handles statically in order to enable further analysis. The payload | ||
operation of a child handle is nested into a region of the payload operation | ||
of the corresponding parent handle. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
v: ir.Value, | ||
*, | ||
parent: Optional[Handle] = None, | ||
children: Optional[Sequence[Handle]] = None, | ||
): | ||
super().__init__(v) | ||
self.parent = parent | ||
self.children = children if children is not None else [] | ||
|
||
|
||
@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) | ||
@ir.register_value_caster(transform.OperationType.get_static_typeid()) | ||
class OpHandle(Handle): | ||
""" | ||
Wrapper around a transform operation handle with methods to chain further | ||
transforms. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
v: ir.Value, | ||
*, | ||
parent: Optional[Handle] = None, | ||
children: Optional[Sequence[Handle]] = None, | ||
): | ||
super().__init__(v, parent=parent, children=children) | ||
|
||
def match_ops( | ||
self, | ||
ops: str | ||
| ir.OpView | ||
| structured.MatchInterfaceEnum | ||
| Sequence[str | ir.OpView], | ||
) -> OpHandle: | ||
""" | ||
Emits a `transform.structured.MatchOp`. | ||
Returns a handle to payload ops that match the given names, types, or | ||
interface. If only a single type is given, the value wrapped by the | ||
resulting handle is populated with the respective type. | ||
""" | ||
# Handle interface. | ||
if isinstance(ops, structured.MatchInterfaceEnum) or ( | ||
isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ | ||
): | ||
if isinstance(ops, str): | ||
ops = structured.MatchInterfaceEnum[ops] | ||
match_op = structured.MatchOp( | ||
transform.AnyOpType.get(), | ||
self, | ||
interface=ops, | ||
) | ||
|
||
# Handle op name(s), either given directly as string or given as op. | ||
else: | ||
if isinstance(ops, str): | ||
op_type = transform.OperationType.get(ops) | ||
op_names = [ops] | ||
elif isinstance(ops, Sequence): | ||
op_type = transform.AnyOpType.get() | ||
op_names = [ | ||
op if isinstance(op, str) else op.OPERATION_NAME for op in ops | ||
] | ||
else: | ||
op_type = transform.OperationType.get(ops.OPERATION_NAME) | ||
op_names = [ops.OPERATION_NAME] | ||
match_op = structured.MatchOp.match_op_names( | ||
op_type, | ||
self, | ||
op_names, | ||
) | ||
|
||
handle = OpHandle(match_op.results_, parent=self) | ||
self.children.append(handle) | ||
return handle | ||
|
||
|
||
def insert_transform_script( | ||
block_or_insertion_point: ir.Block | ir.InsertionPoint, | ||
script: Callable[[OpHandle], None], | ||
dump_script: bool = False, | ||
) -> None: | ||
""" | ||
Inserts the transform script of the schedule into the module. The script | ||
should accept an instance of OpHandle as argument, which will be called with | ||
the block arg of the newly created named_sequence op. | ||
Example: | ||
This python code | ||
``` | ||
module = ir.Module.create() | ||
def test_match_ops_single(module: OpHandle): | ||
module.match_ops(scf.ForOp) | ||
insert_transform_script(module.body, script) | ||
``` | ||
generates the following IR: | ||
``` | ||
module { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) { | ||
^bb0(%arg0: !transform.any_op): | ||
%0 = transform.structured.match ops{["scf.for"]} in %arg0 | ||
: (!transform.any_op) -> !transform.op<"scf.for"> | ||
} | ||
} | ||
``` | ||
""" | ||
if isinstance(block_or_insertion_point, ir.Block): | ||
context = block_or_insertion_point.owner.context | ||
insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) | ||
else: | ||
context = block_or_insertion_point.block.owner.context | ||
insertion_point = block_or_insertion_point | ||
|
||
with context, ir.Location.unknown(context): | ||
with insertion_point: | ||
named_sequence_op = transform.NamedSequenceOp( | ||
"__transform_main", [transform.AnyOpType.get()], [] | ||
) | ||
with ir.InsertionPoint(named_sequence_op.body): | ||
script(named_sequence_op.bodyTarget) | ||
transform.YieldOp([]) | ||
|
||
if dump_script: | ||
print(named_sequence_op) |
Oops, something went wrong.