Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(py-dsl): parse compute of python dsl #57731

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/cinn/pybind/ir/ir_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ void BindIrContext(py::module *m) {
.def_static("MakeThenContext",
[]() { return IRContext(new ThenContextNode()); });

m->def("link_to_parent_context", &pybind::LinkToParentContext);

py::class_<IRBuilder> ir_builder(*m, "IRBuilder");
ir_builder.def(py::init<>())
.def("EnterWithContext", &IRBuilder::EnterWithContext)
Expand Down
3 changes: 2 additions & 1 deletion python/cinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .version import full_version as __version__
from .runtime.cinn_jit import to_cinn_llir
import os

cinndir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -189,4 +191,3 @@
reduce_mul,
reduce_sum,
)
from .version import full_version as __version__
17 changes: 17 additions & 0 deletions python/cinn/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .compiler import compile

__all__ = ["compile"]
38 changes: 38 additions & 0 deletions python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ..runtime import CinnLowerLevelIrJit
from .compute_code_generator import ComputeCodeGenerator


def ast_to_llir(fn, inputs_signature):
function_name = fn.__name__
# 1. Parse CINN Compute
llir_compute_generator = ComputeCodeGenerator(
fn, function_name, inputs_signature
)
cinn_llir_func = llir_compute_generator.parse()
return cinn_llir_func


def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
if isinstance(fn, CinnLowerLevelIrJit):
llir_func = ast_to_llir(fn, jit_inputs_signature)
else:
raise Exception("Current Only support compile from CinnLowerLevelIrJit")

if just_convert:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this code?

It looks like just return llir_func

Copy link
Contributor Author

@6clc 6clc Oct 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later PR will insert code here to convert llir_func to the RunTime Module of CINN. As follows

    if just_convert:
        return llir_func

    rt_module = llir_to_runtime_module(
        llir_func, kwargs["target"], fn.__name__, kwargs["arg_names"]
    )

    return rt_module

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please send me the PR immediately after this PR.

return llir_func
return llir_func
245 changes: 245 additions & 0 deletions python/cinn/compiler/compute_code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import contextlib

from cinn import ir

from .expr_executor import ExprExecutor, exec_assign
from .utils import VariableTable, is_node_parsed_in_schedule


class ComputeCodeGenerator(ast.NodeVisitor):
"""
Convert python ast to CINN Lower Level IR,
containing only the semantics of the compute part
"""

def __init__(self, fn, function_name, inputs_signature):
self.fn = fn
self.function_name = function_name
self.inputs_signature = inputs_signature
self.cinn_llir_func = None
self.variables_table = VariableTable()
self.extra_scope = {"range": ir.sequential}

def parse(self):
ast_node = self.fn.parse()
with ir.IRBuilder() as builder, self.variables_table:
for k, v in self.fn.scope.items():
self.variables_table.add(k, v)
for k, v in self.extra_scope.items():
self.variables_table.add(k, v)
self.visit(ast_node)
return builder.get()

def visit_FunctionDef(self, node) -> None:
"""
Parse CINN Low Level IR FunctionDef.

Args:
node(ast.FunctionDef): The ast FunctionDef Node
"""
with ir.LowerFuncContext(self.function_name) as func_ctx:
arg_names = self.visit(node.args)

assert len(node.args.defaults) == 0, "Not support default args"

# 1. Construct args of function
for i, arg_name in enumerate(arg_names):
# Obj of Argument is ir::Buffer
if hasattr(self.inputs_signature[i], "dtype"):
tensor_shape = [
ir.Expr(dim) for dim in self.inputs_signature[i].shape
]
llir_value = ir._Buffer_.make(
arg_name, self.inputs_signature[i].dtype
)
ir.Arg(arg_name, llir_value)
llir_value = ir._Tensor_.make(
arg_name,
self.inputs_signature[i].dtype,
tensor_shape,
tensor_shape,
)
self.variables_table.add(arg_name, llir_value)
# Obj of Argument is ir::Var
else:
llir_value = ir.Var(arg_name)
ir.Arg(arg_name, llir_value)
llir_value = ir.Expr(llir_value)
self.variables_table.add(arg_name, llir_value)

# 2. Construct body of function
body = self.visit_compound_statement(node.body)

def visit_compound_statement(self, stmts):
for stmt in stmts:
self.visit(stmt)

def visit_arguments(self, node):
"""
Parse CINN Low Level IR Argument.
If it is not jit mode, it will get information from arg.annoatation.

Args:
node(ast.arguments): The ast argument Node

Returns:
list[string]: A list of parameter names
"""
arg_names = [arg.arg for arg in node.args]

if len(self.inputs_signature) != len(arg_names):
self.inputs_signature = []
for arg in node.args:
arg_annotation = arg.annotation
if isinstance(arg_annotation, ast.Call):
self.inputs_signature.append(
ExprExecutor(self.variables_table.get()).exec(
arg_annotation
)
)
elif isinstance(arg_annotation, int):
if (
-(2**21) <= arg_annotation
and arg_annotation <= 2**31 - 1
):
self.inputs_signature.append("i32")
elif (
2**63 <= arg_annotation
and arg_annotation <= 2**64 - 1
):
self.inputs_signature.append("u64")
else:
self.inputs_signature.append("i64")
elif isinstance(arg_annotation, float):
return self.inputs_signature.append("fp32")
else:
raise TypeError(
f'Unsupported type {type(arg_annotation)} for {arg_annotation}'
)

return arg_names

def visit_For(self, node) -> ir.Expr:
"""
parse CINN Low Level IR For.

Args:
node(ast.For): The ast For node
"""
for_ctx = ExprExecutor(self.variables_table.get()).exec(node.iter)
with self.variables_table:
with for_ctx as loop_var:
local_var_table = exec_assign(
target=node.target, source=loop_var
)
for k, v in local_var_table.items():
loop_var.rename(k)
self.variables_table.add(k, ir.Expr(v))
self.visit_compound_statement(node.body)

def visit_Assign(self, node):
"""
parse CINN Low Level IR Store.

Args:
node(ast.Assign): The ast Assign node

Returns:
ir.Expr, Points to the Expr of ir::ExprNode<Store>
"""

if isinstance(node.value, ast.Call) and is_node_parsed_in_schedule(
node.value
):
return "no compute"

assert (
len(node.targets) == 1
), "Unsupport targets is a \
list of nodes, like 'a = b = c'"
lhs = node.targets[0]

# 1 parse RHS
rhs_expr = ExprExecutor(self.variables_table.get()).exec(node.value)

# 2 parse LHS
# 2.1 Type of arg is Tensor
if isinstance(lhs, ast.Subscript):
expr_tensor = ExprExecutor(self.variables_table.get()).exec(
lhs.value
)
if isinstance(lhs.slice, ast.Tuple):
expr_indices = []
for idx in lhs.slice.elts:
expr_indices.append(
ExprExecutor(self.variables_table.get()).exec(idx)
)
else:
expr_indices = [
ExprExecutor(self.variables_table.get()).exec(lhs.slice)
]
if not isinstance(rhs_expr, ir.Expr):
rhs_expr = ir.Expr(rhs_expr)
ir.TensorStore(expr_tensor.Expr(), rhs_expr, expr_indices)
# 2.2 Type of arg is Var
else:
local_var_table = exec_assign(target=lhs, source=rhs_expr)
if isinstance(lhs, ast.Tuple):
for k, v in local_var_table.items():
v.as_var_ref().rename(k)
self.variables_table.add(k, v)
else:
for k, v in local_var_table.items():
v[0].as_var_ref().rename(k)
self.variables_table.add(k, v[0])

def visit_If(self, node):
with self.variables_table:
with ir.IfContext(
ExprExecutor(self.variables_table.get()).exec(node.test)
):
with ir.ThenContext():
with self.variables_table:
self.visit_compound_statement(node.body)
if node.orelse:
with ir.ElseContext():
with self.variables_table:
self.visit_compound_statement(node.body)

def visit_With(self, node):
with self.variables_table:
with contextlib.ExitStack() as context_stack:
for item in node.items:
cur_ctx = ExprExecutor(self.variables_table.get()).exec(
item.context_expr
)
cur_ctx = context_stack.enter_context(cur_ctx)
if item.optional_vars is not None:
local_var_table = exec_assign(
target=item.optional_vars, source=cur_ctx
)
for k, v in local_var_table.items():
self.variables_table.add(k, v)
body = self.visit_compound_statement(node.body)

def visit_Expr(self, node):
if is_node_parsed_in_schedule(node.value):
return
res = ExprExecutor(self.variables_table.get()).exec(node.value)
if isinstance(res, ir.Expr):
ir.link_to_parent_context(res)
Loading