Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN+PIR]Support softmax graph for PIRCompiler #58169

Merged
merged 6 commits into from
Oct 21, 2023

Conversation

Aurelius84
Copy link
Contributor

@Aurelius84 Aurelius84 commented Oct 17, 2023

PR types

Others

PR changes

Others

Description

Pcard-67164

验证了softmax 子图的生成和执行,机制上主要包括:

  • 添加了OpMapper机制,处理可变Attribute对Operation的Oprand和Attribute的影响。后续将考虑升级为 Pass + cinn::dialect
  • 删除了 cinn/utlils/attribute_utils.h,统一到hlir/framework/pir/utils.h
  • 规范化 new_ir 目录、newir namespace 为 pir

附 softmax 子图生成的 CUDA C code:

#include <cstdint>

#define CINN_WITH_CUDA
#include "bfloat16.h"
#include "float16.h"
using cinn::common::bfloat16;
using cinn::common::float16;
using cinn::common::half4;
using cinn::common::half8;
using cinn::common::float8;

#include "cinn_cuda_runtime_source.cuh"

extern "C" {

__global__
void __launch_bounds__(1024) fn_fill_constant_kernel(float* __restrict__ var_72428416)
{
  if (((int)blockIdx.x < 8)) {
    if (((int)threadIdx.x < 1024)) {
      var_72428416[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = 1.00000000f;
    };
  };
}__global__
void __launch_bounds__(1024) fn_reduce_max_subtract_exp_reduce_sum_divide_kernel(const float* __restrict__ var_72428416, float* __restrict__ var_71976912, float* __restrict__ var_71946928, float* __restrict__ var_71269344, float* __restrict__ var_71256320, float* __restrict__ var_71249104)
{
  float _var_71256320_tmp_temp_buffer [ 1 ];
  float _var_71976912_tmp_temp_buffer [ 1 ];
  float* var_71256320_tmp = _var_71256320_tmp_temp_buffer;
  float* var_71976912_tmp = _var_71976912_tmp_temp_buffer;
  if (((int)blockIdx.x < 64)) {
    if (((int)threadIdx.x < 128)) {
      var_71976912_tmp[0] = cinn_block_reduce_max_fp32_internal(var_72428416[((128 * (int)blockIdx.x) + (int)threadIdx.x)]);
    };
  };
  if (((int)blockIdx.x < 64)) {
    if (((int)threadIdx.x < 1)) {
      var_71976912[(int)blockIdx.x] = var_71976912_tmp[0];
    };
  };
  if (((int)blockIdx.x < 8)) {
    if (((int)threadIdx.x < 1024)) {
      var_71946928[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = (var_72428416[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] - var_71976912[(((int)threadIdx.x / 128) + (8 * (int)blockIdx.x))]);
    };
  };
  if (((int)blockIdx.x < 8)) {
    if (((int)threadIdx.x < 1024)) {
      var_71269344[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = cinn_nvgpu_exp_fp32(var_71946928[((1024 * (int)blockIdx.x) + (int)threadIdx.x)]);
    };
  };
  if (((int)blockIdx.x < 64)) {
    if (((int)threadIdx.x < 128)) {
      var_71256320_tmp[0] = cinn_block_reduce_sum_fp32_internal(var_71269344[((128 * (int)blockIdx.x) + (int)threadIdx.x)]);
    };
  };
  if (((int)blockIdx.x < 64)) {
    if (((int)threadIdx.x < 1)) {
      var_71256320[(int)blockIdx.x] = var_71256320_tmp[0];
    };
  };
  if (((int)blockIdx.x < 8)) {
    if (((int)threadIdx.x < 1024)) {
      var_71249104[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = (var_71269344[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] / var_71256320[(((int)threadIdx.x / 128) + (8 * (int)blockIdx.x))]);
    };
  };
}

}

@paddle-bot
Copy link

paddle-bot bot commented Oct 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Aurelius84 Aurelius84 force-pushed the fix_cinn branch 2 times, most recently from 8013938 to 07ac605 Compare October 18, 2023 03:55
zhangbo9674
zhangbo9674 previously approved these changes Oct 18, 2023
phlrain
phlrain previously approved these changes Oct 18, 2023
XiaoguangHu01
XiaoguangHu01 previously approved these changes Oct 18, 2023
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 closed this Oct 20, 2023
@Aurelius84 Aurelius84 reopened this Oct 20, 2023
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Oct 20, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation Oct 20, 2023
phlrain
phlrain previously approved these changes Oct 20, 2023
XiaoguangHu01
XiaoguangHu01 previously approved these changes Oct 20, 2023
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 closed this Oct 21, 2023
@Aurelius84 Aurelius84 reopened this Oct 21, 2023
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Oct 21, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation Oct 21, 2023
@Aurelius84 Aurelius84 dismissed stale reviews from XiaoguangHu01 and phlrain via 67784b9 October 21, 2023 05:44
@risemeup1 risemeup1 merged commit 0ad330e into PaddlePaddle:develop Oct 21, 2023
28 checks passed
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Oct 24, 2023
* [CINN+PIR]Support softmax graph for PIRCompiler

fix err

fix conflict

* modify into paddle_test

* fix set_property bug

* fix test

* fix ENVIRONMENT

* fix ci
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 26, 2023
* [CINN+PIR]Support softmax graph for PIRCompiler

fix err

fix conflict

* modify into paddle_test

* fix set_property bug

* fix test

* fix ENVIRONMENT

* fix ci
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* [CINN+PIR]Support softmax graph for PIRCompiler

fix err

fix conflict

* modify into paddle_test

* fix set_property bug

* fix test

* fix ENVIRONMENT

* fix ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants