Skip to content

Commit

Permalink
update yaml file name and label (#5275)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored Sep 14, 2022
1 parent 115f9e2 commit e8efed2
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions docs/dev_guides/api_contributing_guides/new_cpp_op_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

| **内容** | **新增文件位置** |
| -------------- | ------------------------------------------------------------ |
| 算子描述及定义 | 前向算子:[paddle/phi/api/yaml/api.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/api.yaml) <br/>反向算子:[paddle/phi/api/yaml/backward.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/backward.yaml) |
| 算子描述及定义 | 前向算子:[paddle/phi/api/yaml/ops.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/ops.yaml) <br/>反向算子:[paddle/phi/api/yaml/backward.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/backward.yaml) |
| 算子 InferMeta | [paddle/phi/infermeta](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/phi/infermeta) 目录下的相应文件中 |
| 算子 Kernel | [paddle/phi/kernels](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/phi/kernels) 目录下的如下文件:(一般情况)<br/>xxx_kernel.h<br/>xxx_kernel.cc<br/>xxx_grad_kernel.h<br/>xxx_grad_kernel.cc |
| Python API | [python/paddle](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle) 目录下的相应子目录中的 .py 文件,遵循相似功能的 API 放在同一文件夹的原则 |
Expand All @@ -49,12 +49,12 @@

### 3.1 算子 Yaml 文件配置

`paddle/phi/api/yaml/api.yaml``paddle/phi/api/yaml/backward.yaml` 文件中对算子进行描述及定义,在框架编译时会根据 YAML 文件中的配置自动生成 C++ 端的相关代码接口以及内部实现(详见下文 [8.1 Paddle 基于 Yaml 配置自动生成算子代码的逻辑解读](#paddleyaml) 小节的介绍),下面主要以 [paddle.trace](../../api/paddle/trace_cn.html#trace) 为例介绍算子的 Yaml 配置规则:
`paddle/phi/api/yaml/ops.yaml``paddle/phi/api/yaml/backward.yaml` 文件中对算子进行描述及定义,在框架编译时会根据 YAML 文件中的配置自动生成 C++ 端的相关代码接口以及内部实现(详见下文 [8.1 Paddle 基于 Yaml 配置自动生成算子代码的逻辑解读](#paddleyaml) 小节的介绍),下面主要以 [paddle.trace](../../api/paddle/trace_cn.html#trace) 为例介绍算子的 Yaml 配置规则:

[paddle/phi/api/yaml/api.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/api.yaml) 中 trace 相关配置:
[paddle/phi/api/yaml/ops.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/ops.yaml) 中 trace 相关配置:

```yaml
- api : trace
- op : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor(out)
infer_meta :
Expand All @@ -67,7 +67,7 @@
[paddle/phi/api/yaml/backward.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/backward.yaml) 中 trace 相关配置:
```yaml
- backward_api : trace_grad
- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
output : Tensor(x_grad)
Expand All @@ -80,7 +80,7 @@
no_need_buffer : x
```
`api.yaml` 和 `backward.yaml` 分别对算子的前向和反向进行配置,首先 `api.yaml` 中前向算子的配置规则如下:
`ops.yaml` 和 `backward.yaml` 分别对算子的前向和反向进行配置,首先 `ops.yaml` 中前向算子的配置规则如下:

<table>
<thead>
Expand Down Expand Up @@ -146,7 +146,7 @@
</tr>
<tr>
<td>optional</td>
<td>指定输入 Tensor 为可选输入,用法可参考 dropout 中 seed_tensor(python/paddle/utils/code_gen/legacy_api.yaml 中)</td>
<td>指定输入 Tensor 为可选输入,用法可参考 dropout 中 seed_tensor(python/paddle/utils/code_gen/legacy_ops.yaml 中)</td>
</tr>
<tr>
<td>inplace</td>
Expand Down Expand Up @@ -181,12 +181,12 @@ b. 如果是实现自定义的 C++ API,需要在'paddle/phi/api/lib/api_custom
</thead>
<tbody>
<tr>
<td>backward_api</td>
<td>backward_op</td>
<td>反向算子名称,一般命名方式为:前向算子名称+'_grad',二阶算子则为前向算子名称+'_double_grad'</td>
</tr>
<tr>
<td>forward</td>
<td>对应前向算子的名称、参数、返回值,需要与 api.yaml 中前向算子配置一致</td>
<td>对应前向算子的名称、参数、返回值,需要与 ops.yaml 中前向算子配置一致</td>
</tr>
<tr>
<td>args</td>
Expand Down Expand Up @@ -960,7 +960,7 @@ PADDLE_ENFORCE_EQ(比较对象 A, 比较对象 B, 错误提示信息)
注册方式为在算子的 YAML 配置中添加`inplace`配置项,格式如:`(x -> out)`,详见[YAML 配置规则](new_cpp_op_cn.html#yaml)。示例:

```yaml
- api : reshape
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
...
Expand All @@ -975,7 +975,7 @@ PADDLE_ENFORCE_EQ(比较对象 A, 比较对象 B, 错误提示信息)
- 如果反向不需要前向的某些输入或输出参数,则无需在 args 中设置。
- 如果有些反向算子需要依赖前向算子的输入或输出变量的的 Shape 或 LoD,但不依赖于变量中 Tensor 的内存 Buffer 数据,且不能根据其他变量推断出该 Shape 和 LoD,则可以通过 `no_need_buffer` 对该变量进行配置,详见[YAML 配置规则](new_cpp_op_cn.html#yaml)。示例:
```yaml
- backward_api : trace_grad
- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
output : Tensor(x_grad)
Expand Down Expand Up @@ -1090,7 +1090,7 @@ Paddle 支持动态图和静态图两种模式,在 YAML 配置文件中完成
<center><img src="https://github.com/PaddlePaddle/docs/blob/develop/docs/dev_guides/api_contributing_guides/images/code_gen_by_yaml.png?raw=true" width="700px" ></center>
如前文所述,算子开发时通过 YAML 配置文件对算子进行描述及定义,包括前向 [paddle/phi/api/yaml/api.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/api.yaml) 和反向 [paddle/phi/api/yaml/backward.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/backward.yaml)。动态图和静态图两种模式的执行流程不同,具体如下所示:
如前文所述,算子开发时通过 YAML 配置文件对算子进行描述及定义,包括前向 [paddle/phi/api/yaml/ops.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/ops.yaml) 和反向 [paddle/phi/api/yaml/backward.yaml](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/api/yaml/backward.yaml)。动态图和静态图两种模式的执行流程不同,具体如下所示:
- 动态图中自动生成的代码包括从 Python API 到计算 Kernel 间的各层调用接口实现,从底层往上分别为:
- **C++ API**:一套与 Python API 参数对齐的 C++ 接口(只做逻辑计算,不支持自动微分),内部封装了底层 kernel 的选择和调用等逻辑,供上层灵活使用。
Expand Down

0 comments on commit e8efed2

Please sign in to comment.