Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused rms spmd #7830

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
#include "layer_norm_cuda.h" // NOLINT
#include "paddle/extension.h"

#ifdef CUSTOM_OP_WITH_SPMD
#include "paddle/phi/api/ext/spmd_infer.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

#define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor")

static void GetRowsCols(const std::vector<int64_t> &shape,
Expand Down Expand Up @@ -214,14 +219,22 @@ PD_BUILD_OP(fused_rms_norm)
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(RMSLnFwd))
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd))
#endif
;

PD_BUILD_GRAD_OP(fused_rms_norm)
.Inputs({"x", "scale", "invvar", paddle::Grad("y")})
.Outputs({paddle::Grad("x"), paddle::Grad("scale")})
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(RMSLnBwd))
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape));
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd))
#endif
;


// https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679