-
Notifications
You must be signed in to change notification settings - Fork 1
[Hackathon 7th Paddle python pass] passes ir.py代码阅读
做一个Python端的接口等,能满足pattern变换在python端的撰写。并能把这个pattern注册到c++runtime里。
paddle/incubate/passes/ir.py 代码阅读,这个装饰器@RegisterPass 的功能,就是把用户自定义的一段Python函数转换成一个python侧的pattern
Python IR代码通过RegisterPassHelper把输入函数中原始的pattern和replace 输入参数,实现转换成了编译器相关的_desc描述。 并通过VarHelper和OpHelper为算子和参数定义了实现。
def decorated(python_func):
pass_type = python_func.__name__
signature = inspect.signature(python_func)
if len(signature.parameters) > 0:
raise NotImplementedError(
"Pass function with parameter is not supported now."
)
elif len(signature.parameters) == 0:
pass_pairs = python_func()
if _is_pass_pair(pass_pairs):
pass_pairs = [pass_pairs]
elif not all(map(_is_pass_pair, pass_pairs)):
raise ValueError(
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func
这里是装饰器的入口,inspect.signature(python_func)的功能是:inspect.signature(python_func) 可以获取函数 python_func 的签名。这里确保入参是空的。
import inspect
def example_function(a, b, c=1):
pass
signature = inspect.signature(example_function)
print(signature)
>>> (a, b, c=1)
_is_pass_pair检查了pattern里是否包含了两个可调用的闭包。
def _is_pass_pair(check_pair):
if isinstance(check_pair, (list, tuple)):
if len(check_pair) == 2:
if all(map(inspect.isfunction, check_pair)):
return True
return False
之后执行RegisterPassHelper和pass注册之后,结束主要逻辑。
输入了pass闭包,pass函数名字和一些限制信息之后。
提取func中参数信息,根据静态输入or ParamAttr初始化args参数表
def _get_args_from_func(self, func):
args = []
arg_specs = inspect.getfullargspec(func)
for arg_name in arg_specs.args:
input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec):
args.append(
PassDesc.VarHelper(
arg_name, input_spec.shape, input_spec.dtype
)
)
elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name))
else:
args.append(PassDesc.VarHelper(arg_name, [-1]))
return args
_prune_program_desc 方法的主要功能是从给定的操作符描述中移除不必要的或默认的属性,以简化程序描述。
def _func_to_program_desc(self, func, ops):
vars = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func)
vars.extend(args)
outs = func(*args)
if not isinstance(outs, (list, tuple)):
outs = [outs]
# 遍历输出,并添加进vars里。
for out in outs:
if isinstance(out, PassDesc.OpHelper):
op_outs = out.Outputs()
if len(op_outs) != 1:
raise ValueError(
f"Operator '{out._type}' has multiple outputs, please specify one output variable."
)
for op_out in op_outs.values():
vars.extend(op_out)
else:
vars.append(out)
# 遍历当前块的描述,
block_desc = program.current_block().desc
for i in range(block_desc.op_size()):
ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
self._prune_program_desc(ops)
return vars, program.current_block().ops
with paddle.static.program_guard(program, startup_program):
的作用是使用 PaddlePaddle 的 program_guard 上下文管理器保护作用域,使得在其上下文中定义的所有操作和变量都归属于指定的 program 和 startup_program。
for i in range(block_desc.op_size()):
ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
- block_desc.op_size():返回当前block中操作符的数量。
- block_desc.op(i):获取第 i 个操作符的描述对象。
- serialize_to_string():将操作符描述对象序列化为字符串。
- ops.add():在 ops 集合中添加一个新的操作符。
- ParseFromString(...):从字符串解析出操作符的描述,恢复成一个操作符描述对象。
这个循环遍历了当前block中的所有操作符,将它们序列化并添加到 ops 集合中。为什么这里要使用序列化+反序列化呢,我认为是为了深拷贝+数据完整性的考量。
def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
def _add_element_conditions(conditions, elements):
for element in elements:
if element._condition:
conditions.append(element._condition)
_add_element_conditions(conditions, element._elements)
for pattern, replace in zip(patterns, replaces):
# Convert maps of inputs and outputs.
var_map = desc.var_maps.add()
var_map.pattern_var = pattern.name
var_map.replace_var = replace.name
conditions = desc.var_attr_conditions
# Convert shape condition.
if pattern.name in self._input_specs:
condition = conditions.add()
pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
condition.condition_value.name = ""
condition.condition_value.type = framework_pb2.AttrType.LONGS
condition.condition_value.longs.extend(pattern.shape)
condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
# Convert attr conditions.
if PassDesc.VarHelper == pattern.__class__:
for attr in pattern._attrs.values():
_add_element_conditions(conditions, [attr])
在这里我们定义了一个var_map,和condition。
if pattern.name in self._input_specs:
condition = conditions.add()
pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
condition.condition_value.name = ""
condition.condition_value.type = framework_pb2.AttrType.LONGS
condition.condition_value.longs.extend(pattern.shape)
condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
我们把shape属性转换成了pattern中的描述语言Attr。并且定义了一系列条件值。
def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
for replace in replaces:
if isinstance(replace, PassDesc.OpHelper):
for attr in replace._attrs.values():
# Convert attr maps.
mapped = attr._mapped
if inspect.isfunction(mapped):
mapped = mapped(patterns)
attr_map = desc.op_attr_maps.add()
mapped._to_pass_desc_attr(attr_map.pattern_attr)
attr._to_pass_desc_attr(attr_map.replace_attr)
if mapped._operation is not None:
attr_map.operation.CopyFrom(mapped._operation)
方法 _convert_ops_to_pass_desc:
- 遍历替换列表:处理其中的 PassDesc.OpHelper 对象。
- 转换属性映射:将每个属性的映射(_mapped)和自身转换为desc的Attr。
- 处理操作对象:如果映射包含操作对象(_operation),将其复制到attr中。
def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode()
if switch_static_mode:
paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type
# Traverse all pass pairs and convert them to PassDesc data.
# Here need to add cache in the future.
for pattern, replace in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add()
# Convert ProgramDescs of pattern and replace subgraphs.
pattern_vars, pattern_ops = self._func_to_program_desc(
pattern, pass_desc.pattern
)
replace_vars, replace_ops = self._func_to_program_desc(
replace, pass_desc.replace
)
self._convert_vars_to_pass_desc(
pattern_vars, replace_vars, pass_desc
)
self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc)
if switch_static_mode:
paddle.disable_static()
return multi_pass_desc.SerializeToString()
主要for循环逻辑:
- for pattern, replace in self._pass_pairs:遍历所有的转换规则对(pattern 和 replace)。
- pass_desc = multi_pass_desc.pass_descs.add():在 multi_pass_desc 中添加一个新的 PassDesc 对象。
- self._func_to_program_desc(pattern, pass_desc.pattern):将 pattern 转换为 ProgramDesc,并保存到 pass_desc.pattern 中。
- self._func_to_program_desc(replace, pass_desc.replace):将 replace 转换为 ProgramDesc,并保存到 pass_desc.replace 中。
- self._convert_vars_to_pass_desc(pattern_vars, replace_vars, pass_desc):将变量转换为 PassDesc。
- self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc):将操作(ops)转换为 PassDesc。
def _to_pass_desc_attr(self, pass_desc_attr):
if isinstance(self._obj, PassDesc.VarHelper):
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable
pass_desc_attr.var_name = self._obj.name
else:
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator
pass_desc_attr.op_index = self._obj._index
pass_desc_attr.name = self._name
if self._operation_type is not None:
pass_desc_attr.operation = self._operation_type
if self._element_index is not None:
pass_desc_attr.element_index = self._element_index
AttrHelper -> PassDesc
def _to_op_desc_attr(self, value, op_desc_attr):
op_desc_attr.name = ""
if isinstance(value, int):
op_desc_attr.type = framework_pb2.AttrType.INT
op_desc_attr.i = value
else:
raise NotImplementedError("Unimplemented transform operation.")
AttrHelper -> OP
def _set_with_condition(self, type, value):
condition = pass_desc_pb2.PassDesc.AttrCondition()
self._to_pass_desc_attr(condition.attr)
condition.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(condition.condition_attr)
else:
self._to_op_desc_attr(value, condition.condition_value)
if self._operation:
condition.operation.CopyFrom(self._operation)
self._condition = condition
设置条件
def MappedPattern(
self, var=None, op=None, index=0, name=None, element_index=None
):
if all([var, op]):
raise ValueError("Only mapped one of which var or op.")
def mapped_var(pattern_ops):
raise NotImplementedError(
"Mapping to variable is not implemented."
)
def mapped_op(pattern_ops):
ops = [o for o in pattern_ops if o._type == op]
if len(ops) <= index:
raise ValueError(
f"Index '{index}' of operator '{op}' is incorrect."
)
return PassDesc.AttrHelper(
ops[index], name, element_index=element_index
)
self._mapped = mapped_op if var is None else mapped_var
把AttrHelper映射到特定的操作上
一个paddle.static.Variable的继承
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
通过调用Attr来创建数据。
class OpHelper:
def _to_readable_code(self, skip_op_callstack=True):
assert isinstance(skip_op_callstack, bool), f"skip_op_callstack parameter's type is error, expect bool, received {type(skip_op_callstack)}"
outputs_str = "{"
outputs_str += ", ".join([f"{k}={v}" for k, v in self._outputs.items()])
outputs_str += "}"
inputs_str = "{"
inputs_str += ", ".join([f"{k}={v}" for k, v in self._inputs.items()])
inputs_str += "}"
attrs_str = "{"
attrs_str += ", ".join([f"{k}={v}" for k, v in self._attrs.items()])
attrs_str += "}"
op_str = f"{outputs_str} = {self._type}(inputs={inputs_str}, {attrs_str})"
return op_str
生成可读的字符串
调用它时,会设定到特定的_desc上,执行并再设定好它的输出