Skip to content

[Hackathon 7th Paddle python pass] passes ir.py代码阅读

yinfan98 edited this page Jul 21, 2024 · 1 revision

对这项工作整体的理解

做一个Python端的接口等,能满足pattern变换在python端的撰写。并能把这个pattern注册到c++runtime里。

paddle/incubate/passes/ir.py 代码阅读,这个装饰器@RegisterPass 的功能,就是把用户自定义的一段Python函数转换成一个python侧的pattern

代码整体总结

Python IR代码通过RegisterPassHelper把输入函数中原始的pattern和replace 输入参数,实现转换成了编译器相关的_desc描述。 并通过VarHelper和OpHelper为算子和参数定义了实现。

入口 RegisterPass

def decorated(python_func):
    pass_type = python_func.__name__
    signature = inspect.signature(python_func)
    if len(signature.parameters) > 0:
        raise NotImplementedError(
            "Pass function with parameter is not supported now."
        )
    elif len(signature.parameters) == 0:
        pass_pairs = python_func()
        if _is_pass_pair(pass_pairs):
            pass_pairs = [pass_pairs]
        elif not all(map(_is_pass_pair, pass_pairs)):
            raise ValueError(
                "Return value of Pass function must be (callable, callable)."
            )
        helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
        core.register_pass(pass_type, helper.SerializeMultiPassDesc)
    return python_func

这里是装饰器的入口,inspect.signature(python_func)的功能是:inspect.signature(python_func) 可以获取函数 python_func 的签名。这里确保入参是空的。

import inspect

def example_function(a, b, c=1):
    pass

signature = inspect.signature(example_function)
print(signature)
>>> (a, b, c=1)

_is_pass_pair检查了pattern里是否包含了两个可调用的闭包。

def _is_pass_pair(check_pair):
    if isinstance(check_pair, (list, tuple)):
        if len(check_pair) == 2:
            if all(map(inspect.isfunction, check_pair)):
                return True
    return False

之后执行RegisterPassHelper和pass注册之后,结束主要逻辑。

RegisterPassHelper

输入了pass闭包,pass函数名字和一些限制信息之后。

_get_args_from_func

提取func中参数信息,根据静态输入or ParamAttr初始化args参数表

def _get_args_from_func(self, func):
    args = []
    arg_specs = inspect.getfullargspec(func)
    for arg_name in arg_specs.args:
        input_spec = self._input_specs.get(arg_name)
        if isinstance(input_spec, paddle.static.InputSpec):
            args.append(
                PassDesc.VarHelper(
                    arg_name, input_spec.shape, input_spec.dtype
                )
            )
        elif isinstance(input_spec, paddle.ParamAttr):
            args.append(paddle.ParamAttr(arg_name))
        else:
            args.append(PassDesc.VarHelper(arg_name, [-1]))
    return args

_prune_program_desc

_prune_program_desc 方法的主要功能是从给定的操作符描述中移除不必要的或默认的属性,以简化程序描述。

_func_to_program_desc

def _func_to_program_desc(self, func, ops):
    vars = []
    program = paddle.static.Program()
    startup_program = paddle.static.Program()
    with paddle.static.program_guard(program, startup_program):
        args = self._get_args_from_func(func)
        vars.extend(args)
        outs = func(*args)
        if not isinstance(outs, (list, tuple)):
            outs = [outs]
        # 遍历输出,并添加进vars里。
        for out in outs:
            if isinstance(out, PassDesc.OpHelper):
                op_outs = out.Outputs()
                if len(op_outs) != 1:
                    raise ValueError(
                        f"Operator '{out._type}' has multiple outputs, please specify one output variable."
                    )
                for op_out in op_outs.values():
                    vars.extend(op_out)
            else:
                vars.append(out)
    # 遍历当前块的描述,
    block_desc = program.current_block().desc
    for i in range(block_desc.op_size()):
        ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
    self._prune_program_desc(ops)
    return vars, program.current_block().ops
with paddle.static.program_guard(program, startup_program): 

的作用是使用 PaddlePaddle 的 program_guard 上下文管理器保护作用域,使得在其上下文中定义的所有操作和变量都归属于指定的 program 和 startup_program。

for i in range(block_desc.op_size()):
    ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
  • block_desc.op_size():返回当前block中操作符的数量。
  • block_desc.op(i):获取第 i 个操作符的描述对象。
  • serialize_to_string():将操作符描述对象序列化为字符串。
  • ops.add():在 ops 集合中添加一个新的操作符。
  • ParseFromString(...):从字符串解析出操作符的描述,恢复成一个操作符描述对象。

这个循环遍历了当前block中的所有操作符,将它们序列化并添加到 ops 集合中。为什么这里要使用序列化+反序列化呢,我认为是为了深拷贝+数据完整性的考量。

重要的方法 _convert_vars_to_pass_desc

def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
    def _add_element_conditions(conditions, elements):
        for element in elements:
            if element._condition:
                conditions.append(element._condition)
            _add_element_conditions(conditions, element._elements)

    for pattern, replace in zip(patterns, replaces):
        # Convert maps of inputs and outputs.
        var_map = desc.var_maps.add()
        var_map.pattern_var = pattern.name
        var_map.replace_var = replace.name
        conditions = desc.var_attr_conditions
        # Convert shape condition.
        if pattern.name in self._input_specs:
            condition = conditions.add()
            pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
            condition.condition_value.name = ""
            condition.condition_value.type = framework_pb2.AttrType.LONGS
            condition.condition_value.longs.extend(pattern.shape)
            condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
        # Convert attr conditions.
        if PassDesc.VarHelper == pattern.__class__:
            for attr in pattern._attrs.values():
                _add_element_conditions(conditions, [attr])

在这里我们定义了一个var_map,和condition。

if pattern.name in self._input_specs:
    condition = conditions.add()
    pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
    condition.condition_value.name = ""
    condition.condition_value.type = framework_pb2.AttrType.LONGS
    condition.condition_value.longs.extend(pattern.shape)
    condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ

我们把shape属性转换成了pattern中的描述语言Attr。并且定义了一系列条件值。

另一个重要的方法 _convert_ops_to_pass_desc

def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
    for replace in replaces:
        if isinstance(replace, PassDesc.OpHelper):
            for attr in replace._attrs.values():
                # Convert attr maps.
                mapped = attr._mapped
                if inspect.isfunction(mapped):
                    mapped = mapped(patterns)
                attr_map = desc.op_attr_maps.add()
                mapped._to_pass_desc_attr(attr_map.pattern_attr)
                attr._to_pass_desc_attr(attr_map.replace_attr)
                if mapped._operation is not None:
                    attr_map.operation.CopyFrom(mapped._operation)

方法 _convert_ops_to_pass_desc:

  • 遍历替换列表:处理其中的 PassDesc.OpHelper 对象。
  • 转换属性映射:将每个属性的映射(_mapped)和自身转换为desc的Attr。
  • 处理操作对象:如果映射包含操作对象(_operation),将其复制到attr中。

SerializeMultiPassDesc

def SerializeMultiPassDesc(self):
    switch_static_mode = paddle.in_dynamic_mode()
    if switch_static_mode:
        paddle.enable_static()
    multi_pass_desc = pass_desc_pb2.MultiPassDesc()
    multi_pass_desc.pass_type = self._pass_type
    # Traverse all pass pairs and convert them to PassDesc data.
    # Here need to add cache in the future.
    for pattern, replace in self._pass_pairs:
        pass_desc = multi_pass_desc.pass_descs.add()
        # Convert ProgramDescs of pattern and replace subgraphs.
        pattern_vars, pattern_ops = self._func_to_program_desc(
            pattern, pass_desc.pattern
        )
        replace_vars, replace_ops = self._func_to_program_desc(
            replace, pass_desc.replace
        )
        self._convert_vars_to_pass_desc(
            pattern_vars, replace_vars, pass_desc
        )
        self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc)
    if switch_static_mode:
        paddle.disable_static()
    return multi_pass_desc.SerializeToString()

主要for循环逻辑:

  • for pattern, replace in self._pass_pairs:遍历所有的转换规则对(pattern 和 replace)。
  • pass_desc = multi_pass_desc.pass_descs.add():在 multi_pass_desc 中添加一个新的 PassDesc 对象。
  • self._func_to_program_desc(pattern, pass_desc.pattern):将 pattern 转换为 ProgramDesc,并保存到 pass_desc.pattern 中。
  • self._func_to_program_desc(replace, pass_desc.replace):将 replace 转换为 ProgramDesc,并保存到 pass_desc.replace 中。
  • self._convert_vars_to_pass_desc(pattern_vars, replace_vars, pass_desc):将变量转换为 PassDesc。
  • self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc):将操作(ops)转换为 PassDesc。

PassDesc / AttrHelper

_to_pass_desc_attr

def _to_pass_desc_attr(self, pass_desc_attr):
    if isinstance(self._obj, PassDesc.VarHelper):
        pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable
        pass_desc_attr.var_name = self._obj.name
    else:
        pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator
        pass_desc_attr.op_index = self._obj._index
    pass_desc_attr.name = self._name
    if self._operation_type is not None:
        pass_desc_attr.operation = self._operation_type
    if self._element_index is not None:
        pass_desc_attr.element_index = self._element_index

AttrHelper -> PassDesc

_to_op_desc_attr

def _to_op_desc_attr(self, value, op_desc_attr):
    op_desc_attr.name = ""
    if isinstance(value, int):
        op_desc_attr.type = framework_pb2.AttrType.INT
        op_desc_attr.i = value
    else:
        raise NotImplementedError("Unimplemented transform operation.")

AttrHelper -> OP

_set_with_condition

def _set_with_condition(self, type, value):
    condition = pass_desc_pb2.PassDesc.AttrCondition()
    self._to_pass_desc_attr(condition.attr)
    condition.type = type
    if isinstance(value, PassDesc.AttrHelper):
        value._to_pass_desc_attr(condition.condition_attr)
    else:
        self._to_op_desc_attr(value, condition.condition_value)
    if self._operation:
        condition.operation.CopyFrom(self._operation)
    self._condition = condition

设置条件

MappedPattern

def MappedPattern(
    self, var=None, op=None, index=0, name=None, element_index=None
):
    if all([var, op]):
        raise ValueError("Only mapped one of which var or op.")

    def mapped_var(pattern_ops):
        raise NotImplementedError(
            "Mapping to variable is not implemented."
        )

    def mapped_op(pattern_ops):
        ops = [o for o in pattern_ops if o._type == op]
        if len(ops) <= index:
            raise ValueError(
                f"Index '{index}' of operator '{op}' is incorrect."
            )
        return PassDesc.AttrHelper(
            ops[index], name, element_index=element_index
        )

    self._mapped = mapped_op if var is None else mapped_var

把AttrHelper映射到特定的操作上

VarHelper

一个paddle.static.Variable的继承

Attr

def Attr(self, name):
    attr = self._attrs.get(name)
    if attr is None:
        attr = PassDesc.AttrHelper(self, name)
        self._attrs[name] = attr
    return attr

通过调用Attr来创建数据。

OpHelper

_to_readable_code

class OpHelper:
    def _to_readable_code(self, skip_op_callstack=True):
        assert isinstance(skip_op_callstack, bool), f"skip_op_callstack parameter's type is error, expect bool, received {type(skip_op_callstack)}"
        outputs_str = "{"
        outputs_str += ", ".join([f"{k}={v}" for k, v in self._outputs.items()])
        outputs_str += "}"

        inputs_str = "{"
        inputs_str += ", ".join([f"{k}={v}" for k, v in self._inputs.items()])
        inputs_str += "}"

        attrs_str = "{"
        attrs_str += ", ".join([f"{k}={v}" for k, v in self._attrs.items()])
        attrs_str += "}"

        op_str = f"{outputs_str} = {self._type}(inputs={inputs_str}, {attrs_str})"
        return op_str

生成可读的字符串

call

调用它时,会设定到特定的_desc上,执行并再设定好它的输出