Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move fused_attention op to phi [迁移前向 GPU OpKernel] #51743

Merged
merged 54 commits into from
Apr 6, 2023

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Mar 16, 2023

PR types

Others

PR changes

Others

Describe

Fluid 算子函数式迁移, 迁移 fused_attention op 的 gpu kernel (前向)

@paddle-bot paddle-bot bot added the contributor External developers label Mar 16, 2023
@AndSonder AndSonder changed the title Move fused_attention op to phi [迁移 GPU OpKernel] Move fused_attention op to phi [迁移前向 GPU OpKernel] Mar 20, 2023

#include "paddle/phi/core/dense_tensor.h"

namespace phi {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse算子统一写在phi::fusion命名空间下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -0,0 +1,379 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022 -> 2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 366 to 378
phi::DataType data_type;
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::FLOAT32) {
data_type = phi::DataType::FLOAT32;
} else {
data_type = phi::DataType::FLOAT64;
}
kernel->OutputAt(0).SetDataType(data_type);
kernel->OutputAt(1).SetDataType(data_type);
kernel->OutputAt(3).SetDataType(data_type);
kernel->OutputAt(4).SetDataType(data_type);
kernel->OutputAt(15).SetDataType(data_type);
kernel->OutputAt(16).SetDataType(data_type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
phi::DataType data_type;
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::FLOAT32) {
data_type = phi::DataType::FLOAT32;
} else {
data_type = phi::DataType::FLOAT64;
}
kernel->OutputAt(0).SetDataType(data_type);
kernel->OutputAt(1).SetDataType(data_type);
kernel->OutputAt(3).SetDataType(data_type);
kernel->OutputAt(4).SetDataType(data_type);
kernel->OutputAt(15).SetDataType(data_type);
kernel->OutputAt(16).SetDataType(data_type);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(15).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -28,8 +28,12 @@
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask

random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恢复此单测原先的随机数设置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 26 to 27
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/memory/memory.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/phi/commom/memory_utils.h"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include <cuda.h>

#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi下不可引用fluid的头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已去除该引用

@AndSonder AndSonder requested a review from From00 April 2, 2023 03:04
namespace fusion {

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与fused_attention_op.cu中的AllReduce统一成一份

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除fused_attention_op.cu中的AllReduce

@@ -0,0 +1,91 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此文件可以删除,后续反向算子也直接在fluid下做迁移即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -0,0 +1,157 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此文件可删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -28,8 +28,12 @@
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask

random.seed(42)
default_main_program().random_seed = 42
seed = 53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

随机种子改回42会有错误吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改回42,单测通过

@AndSonder AndSonder requested a review from From00 April 3, 2023 11:23
Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

因分布式相关头文件依赖:

#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"

此算子直接在fluid目录下迁移成函数式,后续分布式依赖迁移后,再将代码移动到PHI目录下。

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI-OP-Benchmark

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit a7ec895 into PaddlePaddle:develop Apr 6, 2023
@AndSonder AndSonder deleted the fused_attention_kernel branch April 23, 2024 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants