Skip to content

新IR Pass 推理单测基础设施搭建

Yuanle Liu edited this page Oct 24, 2023 · 2 revisions

新IR Pass 推理单测基础设施搭建

https://github.com/PaddlePaddle/Paddle/pull/58053

https://github.com/PaddlePaddle/Paddle/pull/58252

这里暂时先不进行基础设施的搭建,首先进行新ir下网络的构建以及conv2d_bn_fuse pass的测试。

static 组网后转成pir

这种方式使用了原有方式组网,然后使用pir.translate_to_new_ir将program转化成新ir下的program。

在运行python代码之前,设置环境变量

export FLAGS_enable_new_ir_in_executor=1

测试代码如下所示:

import numpy as np

import paddle
from paddle import pir
from paddle import base
from paddle.base import core

paddle.enable_static()

place = paddle.CPUPlace()

new_scope = paddle.static.Scope()

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
    with paddle.static.program_guard(
        main_program, startup_program
    ):
        x = paddle.static.data(
            name='x', shape=[3, 1, 28, 28], dtype='float32'
        )
        conv1_1 = paddle.static.nn.conv2d(
            input=x,
            filter_size=3,
            num_filters=32,
            stride=1,
            padding=1,
            act=None,
            bias_attr=False,
            data_format='NHWC',
        )
        bn = paddle.static.nn.batch_norm(
            input=conv1_1, act=None, data_layout='NHWC'
        )

pass_names = ['conv2d_bn_fuse']
feeds = {"x": np.random.random((3, 1, 28, 28)).astype("float32")}
fetch_list = [bn]

executor = paddle.static.Executor(place)
executor.run(startup_program)
out = executor.run(
    program=main_program,
    feed=feeds,
    fetch_list=[bn],
)

newir_program = pir.translate_to_new_ir(main_program.desc)
print('--------print new ir program--------')
print(newir_program)

print("Try to run the new ir program")
# out = executor.run(
#     program=newir_program,
#     feed=feeds,
#     fetch_list=[bn],
# )

pm = pir.PassManager()
for name in pass_names:
    pm.add_pass(name)

# here I didn't know how to make pir.program run by the executor
pm.run(newir_program)
print('--------print new ir program fusion--------')
print(newir_program)

输出结果如下所示,代码中输出了转化前后的ir:

--------print new ir program--------
{
 (%0) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"conv2d_0.w_0",stop_gradient:[false]} : () -> pd_op.tensor<32x28x3x3xf32>
 (%1) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"batch_norm_0.b_0",stop_gradient:[false]} : () -> pd_op.tensor<32xf32>
 (%2) = "builtin.get_parameter" () {parameter_name:"batch_norm_0.w_1"} : () -> pd_op.tensor<32xf32>
 (%3) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"batch_norm_0.w_0",stop_gradient:[false]} : () -> pd_op.tensor<32xf32>
 (%4) = "builtin.get_parameter" () {parameter_name:"batch_norm_0.w_2"} : () -> pd_op.tensor<32xf32>
 (%5) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persisable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[3,1,28,28],stop_gradient:[true]} : () -> pd_op.tensor<3x1x28x28xf32>
 (%6) = "pd_op.conv2d" (%5, %0) {data_format:"NHWC",dilations:[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:[false],padding_algorithm:"EXPLICIT",paddings:[(Int32)1,(Int32)1],stop_gradient:[false],strides:[(Int32)1,(Int32)1]} : (pd_op.tensor<3x1x28x28xf32>, pd_op.tensor<32x28x3x3xf32>) -> pd_op.tensor<3x1x28x32xf32>
 (%7, %8, %9, %10, %11, %12) = "pd_op.batch_norm_" (%6, %2, %4, %3, %1) {data_layout:"NHWC",epsilon:(Float)1e-05,is_persisable:[false,true,true,false,false,false],is_test:false,momentum:(Float)0.9,stop_gradient:[false,true,true,true,true,true],trainable_statistics:false,use_global_stats:false} : (pd_op.tensor<3x1x28x32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>) -> pd_op.tensor<3x1x28x32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<-1xf32>
}

Try to run the new ir program
--------print new ir program fusion--------
{
 (%0) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"conv2d_0.w_0",stop_gradient:[false]} : () -> pd_op.tensor<32x28x3x3xf32>
 (%1) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"batch_norm_0.b_0",stop_gradient:[false]} : () -> pd_op.tensor<32xf32>
 (%2) = "builtin.get_parameter" () {parameter_name:"batch_norm_0.w_1"} : () -> pd_op.tensor<32xf32>
 (%3) = "builtin.get_parameter" () {is_persisable:[true],parameter_name:"batch_norm_0.w_0",stop_gradient:[false]} : () -> pd_op.tensor<32xf32>
 (%4) = "builtin.get_parameter" () {parameter_name:"batch_norm_0.w_2"} : () -> pd_op.tensor<32xf32>
 (%5) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persisable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[3,1,28,28],stop_gradient:[true]} : () -> pd_op.tensor<3x1x28x28xf32>
 (%6) = "pd_op.conv2d" (%5, %0) {data_format:"NHWC",dilations:[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:[false],padding_algorithm:"EXPLICIT",paddings:[(Int32)1,(Int32)1],stop_gradient:[false],strides:[(Int32)1,(Int32)1]} : (pd_op.tensor<3x1x28x28xf32>, pd_op.tensor<32x28x3x3xf32>) -> pd_op.tensor<3x1x28x32xf32>
 (%7, %8, %9, %10, %11, %12) = "pd_op.batch_norm_" (%6, %2, %4, %3, %1) {data_layout:"NHWC",epsilon:(Float)1e-05,is_persisable:[false,true,true,false,false,false],is_test:false,momentum:(Float)0.9,stop_gradient:[false,true,true,true,true,true],trainable_statistics:false,use_global_stats:false} : (pd_op.tensor<3x1x28x32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>) -> pd_op.tensor<3x1x28x32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<32xf32>, pd_op.tensor<-1xf32>
}

由于pass里匹配的是pd_op.batch_norm op,但是转换后的ir program中为pd_op.batch_norm_。实际上fuse pass并没有完成匹配,因此没有进行ir上的修改。

取消掉代码print…后的注释后,即在新ir下调用executor运行新ir的代码。

print("Try to run the new ir program")
# out = executor.run(
#     program=newir_program,
#     feed=feeds,
#     fetch_list=[bn],
# )

程序会报错:

Try to run the new ir program
Traceback (most recent call last):
  File "static_to_pir.py", line 55, in <module>
    out = executor.run(
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 1632, in run
    res = self._run_impl(
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 1767, in _run_impl
    if _can_use_interpreter_core(program, self.place):
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 727, in _can_use_interpreter_core
    program._graph, compiler.CompiledProgram
AttributeError: 'paddle.base.libpaddle.pir.Program' object has no attribute '_graph'

直接构建新ir下的网络

测试代码如下所示:

import numpy as np

import paddle
from paddle import pir
from paddle import base
from paddle.base import core

paddle.enable_static()

with paddle.pir_utils.IrGuard():
    x = paddle.static.data(
        name='x', shape=[3, 1, 28, 28], dtype='float32'
    )
    conv2d = paddle.nn.Conv2D(
        in_channels=1,
        out_channels=32,
        kernel_size=3,
        padding=1,
        data_format='NCHW',
        bias_attr=False,
    )
    bn = paddle.nn.BatchNorm2D(num_features=32, data_format='NCHW')
    result1 = conv2d(x)
    result2 = bn(result1)

    executor = base.Executor(place)
    out = executor.run(
        feed={"x": np.random.random((3, 1, 28, 28)).astype("float32")},
        fetch_list=[result2],
    )

在paddle.pir_utils.IrGuard()下,使用动态图api进行网络搭建,但是这种方式也会出现问题。

Traceback (most recent call last):
  File "pir_net.py", line 29, in <module>
    out = executor.run(
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 1622, in run
    res = self._run_pir_impl(
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 1925, in _run_pir_impl
    program, new_exe = self._executor_cache.get_pir_program_and_executor(
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 1026, in get_pir_program_and_executor
    new_exe = _StandaloneExecutor(place, plan, scope)
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 781, in __init__
    self._new_exe = self._create_new_executor()
  File "/root/miniconda3/envs/paddle/lib/python3.8/site-packages/paddle/base/executor.py", line 805, in _create_new_executor
    new_exe = core.StandaloneExecutor(self._place, self._plan, self._scope)
RuntimeError: (NotFound) Cannot find parameter_0 in scope.
  [Hint: var should not be null.] (at /home/Paddle/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc:63)