Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Perf]Polish UniformRandom And Split it into ScheduleBlock #1357

Merged
merged 4 commits into from
May 9, 2023

Conversation

Aurelius84
Copy link
Collaborator

@Aurelius84 Aurelius84 commented Apr 19, 2023

一、数据现状

image

二、API 级别验证

优化前

总耗时:fn_xxx + gen_seq + seed_pesudo = 180 + 81 + 12 = 273 us
image

优化后

总耗时:192 us
image

三、后续可优化点

3.1 将 state 变量的初始化放到 kernel 外面

Nvidia 的官网明确指出了存在的性能问题,给开发者实现高性能 Kernel 提供了充分的经验指导:

  • curand_init() 要比 curand()curand_uniform() 慢!
  • curand_init() 在 offset 比较大时性能也会比小 offset 差!
  • save/load 操作 state 比每次重复创建起始 state 性能要快很多 !

原文如下:Calls to curand_init() are slower than calls to curand() or curand_uniform(). Large offsets to curand_init() take more time than smaller offsets. It is much faster to save and restore random generator state than to recalculate the starting state repeatedly.

对于上述第三点,Nvidia 建议可以将 state 存放到 global memory 中,如下是一个样例代码:

__global__ void example(curandState *global_state)
{
    curandState local_state;
    local_state = global_state[threadIdx.x];
    for(int i = 0; i < 10000; i++) {
        unsigned int x = curand(&local_state);
        ...
    }
    global_state[threadIdx.x] = local_state;
}

此操作的前提是将 state 变量的初始化放到 kernel 外面

3.2 借助 curand_uniform4 减少API调用次数

此 PR 里的 device API 在每次调用时,只会生成一个 float/double 的随机数。Nvidia 同样提供了一次可以生成 2个或4个 device API:

__device__ float4
curand_uniform4 (curandStatePhilox4_32_10_t *state);

__device__ float4
curand_normal4 (curandStatePhilox4_32_10_t *state);

附:CUDA source code:

#include <cstdint>

#define CINN_WITH_CUDA
#include "float16.h"
using cinn::common::float16;
using cinn::common::half4;
using cinn::common::half8;
using cinn::common::float8;

#include "cinn_cuda_runtime_source.cuh"

extern "C" {

__global__
void __launch_bounds__(1024) fn_fill_constant_0_uniform_random_1_broadcast_to_16_greater_equal_2_cast_5_cast_3_elementwise_mul_4_scale_6_11_kernel(const float16* __restrict__ eager_tmp_0, uint8_t* __restrict__ var_11, float16* __restrict__ var_15)
{
  bool _var_5_temp_buffer [ 1 ];
  bool* var_5 = _var_5_temp_buffer;
  if (((int)blockIdx.x < 24576)) {
    if (((int)threadIdx.x < 1024)) {
    {
      var_5[0] = (cinn_nvgpu_uniform_random_fp32(0) >= 0.100000001f);
      var_15[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = ((float16)1.1104f * (eager_tmp_0[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] * ((float16)(var_5[0]))));
      var_11[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = ((uint8_t)(var_5[0]));
    }
    };
  };
}

}

相关问题:Why can't templates be within extern "C" blocks?

@paddle-bot
Copy link

paddle-bot bot commented Apr 19, 2023

Thanks for your contribution!

fix codegen

refine unittest and add float64 kernel
Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhhsplendid zhhsplendid merged commit 658615e into PaddlePaddle:develop May 9, 2023
zhhsplendid pushed a commit to PaddlePaddle/Paddle that referenced this pull request May 9, 2023
[CINN]Adjust Bert unittest loss ground truth, see: PaddlePaddle/CINN#1357
BiynXu added a commit to BiynXu/CINN that referenced this pull request May 11, 2023
Aurelius84 added a commit to Aurelius84/CINN that referenced this pull request May 11, 2023
lanxianghit pushed a commit that referenced this pull request May 12, 2023
jiahy0825 pushed a commit to jiahy0825/CINN that referenced this pull request May 25, 2023
…dle#1357)

本PR因和paddle联编测试需两边修改,现CINN强行合入,待Paddle对应PR合入后CI可正常。
jiahy0825 pushed a commit to jiahy0825/CINN that referenced this pull request May 25, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants