Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Custom Op] New custom operator extension mechanism #30690

Merged
merged 39 commits into from
Feb 10, 2021

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Jan 25, 2021

PR types

New features

PR changes

OPs

Describe

New custom operator extension mechanism

This PR implements a simpler way to extend custom operator for paddle. the main features include:

  1. The programming method is more intuitive, the operation object is Tensor, no need to learn other concepts in paddle;
  2. Does not rely the function or class symbol in the underlying paddle framework.

Original mechanism: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/07_new_op/custom_op.html

Design

image

TODO

  • 目前的PR内容已经很多了,且作为整体功能比较难拆分
  • 为了避免PR过大review困难,以及便于之后并行推进剩余工作,本PR仅合入基础功能,后续会通过一些小PR继续完善TODO项,完善点包括:
    • 添加各种Attribute的输入支持
    • 添加各种必要的DISPATCH宏
    • 完善Tensor API对复数等特殊类型的支持
    • 添加更多的不同类型的自定义Op单测示例
      • 反向仅使用部分前向的Op
      • 需要提升输入类型的Op,比如复数计算
      • 需要注册Double Grad的Op
    • 接入预测
    • 多处细节完善等等

Basic steps

Method1: setup build & install
  1. write custom_op.cc and custom_op.cu
  2. write setup.py
  3. python setup.py build & install
  4. import module & call
Mehtod2: jit compile
  1. write custom_op.cc and custom_op.cu
  2. method = paddle.utils.cpp_extension.load()
  3. call method

Simple Demo

relu_op.cu

#include "paddle/extension.h"

template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
                                         data_t* y,
                                         const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
    y[i] = max(x[i], static_cast<data_t>(0.));
  }
}

template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
                                          const data_t* y,
                                          data_t* dx,
                                          const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
    dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
  }
}

std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
  auto out = paddle::Tensor(paddle::PlaceType::kGPU);
  out.Reshape(x.shape());

  int numel = x.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
  PD_DISPATCH_FLOATING_TYPES(
      x.type(), "relu_cuda_forward_kernel", ([&] {
        relu_cuda_forward_kernel<data_t><<<grid, block>>>(
            x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
      }));

  return {out};
}

std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
                                               const paddle::Tensor& out,
                                               const paddle::Tensor& grad_out) {
  auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
  grad_x.Reshape(x.shape());

  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
  PD_DISPATCH_FLOATING_TYPES(
      out.type(), "relu_cuda_backward_kernel", ([&] {
        relu_cuda_backward_kernel<data_t><<<grid, block>>>(
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x.mutable_data<data_t>(x.place()),
            numel);
      }));

  return {grad_x};
}

relu_op.cc

#include <iostream>
#include <vector>

#include "paddle/extension.h"

std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
                                               const paddle::Tensor& out,
                                               const paddle::Tensor& grad_out);

std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
  // TODO(chenweihang): Check Input
  return relu_cuda_forward(x);
}

std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
                                         const paddle::Tensor& out,
                                         const paddle::Tensor& grad_out) {
  // TODO(chenweihang): Check Input
  return relu_cuda_backward(x, out, grad_out);
}

std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
  return {x_shape};
}

std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
  return {x_dtype};
}

PD_BUILD_OPERATOR("relu2")
  .Inputs({"X"})
  .Outputs({"Out"})
  .SetKernelFn(PD_KERNEL(ReluForward))
  .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
  .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
  .SetBackwardOp("relu2_grad")
    .Inputs({"X", "Out", paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .SetKernelFn(PD_KERNEL(ReluBackward));

setup.py

from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup

setup(
    name='custom_relu_lib',
    ext_modules=[
        CUDAExtension(
            name='custom_relu_op',
            sources=['relu_op.cc', "relu_op.cu"])
    ])

test_setup.py

import os
import numpy as np

import paddle
import paddle.static as static
import custom_relu_op

def test_relu2_dynamic(device, dtype):
    paddle.set_device(device)
    
    x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
    t = paddle.to_tensor(x)
    t.stop_gradient = False

    out = custom_relu_op.relu2(t)
    out.stop_gradient = False
    print(out.numpy())

    out.backward()

def test_relu2_static(device, dtype):
    paddle.enable_static()
    paddle.set_device(device)
    
    with static.scope_guard(static.Scope()):
        with static.program_guard(static.Program()):
            x = static.data(name='X', shape=[None, 8], dtype=dtype)
            x.stop_gradient = False
            out = custom_relu_op.relu2(x)
            static.append_backward(out)
            print(static.default_main_program())
            
            exe = static.Executor()
            exe.run(static.default_startup_program())
            
            x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
            out, = exe.run(
                static.default_main_program(),
                feed={'X': x},
                fetch_list=[out.name])
            print(out)

if __name__ == '__main__':
    # dynamic graph mode
    test_relu2_dynamic("gpu", "float32")
    test_relu2_dynamic("gpu", "float64")

    # static graph mode
    test_relu2_static("gpu", "float32")
    test_relu2_static("gpu", "float64")
    

test_jit.py

import os
import numpy as np

import paddle
import paddle.static as static
from paddle.utils.cpp_extension import load

relu2 = load(
    name='custom_relu_jit_lib',
    sources=['relu_op.cc', 'relu_op.cu'])

# 剩余代码和test_setup.py示例一致,调用接口处替换下即可

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle/fluid/framework/custom_operator.cc Outdated Show resolved Hide resolved
paddle/fluid/framework/custom_operator.cc Outdated Show resolved Hide resolved
paddle/fluid/framework/custom_operator.cc Outdated Show resolved Hide resolved
chenwhql and others added 8 commits February 3, 2021 12:54
* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op
…t more data type (#2)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor
chalsliu
chalsliu previously approved these changes Feb 9, 2021
JiabinYang
JiabinYang previously approved these changes Feb 9, 2021
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

#pragma once

// All paddle apis in C++ frontend
#include "paddle/fluid/extension/include/all.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件看起来会引用一些 platform 的内部头文件,在预测发布时,需要评估下这样暴露头文件是否会有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里暴露的头文件是paddle底层的数据类型,包括:

#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"

为了支持完备的自定义Op,Op支持的数据类型是需要暴露的

如线下讨论的,这里可以按如下两种方式解决:

  1. 预测接入这几个底层头文件,但这几个底层头文件也需要明确,不能再include其他框架内的上层头文件了(可以保证);
  2. 预测通过宏控制暂时不支持这些数据类型,仅在训练时使用。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenwhql
Copy link
Contributor Author

chenwhql commented Feb 9, 2021

TODO in next PR:提供一个cmake function,用于返回自定义Op依赖的头文件,方便预测发布时能够一条指令引入新增的头文件

Shixiaowei02
Shixiaowei02 previously approved these changes Feb 9, 2021

using Tensor = paddle::Tensor;

#define DISABLE_COPY_AND_ASSIGN(classname) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_DISABLE_COPY_AND_ASSIGN

引用后可能与其他项目冲突

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang): Support Complex later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要考虑A100上提供的TF32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前paddle底层还没有这个类型的实现,有待后续引入

return {x_dtype};
}

PD_BUILD_OPERATOR("relu2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好不用“relu2”这样的名字,一方面不推荐用数字后缀命名,另一方面还有个专门操作relu6,这里的数字后缀就有歧义了,不过作为单测内容也可以不严格要求。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, 下次commit全部替换一下

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit f649442 into PaddlePaddle:develop Feb 10, 2021
@chenwhql chenwhql changed the title New custom operator extension mechanism [Custom Op] New custom operator extension mechanism Feb 18, 2021
chenwhql added a commit to chenwhql/Paddle that referenced this pull request Feb 22, 2021
* initial commit: simple demo

* polish copyright format

* add grap op simple demo

* adapt uncertain number of argument

* change trait marco name

* add place & dtype support for add kernel

* add dispath and infershape func

* poish code & add notes

* add dynamic_loader dep for paddle_framework

* add new custom op test dir

* polish impl details

* add unittest for new custom op

* fix failed unittest

* Costum op (#1)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* Remove ShareData from user && Change CustomTensor to Tensor && Support more data type (#2)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* refactor register design & add test

* change op_funtion to op_meta_info

* split op meta info into .h and .cc

* move get methods into friend class

* move OpMetaInfoHelper into framework space

* move CustomTensorUtils into framework space

* change pybind api name

* move PD C API into op meta info

* add register custom op api

* remove inference cmake change

* refactor copy to api && change Reshape to lowercase && support more dtype && add more test (#3)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* support multi dtype

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* fix copy to error

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* polish detail & error message

* polish test details

* Add cast api && Change copy related api to copy_to && add more test (#4)

* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* support multi dtype

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* fix copy to error

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add type cast

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* add more error log

* add more error log

* polish code

* used for test

* remove test comment

* remove test comment

* fix uint8 type error

* fix lost uint8 type error

* add test for coverage

* polish details by reviewer comments

* add prefix for DISABLE_COPY_AND_ASSIGN

Co-authored-by: Jiabin Yang <360788950@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants