diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 952625bcb6e46..7b8dae2dbc872 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -23,45 +23,27 @@ class BatchFCOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Input"), true, - platform::errors::InvalidArgument( - "X(Input) of Batch Fully Connected should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Out(Output) of Batch Fully Connected should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("W"), true, - platform::errors::InvalidArgument( - "W(Input) of Batch Fully Connected should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "BatchFCOp"); + OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "BatchFCOp"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchFCOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BatchFCOp"); auto input_dims = ctx->GetInputDim("Input"); auto w_dims = ctx->GetInputDim("W"); + auto batchcount = ctx->Attrs().Get("batchcount"); - PADDLE_ENFORCE_EQ(input_dims.size(), 3, + int feature_dim = input_dims[1] / batchcount; + PADDLE_ENFORCE_EQ(feature_dim, w_dims[0], platform::errors::InvalidArgument( - "Input of BatchFCOp should have 3D.")); - PADDLE_ENFORCE_EQ(w_dims.size(), 3, platform::errors::InvalidArgument( - "W of BatchFCOp should have 3D.")); - PADDLE_ENFORCE_EQ( - input_dims[0], w_dims[0], - platform::errors::InvalidArgument( - "Input.dim[0] and W.dim[0] of BatchFCOp should be same.")); - PADDLE_ENFORCE_EQ( - input_dims[2], w_dims[1], - platform::errors::InvalidArgument( - "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + "Input.dim[1]/batchcount and W.dim[0] of BatchFCOp " + "should be same.")); auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0], - platform::errors::InvalidArgument( - "Bias.dim[0] should be same as input.dim[0].")); - PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[2], + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], platform::errors::InvalidArgument( - "Bias.dim[1] should be same as input.dim[2].")); + "Bias.dim[1] should be same as W.dim[1].")); - ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], w_dims[2]}); + ctx->SetOutputDim("Out", {input_dims[0], w_dims[1]}); ctx->ShareLoD("Input", /*->*/ "Out"); } @@ -107,6 +89,7 @@ class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(Tensor) Input tensor of batch_fc_op operator."); AddInput("W", "(Tensor) Input tensor of batch_fc_op operator."); AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator."); + AddAttr("batchcount", "(int64_t) the batchcount"); AddOutput("Out", "Output tensor of batch_fc_op operator."); AddComment(R"DOC( BatchFC Operator. @@ -136,8 +119,6 @@ class BatchFCGradOpMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); } }; -DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInferer, - "Bias"); } // namespace operators } // namespace paddle @@ -147,8 +128,7 @@ REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker, ops::BatchFCGradOpMaker, ops::BatchFCGradOpMaker); -REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp, - ops::BatchFCGradOpNoNeedBufferVarsInferer); +REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp); REGISTER_OP_CPU_KERNEL( batch_fc, ops::BatchFCKernel, diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index 414eeef2a6f70..bcdb1912c1d0b 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/batch_fc_op.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" @@ -33,75 +34,114 @@ static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } +// add the same row vector to all matrix rows template -__global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num, - int out_dim, const T* bias) { - CUDA_KERNEL_LOOP(idx, slot_pairs_num * ins_num * out_dim) { - int block_len = ins_num * out_dim; - int slot_index = idx / block_len; - int out_dim_index = (idx % block_len) % out_dim; - T temp = data[idx] + bias[slot_index * out_dim + out_dim_index]; - data[idx] = temp; - } +__global__ void kernel_vec_mat_row_add(const int N, const unsigned int rown, + const unsigned int coln, T* matrix, + const T* vector) { + CUDA_KERNEL_LOOP(i, N) { matrix[i] += vector[i % coln]; } } template -void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num, - int out_dim, const T* bias) { - add_bias_kernel<<>>(data, slot_pairs_num, - ins_num, out_dim, bias); +void vec_mat_row_add(cudaStream_t stream, const unsigned int rown, + const unsigned int coln, T* matrix, const T* vector) { + int N = rown * coln; + kernel_vec_mat_row_add<<>>( + N, rown, coln, matrix, vector); } +// calculate col sum of a mat template -__global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num, - int ins_num, int out_dim, T* db_data) { - CUDA_KERNEL_LOOP(idx, slot_pairs_num * out_dim) { - int row = idx / out_dim; - int col = idx % out_dim; - T temp = static_cast(0); - for (int i = 0; i < ins_num; ++i) { - int select_indx = ((row + 1) * i + 1) * col; - temp += dout_data[select_indx]; +__global__ void kernel_add_col_sum_mat(const unsigned int rown, + const unsigned int coln, const T* matrix, + T* vector) { + CUDA_KERNEL_LOOP(i, coln) { + for (unsigned int j = 0; j < rown; j++) { + // vector[i] += matrix[i * rown + j]; + vector[i] += matrix[j * coln + i]; } - db_data[idx] += temp; } } template -void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num, - int ins_num, int out_dim, T* db_data) { - add_bias_grad_kernel<<>>(dout_data, slot_pairs_num, ins_num, - out_dim, db_data); +void col_sum_mat(cudaStream_t stream, const unsigned int rown, + const unsigned int coln, const T* matrix, T* vector, + const int alpha) { + kernel_add_col_sum_mat<<>>( + rown, coln, matrix, vector); +} + +template +__global__ void kernel_transpose_split_col(const unsigned int row, + const unsigned int col, + const unsigned int num_block, + const T* source, T* dest) { + CUDA_KERNEL_LOOP(i, row * col) { + int len = col / num_block; + int dest_row = i / len; + int dest_col = i % len; + int block_row = dest_row % row; + int block_idx = dest_row / row; + int sou_col = block_idx * len + dest_col; + dest[i] = source[block_row * col + sou_col]; + } +} + +template +void transpose_split_col(cudaStream_t stream, const unsigned int rown, + const unsigned int coln, const unsigned int num_block, + const T* source, T* dest) { + kernel_transpose_split_col<<>>(rown, coln, num_block, source, dest); +} + +template +__global__ void kernel_transpose_split_row(const unsigned int row, + const unsigned int col, + const unsigned int num_block, + const T* source, T* dest) { + CUDA_KERNEL_LOOP(i, row * col) { + int len = row / num_block; + int dest_row = i / (col * num_block); + int dest_col = i % (col * num_block); + int block_idx = dest_col / col; + int block_col = dest_col % col; + dest[i] = source[(block_idx * len + dest_row) * col + block_col]; + } +} + +template +void transpose_split_row(cudaStream_t stream, const unsigned int rown, + const unsigned int coln, const unsigned int num_block, + const T* source, T* dest) { + kernel_transpose_split_row<<>>(rown, coln, num_block, source, dest); } template class BatchFCCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - // X.dim = slot_pairs_num * ins_num * in_dim - // W.dim = slot_pairs_num * in_dim * out_dim - // b.dim = slot_pairs_num * out_dim - // output.dim = slot_pairs_num * ins_num * out_dim auto* input = ctx.Input("Input"); auto* w = ctx.Input("W"); auto* bias = ctx.Input("Bias"); auto* output = ctx.Output("Out"); + auto batchcount = ctx.Attr("batchcount"); + auto input_dims = input->dims(); auto w_dims = w->dims(); - auto slot_pairs_num = input_dims[0]; - auto ins_num = input_dims[1]; - auto in_dim = input_dims[2]; - auto out_dim = w_dims[2]; + auto ins_num = input_dims[0]; + auto in_feat = input_dims[1] / batchcount; + auto out_feat = w_dims[1] / batchcount; // get data ptr const T* in_data = input->data(); const T* w_data = w->data(); const T* bias_data = bias->data(); - output->Resize({slot_pairs_num, ins_num, out_dim}); + output->Resize({ins_num, w_dims[1]}); T* out_data = output->mutable_data(ctx.GetPlace()); + // initialize auto out_eigen = framework::EigenVector::Flatten(*output); auto& dev_ctx = ctx.template device_context(); @@ -109,19 +149,39 @@ class BatchFCCUDAKernel : public framework::OpKernel { .eigen_device(); out_eigen.device(place) = out_eigen.constant(static_cast(0)); + math::Transpose trans; + + Tensor out_help; + out_help = + ctx.AllocateTmpTensor({w_dims[1], ins_num}, dev_ctx); + trans(dev_ctx, *output, &out_help, {1, 0}); + + Tensor input_help; + input_help = ctx.AllocateTmpTensor( + {input_dims[1], ins_num}, dev_ctx); + trans(dev_ctx, *input, &input_help, {1, 0}); + + Tensor w_help; + w_help = ctx.AllocateTmpTensor({w_dims[1], w_dims[0]}, + dev_ctx); + trans(dev_ctx, *w, &w_help, {1, 0}); + CBLAS_TRANSPOSE transA = CblasNoTrans; CBLAS_TRANSPOSE transB = CblasNoTrans; T alpha = 1; T beta = 0; - int64_t strideA = ins_num * in_dim; - int64_t strideB = in_dim * out_dim; + int64_t strideA = out_feat * in_feat; + int64_t strideB = in_feat * ins_num; auto blas = math::GetBlas(dev_ctx); - blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data, - w_data, beta, out_data, slot_pairs_num, strideA, strideB); - add_bias(ctx.cuda_device_context().stream(), out_data, slot_pairs_num, - ins_num, out_dim, bias_data); + blas.BatchedGEMM(transA, transB, out_feat, ins_num, in_feat, alpha, + w_help.data(), input_help.data(), beta, + out_help.data(), batchcount, strideA, strideB); + + trans(dev_ctx, out_help, output, {1, 0}); + vec_mat_row_add(ctx.cuda_device_context().stream(), ins_num, w_dims[1], + output->data(), bias->data()); } }; @@ -132,21 +192,23 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel { auto* input = ctx.Input("Input"); auto* w = ctx.Input("W"); auto* dout = ctx.Input(framework::GradVarName("Out")); + auto batchcount = ctx.Attr("batchcount"); auto* dx = ctx.Output(framework::GradVarName("Input")); auto* dw = ctx.Output(framework::GradVarName("W")); auto* db = ctx.Output(framework::GradVarName("Bias")); auto input_dims = input->dims(); + auto dout_dims = dout->dims(); auto w_dims = w->dims(); - auto slot_pairs_num = input_dims[0]; - auto ins_num = input_dims[1]; - auto in_dim = input_dims[2]; - auto out_dim = w_dims[2]; + + auto dout_coln = dout_dims[1]; + auto ins_num = dout_dims[0]; auto& dev_ctx = ctx.template device_context(); auto& place = *ctx.template device_context() .eigen_device(); + auto stream = ctx.cuda_device_context().stream(); // initialize dx->mutable_data(ctx.GetPlace()); auto dx_eigen = framework::EigenVector::Flatten(*dx); @@ -156,32 +218,72 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel { auto dw_eigen = framework::EigenVector::Flatten(*dw); dw_eigen.device(place) = dw_eigen.constant(static_cast(0)); - // get data ptr - const T* x_data = input->data(); - const T* w_data = w->data(); - const T* dout_data = dout->data(); - T* dx_data = dx->data(); - T* dw_data = dw->data(); - db->mutable_data(ctx.GetPlace()); auto db_eigen = framework::EigenVector::Flatten(*db); db_eigen.device(place) = db_eigen.constant(static_cast(0)); - T* db_data = db->data(); - add_bias_grad(ctx.cuda_device_context().stream(), dout_data, - slot_pairs_num, ins_num, out_dim, db_data); + // get bias grad + col_sum_mat(stream, ins_num, dout_coln, dout->data(), db->data(), 1); + + Tensor dout_help; + dout_help = ctx.AllocateTmpTensor( + {dout_dims[0] * batchcount, dout_dims[1] / batchcount}, dev_ctx); + dout_help.mutable_data(ctx.GetPlace()); + + Tensor input_help; + input_help = ctx.AllocateTmpTensor( + {input_dims[0] * batchcount, input_dims[1] / batchcount}, dev_ctx); + input_help.mutable_data(ctx.GetPlace()); + + Tensor w_help; + w_help = ctx.AllocateTmpTensor( + {w_dims[0] * batchcount, w_dims[1] / batchcount}, dev_ctx); + w_help.mutable_data(ctx.GetPlace()); + + Tensor dx_help; + dx_help = ctx.AllocateTmpTensor( + {input_dims[0] * batchcount, input_dims[1] / batchcount}, dev_ctx); + auto dx_help_eigen = framework::EigenVector::Flatten(dx_help); + dx_help_eigen.device(place) = dx_help_eigen.constant(static_cast(0)); + + Tensor dw_help; + dw_help = ctx.AllocateTmpTensor( + {w_dims[0] * batchcount, w_dims[1] / batchcount}, dev_ctx); + auto dw_help_eigen = framework::EigenVector::Flatten(dw_help); + + dx_help_eigen.device(place) = dx_help_eigen.constant(static_cast(0)); + transpose_split_col(stream, dout_dims[0], dout_dims[1], batchcount, + dout->data(), dout_help.data()); + transpose_split_col(stream, input_dims[0], input_dims[1], batchcount, + input->data(), input_help.data()); + transpose_split_col(stream, w_dims[0], w_dims[1], batchcount, w->data(), + w_help.data()); + + // dx = dout_data * y^T auto blas = math::GetBlas(dev_ctx); T alpha = 1; T beta = 0; // dx = dout_data * y^T - blas.BatchedGEMM(CblasNoTrans, CblasTrans, ins_num, in_dim, out_dim, alpha, - dout_data, w_data, beta, dx_data, slot_pairs_num, - ins_num * out_dim, out_dim * in_dim); + blas.BatchedGEMM(CblasNoTrans, CblasTrans, dout_dims[0], w_dims[0], + dout_dims[1] / batchcount, alpha, dout_help.data(), + w_help.data(), beta, dx_help.data(), batchcount, + dout_dims[0] * dout_dims[1] / batchcount, + w_dims[0] * dout_dims[1] / batchcount); + + transpose_split_row(stream, dout_dims[0] * batchcount, w_dims[0], + batchcount, dx_help.data(), dx->data()); + // dy = x^T * dout_data - blas.BatchedGEMM(CblasTrans, CblasNoTrans, in_dim, out_dim, ins_num, alpha, - x_data, dout_data, beta, dw_data, slot_pairs_num, - in_dim * ins_num, ins_num * out_dim); + blas.BatchedGEMM(CblasTrans, CblasNoTrans, input_dims[1] / batchcount, + dout_dims[1] / batchcount, input_dims[0], alpha, + input_help.data(), dout_help.data(), beta, + dw_help.data(), batchcount, + input_dims[0] * input_dims[1] / batchcount, + input_dims[0] * dout_dims[1] / batchcount); + + transpose_split_row(stream, w_dims[0] * batchcount, w_dims[1] / batchcount, + batchcount, dw_help.data(), dw->data()); } }; diff --git a/paddle/fluid/operators/cross_norm_hadamard.cu.h b/paddle/fluid/operators/cross_norm_hadamard.cu.h new file mode 100644 index 0000000000000..7d5074d6e5f18 --- /dev/null +++ b/paddle/fluid/operators/cross_norm_hadamard.cu.h @@ -0,0 +1,236 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "cub/cub.cuh" +#include "paddle/fluid/operators/math/math_function.h" + +#define NORM_POS(idx, row, col) (((idx)*block_cols + (col)) * ins_num + (row)) +#define SCALE_MEAN_POS(idx, col) ((idx)*block_cols + (col)) +#define INPUT_POS(idx, row, col) \ + (((embed_dim * (idx)) + (col)) * ins_num + (row)) + +#define INPUT_POS_FF(idx, row, col) \ + (embed_dim * (idx) + (col) + (row)*input_cols) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} +namespace paddle { +namespace operators { + +template +__global__ void nncross_normforward_multi(int len, int n, int embed_dim, + int ins_num, const T* inputs, + T* norm_output, const T* mean, + const T* scale) { + CUDA_KERNEL_LOOP(i, len) { + int norm_cols = n * (embed_dim * 3 + 1); + int block_cols = embed_dim * 3 + 1; + int input_cols = embed_dim * 2 * n; + + int row = i / norm_cols; + int col_global = i % norm_cols; + int block_idx = col_global / block_cols; + int col = col_global % block_cols; + + if (col < embed_dim) { + norm_output[i] = + (inputs[INPUT_POS_FF(block_idx * 2, row, col)] - mean[col_global]) * + scale[col_global]; + } else if (col < embed_dim * 2) { + col -= embed_dim; + norm_output[i] = (inputs[INPUT_POS_FF(block_idx * 2 + 1, row, col)] - + mean[col_global]) * + scale[col_global]; + } else if (col < embed_dim * 3) { + col -= 2 * embed_dim; + norm_output[i] = (inputs[INPUT_POS_FF(block_idx * 2, row, col)] * + inputs[INPUT_POS_FF(block_idx * 2 + 1, row, col)] - + mean[col_global]) * + scale[col_global]; + } + } +} + +template +__global__ void nncross_normforward_multi_sim(int len, int N, int embed_dim, + int ins_num, const T* inputs, + T* norm_output, const T* mean, + const T* scale) { + CUDA_KERNEL_LOOP(i, len) { + int block_cols = embed_dim * 3 + 1; + int row = i / N; + int col_global = (i % N + 1) * block_cols - 1; + int block_idx = col_global / block_cols; + int input_cols = embed_dim * 2 * N; + + T sum = 0; + for (int j = 0; j < embed_dim; ++j) { + sum += inputs[INPUT_POS_FF(block_idx * 2, row, j)] * + inputs[INPUT_POS_FF(block_idx * 2 + 1, row, j)]; + } + norm_output[row * block_cols * N + col_global] = + (sum - mean[col_global]) * scale[col_global]; + } +} + +template +__global__ void nncross_normbackpropagate_multi(int len, int N, int embed_dim, + int ins_num, const T* inputs, + const T* norm_grad, T* grads, + const T* mean, const T* scale) { + CUDA_KERNEL_LOOP(i, len) { + int row = i % ins_num; + int col_global = i / ins_num; + int a_idx = col_global / embed_dim; + int col = col_global % embed_dim; + int block_cols = embed_dim * 3 + 1; + + // grad 0 + grads[i] += norm_grad[NORM_POS(a_idx / 2, row, col)] * + scale[SCALE_MEAN_POS(a_idx / 2, col)]; + // grad 1 + grads[i] += norm_grad[NORM_POS(a_idx / 2, row, (embed_dim * 2 + col))] * + scale[SCALE_MEAN_POS(a_idx / 2, (embed_dim * 2 + col))] * + inputs[INPUT_POS((1 + (a_idx / 2) * 4 - a_idx), row, col)]; + // grad 2 + grads[i] += norm_grad[NORM_POS(a_idx / 2, row, (embed_dim * 3))] * + scale[SCALE_MEAN_POS(a_idx / 2, (embed_dim * 3))] * + inputs[INPUT_POS((1 + (a_idx / 2) * 4 - a_idx), row, col)]; + } +} + +template +__global__ void kernel_mean_scale(int N, const T* summary, T* mean, T* scale) { + CUDA_KERNEL_LOOP(i, N) { + mean[i] = summary[i + N] / summary[i]; + scale[i] = sqrt(summary[i] / summary[i + 2 * N]); + } +} + +template +__global__ void kernel_normbackwardsummary_x0(int len, int row, T* in_val, + T* sum_grad, const T* means, + const T* scale, + const T squared_sum_epsilon) { + CUDA_KERNEL_LOOP(i, len) { + in_val[i] = in_val[i] / scale[i / row] + means[i / row]; + } +} + +template +__global__ void kernel_normbackwardsummary_plus_mean( + int len, int row, T* in_val, T* sum_grad, const T* means, const T* scale, + const T squared_sum_epsilon) { + CUDA_KERNEL_LOOP(i, len) { + in_val[i] = (in_val[i] - means[i / row]) * (in_val[i] - means[i / row]); + } +} + +template +__global__ void kernel_normbackwardsummary_place_sum(int len, T* buf1, T* buf2, + T* out_val, int row, + T squared_sum_epsilon) { + CUDA_KERNEL_LOOP(i, len) { + out_val[3 * i] = 1; + out_val[3 * i + 1] = buf1[i] / row; + out_val[3 * i + 2] = buf2[i] / row + squared_sum_epsilon; + } +} + +template +void nncross_norm_ff(int N, int embed_dim, int ins_num, const T* inputs, + T* norm_output, const T* summary, T* mean, T* scale, + cudaStream_t stream) { + int norm_cols = N * (embed_dim * 3 + 1); + kernel_mean_scale<<>>( + norm_cols, summary, mean, scale); + nncross_normforward_multi<<>>(norm_cols * ins_num, N, embed_dim, + ins_num, inputs, norm_output, mean, + scale); + nncross_normforward_multi_sim<<>>(N * ins_num, N, embed_dim, ins_num, + inputs, norm_output, mean, scale); +} + +template +void nncross_norm_bp(int N, int embed_dim, int ins_num, const T* inputs, + T* norm_output, const T* norm_grad, T* grads, T* sum_grad, + T* summary, const T* mean, const T* scale, + const T squared_sum_epsilon, cudaStream_t stream, + T* sum_grad_buf, int* sum_offset, T* sum_grad_buf2, + const framework::ExecutionContext& ctx) { + auto& dev_ctx = ctx.template device_context(); + int norm_cols = N * (embed_dim * 3 + 1); + int intput_cols = N * embed_dim; + + kernel_normbackwardsummary_x0<<>>( + norm_cols * ins_num, ins_num, norm_output, sum_grad, mean, scale, + squared_sum_epsilon); + + size_t temp_storage_bytes; + cub::DeviceSegmentedReduce::Sum(NULL, temp_storage_bytes, norm_output, + sum_grad_buf, norm_cols, sum_offset, + sum_offset + 1); + auto temp_cub_buf = memory::Alloc(dev_ctx, temp_storage_bytes); + T* cub_buf = reinterpret_cast(temp_cub_buf->ptr()); + + cub::DeviceSegmentedReduce::Sum(cub_buf, temp_storage_bytes, norm_output, + sum_grad_buf, norm_cols, sum_offset, + sum_offset + 1, stream); + + kernel_normbackwardsummary_plus_mean<<>>( + norm_cols * ins_num, ins_num, norm_output, sum_grad, mean, scale, + squared_sum_epsilon); + cub::DeviceSegmentedReduce::Sum(cub_buf, temp_storage_bytes, norm_output, + sum_grad_buf2, norm_cols, sum_offset, + sum_offset + 1, stream); + kernel_normbackwardsummary_place_sum<<>>( + norm_cols, sum_grad_buf, sum_grad_buf2, sum_grad, ins_num, + squared_sum_epsilon); + + nncross_normbackpropagate_multi<<>>( + intput_cols * ins_num * 2, N, embed_dim, ins_num, inputs, norm_grad, + grads, mean, scale); +} + +template +__global__ void KernelUpdateParam(int C, const T* d_summary, T* summary, + const float decay_rate) { + CUDA_KERNEL_LOOP(i, C) { + summary[i] = summary[i] * decay_rate + d_summary[i]; + } +} + +template +void update_norm_param(cudaStream_t stream, int C, const T* d_summary, + T* summary, const float decay_rate) { + KernelUpdateParam<<>>( + C, d_summary, summary, decay_rate); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.cc b/paddle/fluid/operators/cross_norm_hadamard_op.cc new file mode 100644 index 0000000000000..fc3158fe1e99c --- /dev/null +++ b/paddle/fluid/operators/cross_norm_hadamard_op.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cross_norm_hadamard_op.h" +#include + +namespace paddle { +namespace operators { + +class CrossNormHadamardOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "X", "CrossNormHadamard"); + OP_INOUT_CHECK(ctx->HasInput("SummaryInput"), "Input", "SummaryInput", + "CrossNormHadamard"); + + OP_INOUT_CHECK(ctx->HasOutput("CudaMeans"), "Output", "CudaMeans", + "CrossNormHadamard"); + OP_INOUT_CHECK(ctx->HasOutput("CudaScales"), "Output", "CudaScales", + "CrossNormHadamard"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CrossNormHadamard"); + + auto fields_num = ctx->Attrs().Get("fields_num"); + auto embed_dim = ctx->Attrs().Get("embed_dim"); + + auto cols = (embed_dim * 3 + 1) * fields_num; + auto input_dims = ctx->GetInputDim("Input"); + auto summary_dims = ctx->GetInputDim("SummaryInput"); + + PADDLE_ENFORCE_EQ( + cols, summary_dims[1], + platform::errors::InvalidArgument("Input(SummaryInput) should be [%d]," + "but now it is [%d]", + cols, summary_dims[0])); + + PADDLE_ENFORCE_EQ(embed_dim * 2 * fields_num, input_dims[1], + platform::errors::InvalidArgument( + "Input(Input) should be [%d]," + "but now it is [%d]", + embed_dim * 2 * fields_num, input_dims[1])); + ctx->SetOutputDim("Out", {input_dims[0], cols}); + ctx->SetOutputDim("CudaMeans", {1, cols}); + ctx->SetOutputDim("CudaScales", {1, cols}); + + ctx->ShareLoD("Input", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class CrossNormHadamardGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::InvalidArgument("Input should not be null")); + PADDLE_ENFORCE_EQ(ctx->HasInput("SummaryInput"), true, + platform::errors::InvalidArgument( + "Input(SummaryInput) should not be null")); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + ctx->SetOutputDim(framework::GradVarName("SummaryInput"), + ctx->GetInputDim("SummaryInput")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class CrossNormHadamardOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(Tensor) Input tensor of cross_norm_hadamard_op operator."); + AddInput("SummaryInput", + "(Tensor) Input tensor of cross_norm_hadamard_op operator."); + AddAttr("fields_num", "(int64_t) the fields_num").SetDefault(2); + AddAttr("embed_dim", "(int64_t) the embed_dim").SetDefault(1); + AddAttr( + "summary_decay_rate", + "(float, default 0.9999999) The decay rate when update the summary") + .SetDefault(0.9999999); + AddAttr("epsilon", "") + .SetDefault(1e-4) + .AddCustomChecker([](const float& epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddOutput("Out", "Output tensor of cross_norm_hadamard_op operator."); + AddOutput("CudaMeans", "Output tensor of cross_norm_hadamard_op operator."); + AddOutput("CudaScales", + "Output tensor of cross_norm_hadamard_op operator."); + + AddComment(R"DOC( +CrossNormHadamard Operator. +Notice: It currently supports GPU device. +This Op exists in contrib, which means that it is not shown to the public. +)DOC"); + } +}; + +template +class CrossNormHadamardGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("cross_norm_hadamard_grad"); + + op->SetInput("Input", this->Input("Input")); + op->SetInput("SummaryInput", this->Input("SummaryInput")); + op->SetInput("Out", this->Output("Out")); + op->SetInput("CudaMeans", this->Output("CudaMeans")); + op->SetInput("CudaScales", this->Output("CudaScales")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + op->SetOutput("SummaryInput", this->Input("SummaryInput")); + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + op->SetOutput(framework::GradVarName("SummaryInput"), + this->InputGrad("SummaryInput")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + cross_norm_hadamard, ops::CrossNormHadamardOp, + ops::CrossNormHadamardOpMaker, + ops::CrossNormHadamardGradOpMaker, + ops::CrossNormHadamardGradOpMaker); + +REGISTER_OPERATOR(cross_norm_hadamard_grad, ops::CrossNormHadamardGradOp); + +REGISTER_OP_CPU_KERNEL( + cross_norm_hadamard, + ops::CrossNormHadamardKernel, + ops::CrossNormHadamardKernel); diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.cu b/paddle/fluid/operators/cross_norm_hadamard_op.cu new file mode 100644 index 0000000000000..2ca2ca9e5bcf5 --- /dev/null +++ b/paddle/fluid/operators/cross_norm_hadamard_op.cu @@ -0,0 +1,208 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/cross_norm_hadamard.cu.h" +#include "paddle/fluid/operators/cross_norm_hadamard_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +#include "paddle/fluid/framework/fleet/box_wrapper.h" +namespace paddle { +namespace operators { +using framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; + +template +class CrossNormHadamardCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* summary_input = ctx.Input("SummaryInput"); + auto* Out = ctx.Output("Out"); + auto* cuda_means = ctx.Output("CudaMeans"); + auto* cuda_scales = ctx.Output("CudaScales"); + + auto fields_num = ctx.Attr("fields_num"); + auto embed_dim = ctx.Attr("embed_dim"); + + auto cols = (embed_dim * 3 + 1) * fields_num; + auto input_dims = input->dims(); + auto rows = input_dims[0]; + + auto& place = *ctx.template device_context() + .eigen_device(); + auto& dev_ctx = ctx.template device_context(); + auto stream = ctx.cuda_device_context().stream(); + + Out->Resize({rows, cols}); + T* out_data = Out->mutable_data(ctx.GetPlace()); + auto out_eigen = framework::EigenVector::Flatten(*Out); + out_eigen.device(place) = out_eigen.constant(static_cast(0)); + + cuda_means->Resize({1, cols}); + cuda_scales->Resize({1, cols}); + cuda_means->mutable_data(ctx.GetPlace()); + cuda_scales->mutable_data(ctx.GetPlace()); + + nncross_norm_ff(fields_num, embed_dim, rows, input->data(), + Out->data(), summary_input->data(), + cuda_means->data(), cuda_scales->data(), stream); + } +}; + +template +class CrossNormHadamardOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* summary_input = ctx.Input("SummaryInput"); + auto* out = ctx.Input("Out"); + auto* means = ctx.Input("CudaMeans"); + auto* scales = ctx.Input("CudaScales"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto fields_num = ctx.Attr("fields_num"); + auto embed_dim = ctx.Attr("embed_dim"); + const float epsilon = ctx.Attr("epsilon"); + const float dr = ctx.Attr("summary_decay_rate"); + + auto* input_grad = ctx.Output(framework::GradVarName("Input")); + auto* summary_grad = + ctx.Output(framework::GradVarName("SummaryInput")); + + auto& dev_ctx = ctx.template device_context(); + auto& place = *ctx.template device_context() + .eigen_device(); + auto stream = ctx.cuda_device_context().stream(); + + // initialize + input_grad->mutable_data(ctx.GetPlace()); + auto input_grad_eigen = framework::EigenVector::Flatten(*input_grad); + input_grad_eigen.device(place) = + input_grad_eigen.constant(static_cast(0)); + + summary_grad->mutable_data(ctx.GetPlace()); + auto summary_grad_eigen = framework::EigenVector::Flatten(*summary_grad); + summary_grad_eigen.device(place) = + summary_grad_eigen.constant(static_cast(0)); + + auto cols = (embed_dim * 3 + 1) * fields_num; + auto input_dims = input->dims(); + auto rows = input_dims[0]; + + // temperary tensor + math::Transpose trans; + + Tensor input_help; + input_help = ctx.AllocateTmpTensor( + {fields_num * 2 * embed_dim, rows}, dev_ctx); + trans(dev_ctx, *input, &input_help, {1, 0}); + + Tensor summary_help; + summary_help = ctx.AllocateTmpTensor({cols, 3}, dev_ctx); + trans(dev_ctx, *summary_input, &summary_help, {1, 0}); + + Tensor out_help; + out_help = ctx.AllocateTmpTensor({cols, rows}, dev_ctx); + trans(dev_ctx, *out, &out_help, {1, 0}); + + Tensor out_grad_help; + out_grad_help = + ctx.AllocateTmpTensor({cols, rows}, dev_ctx); + trans(dev_ctx, *out_grad, &out_grad_help, {1, 0}); + + Tensor input_grad_help; + input_grad_help = ctx.AllocateTmpTensor( + {fields_num * 2 * embed_dim, rows}, dev_ctx); + trans(dev_ctx, *input_grad, &input_grad_help, {1, 0}); + + Tensor summary_grad_help; + summary_grad_help = + ctx.AllocateTmpTensor({cols, 3}, dev_ctx); + trans(dev_ctx, *summary_grad, &summary_grad_help, {1, 0}); + + std::vector sum_offset(cols + 1, rows); + for (int i = 2; i < sum_offset.size(); ++i) { + sum_offset[i] += sum_offset[i - 1]; + } + sum_offset[0] = 0; + + auto tmp_array = memory::Alloc(dev_ctx, sum_offset.size() * sizeof(int)); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + tmp_array->ptr(), platform::CPUPlace(), + reinterpret_cast(sum_offset.data()), + sum_offset.size() * sizeof(int), dev_ctx.stream()); + int* g_sum_offset = reinterpret_cast(tmp_array->ptr()); + + auto temp_grad_buf1 = memory::Alloc(dev_ctx, sizeof(T) * cols); + T* g_grad_buf1 = reinterpret_cast(temp_grad_buf1->ptr()); + auto temp_grad_buf2 = memory::Alloc(dev_ctx, sizeof(T) * cols); + T* g_grad_buf2 = reinterpret_cast(temp_grad_buf2->ptr()); + + nncross_norm_bp(fields_num, embed_dim, rows, input_help.data(), + out_help.data(), out_grad_help.data(), + input_grad_help.data(), summary_grad_help.data(), + summary_help.data(), means->data(), + scales->data(), epsilon, stream, g_grad_buf1, + g_sum_offset, g_grad_buf2, ctx); + + trans(dev_ctx, input_grad_help, input_grad, {1, 0}); + trans(dev_ctx, summary_grad_help, summary_grad, {1, 0}); + + int C = 3 * cols; + T* summary_input_data = + ctx.Output("SummaryInput")->mutable_data(ctx.GetPlace()); + + bool need_sync_stats = true; + if (need_sync_stats) { +#if defined(PADDLE_WITH_NCCL) + auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(summary_grad->data()), + reinterpret_cast(summary_grad->data()), C, + platform::ToNCCLDataType(input->type()), ncclSum, comm->comm(), + stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU, and need_sync_stats connot be " + "supported on windows now.")); +#endif + } + update_norm_param(stream, C, summary_grad->data(), summary_input_data, + dr); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using GPUCtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(cross_norm_hadamard, + ops::CrossNormHadamardCUDAKernel, + ops::CrossNormHadamardCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(cross_norm_hadamard_grad, + ops::CrossNormHadamardOpCUDAKernel, + ops::CrossNormHadamardOpCUDAKernel); diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.h b/paddle/fluid/operators/cross_norm_hadamard_op.h new file mode 100644 index 0000000000000..28b4b19df2e88 --- /dev/null +++ b/paddle/fluid/operators/cross_norm_hadamard_op.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CrossNormHadamardKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented("BatchFC only supports GPU now.")); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index caabf77c0c4da..764e2b119bcf5 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -25,6 +25,7 @@ from paddle.fluid.layers import utils from ... import unique_name from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.fluid.param_attr import ParamAttr from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from paddle.fluid.framework import Variable, convert_np_dtype_to_dtype_ from paddle.fluid.layers import slice, reshape @@ -35,7 +36,8 @@ 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', 'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc', - '_pull_box_extended_sparse', 'fused_seqpool_cvm' + '_pull_box_extended_sparse', 'fused_seqpool_cvm', + 'cross_norm_layer_hadamard' ] @@ -1301,7 +1303,13 @@ def rank_attention(input, return output -def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None): +def batch_fc(input, + param_size, + param_attr, + bias_size, + bias_attr, + batchcount, + act=None): """ **Batch FC layer** This Op can calculate BatchFC. This is similar to matmul op, @@ -1321,16 +1329,16 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None): Examples: .. code-block:: python import paddle.fluid as fluid - - input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32") + batchcount = 3 + input = fluid.data(name="input", shape=[16, 2 * batchcount], dtype="float32") out = fluid.contrib.layers.batch_fc(input=input, - param_size=[16, 3, 10], + param_size=[2, 3 * batchcount], param_attr= fluid.ParamAttr(learning_rate=1.0, name="w_0", initializer= fluid.initializer.Xavier(uniform=False)), - bias_size=[16, 10], + bias_size=[1, 3 * batchcount], bias_attr= fluid.ParamAttr(learning_rate=1.0, name="b_0", @@ -1342,10 +1350,10 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None): helper = LayerHelper("batch_fc", **locals()) check_type(input, 'input', (Variable), 'batch_fc') input_shape = input.shape - assert input_shape[0] == param_size[0] - assert input_shape[2] == param_size[1] - assert param_size[2] == bias_size[1] - assert input_shape[0] == bias_size[0] + #assert input_shape[0] == param_size[0] + #assert input_shape[2] == param_size[1] + #assert param_size[2] == bias_size[1] + #assert input_shape[0] == bias_size[0] dtype = helper.input_dtype() check_dtype(dtype, 'input', ['float32', 'float64'], 'batch_fc') @@ -1360,6 +1368,7 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None): inputs={"Input": input, "W": w, "Bias": b}, + attrs={'batchcount': batchcount}, outputs={"Out": pre_act}) return helper.append_activation(pre_act) @@ -1411,8 +1420,15 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'): return outs, outs_extend -def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True, - need_filter=False, show_coeff=0.2, clk_coeff=1.0, threshold=0.96): +def fused_seqpool_cvm(input, + pool_type, + cvm, + pad_value=0.0, + use_cvm=True, + need_filter=False, + show_coeff=0.2, + clk_coeff=1.0, + threshold=0.96): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1430,8 +1446,8 @@ def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True, if pool_type.upper() != 'SUM': raise ValueError( - "fused_seqpool_cvm only support SUM pooling now, and your type is: " + - pool_type) + "fused_seqpool_cvm only support SUM pooling now, and your type is: " + + pool_type) check_type(input, 'input', list, 'fused_seqpool_cvm') if isinstance(input, list): @@ -1463,3 +1479,62 @@ def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True, return outs + +def cross_norm_layer_hadamard(input, + fields_num, + embed_dim, + param_dict={}, + summary_decay_rate=0.9999999, + epsilon=1e-04, + name=None): + """ + **Cross Norm Layer Hadamard** + """ + helper = LayerHelper('cross_norm_hadamard', **locals()) + dtype = helper.input_dtype() + + assert fields_num * 2 * embed_dim == input.shape[1] + summary_len = (embed_dim * 3 + 1) * fields_num + param_shape = [3, summary_len] + batch_size_default = 1e4 + batch_sum_default = 0.0 + batch_square_sum_default = 1e4 + + if param_dict and isinstance(param_dict, dict): + batch_size_default = param_dict.get("batch_size", 1e4) + batch_sum_default = param_dict.get("batch_sum", 0.0) + batch_square_sum_default = param_dict.get("batch_square", 1e4) + + np_layer = np.zeros(param_shape) + np_layer[0, :] = batch_size_default + np_layer[1, :] = batch_sum_default + np_layer[2, :] = batch_square_sum_default + summary_input = helper.create_parameter( + attr=ParamAttr( + name=name + '.cross_summary', + initializer=NumpyArrayInitializer(np_layer), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + means = helper.create_variable(dtype=dtype, stop_gradient=True) + scales = helper.create_variable(dtype=dtype, stop_gradient=True) + + out = helper.create_variable(dtype=dtype) + helper.append_op( + type="cross_norm_hadamard", + inputs={"Input": input, + "SummaryInput": summary_input}, + outputs={ + "Out": out, + "CudaMeans": means, + "CudaScales": scales, + "SummaryInput": summary_input + }, + attrs={ + "fields_num": fields_num, + "embed_dim": embed_dim, + "epsilon": epsilon, + "summary_decay_rate": summary_decay_rate + }) + return out diff --git a/python/paddle/fluid/tests/unittests/test_batch_fc_op.py b/python/paddle/fluid/tests/unittests/test_batch_fc_op.py index 56631d8d3b4ad..a060f85fa89d9 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_fc_op.py @@ -22,39 +22,44 @@ import paddle.fluid.core as core -def np_cal_batchfc(input, w, bias): - slot_pairs_num, batch_size, in_dim = input.shape - _, _, out_dim = w.shape - res = np.zeros((slot_pairs_num, batch_size, out_dim)) - for slot in range(slot_pairs_num): - res[slot, :] = np.dot(input[slot, :], w[slot, :]) - for slot in range(slot_pairs_num): - for bindx in range(out_dim): - res[slot, :, bindx] += bias[slot, bindx] +def np_cal_batchfc(input, w, bias, batchcount): + ins_num, _ = input.shape + in_feat, w_col = w.shape + out_feat = w_col / batchcount + + res = np.zeros((ins_num, w_col)) + for batch in range(batchcount): + res[:, batch * out_feat:batch * out_feat + out_feat] = np.dot( + input[:, in_feat * batch:in_feat * batch + in_feat], + w[:, out_feat * batch:out_feat * batch + out_feat]) + + for col in range(w_col): + res[:, col] = res[:, col] + bias[0, col] return res class TestBatchFCOp(OpTest): def config(self): - self.slot_pairs_num = 10 - self.batch_size = 5 - self.in_dim = 10 - self.out_dim = 12 + self.batchcount = 10 + self.in_feat = 10 + self.out_feat = 10 + self.ins_num = 2 self.dtype = "float64" def setUp(self): self.config() - self.input = np.random.random((self.slot_pairs_num, self.batch_size, - self.in_dim)).astype(self.dtype) - self.w = np.random.random((self.slot_pairs_num, self.in_dim, - self.out_dim)).astype(self.dtype) - self.bias = np.random.random((self.slot_pairs_num, - self.out_dim)).astype(self.dtype) + self.input = np.random.random( + (self.ins_num, self.in_feat * self.batchcount)).astype(self.dtype) + self.w = np.random.random( + (self.in_feat, self.out_feat * self.batchcount)).astype(self.dtype) + self.bias = np.random.random( + (1, self.out_feat * self.batchcount)).astype(self.dtype) self.op_type = "batch_fc" - np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_cal_batchfc(self.input, self.w, self.bias, self.batchcount) np_out = np_out.astype(self.dtype) self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} self.outputs = {"Out": np_out} + self.attrs = {"batchcount": self.batchcount} def test_check_output_gpu(self): if core.is_compiled_with_cuda(): @@ -66,41 +71,5 @@ def test_check_grad_gpu(self): core.CUDAPlace(0), ["Bias", "W", "Input"], "Out") -class TestBatchFCOp1(OpTest): - def config(self): - self.slot_pairs_num = 10 - self.batch_size = 5 - self.in_dim = 10 - self.out_dim = 12 - self.dtype = "float64" - - def setUp(self): - self.config() - self.input = np.random.random((self.slot_pairs_num, self.batch_size, - self.in_dim)).astype(self.dtype) - self.w = np.random.random((self.slot_pairs_num, self.in_dim, - self.out_dim)).astype(self.dtype) - self.bias = np.random.random((self.slot_pairs_num, - self.out_dim)).astype(self.dtype) - self.op_type = "batch_fc" - np_out = np_cal_batchfc(self.input, self.w, self.bias) - np_out = np_out.astype(self.dtype) - self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} - self.outputs = {"Out": np_out} - - def test_check_output_cpu(self): - try: - self.check_output_with_place(place=core.CPUPlace()) - except: - print("do not support cpu test, skip") - - def test_check_grad_cpu(self): - try: - self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"], - "Out") - except: - print("do not support cpu test, skip") - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cross_norm_hadamard_op.py b/python/paddle/fluid/tests/unittests/test_cross_norm_hadamard_op.py new file mode 100644 index 0000000000000..f6c9da5e3dd93 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cross_norm_hadamard_op.py @@ -0,0 +1,142 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestCrossNormHadamardOp(OpTest): + """ + test forward and backward + """ + + def setUp(self): + self.op_type = 'cross_norm_hadamard' + + ins_num = 100 + embed_dim = 2 + fields_num = 5 + tp = np.float64 + + input_a = np.random.random([ins_num, embed_dim]).astype(tp) + input_b = np.random.random([ins_num, embed_dim]).astype(tp) + input = np.concatenate((input_a, input_b), axis=1) + input_multi = input_a * input_b + input_sim = np.sum(input_multi, axis=1, keepdims=True) + + np_res = np.concatenate( + (input_a, input_b, input_multi, input_sim), axis=1) + + for _ in range(fields_num - 1): + input_a = np.random.random([ins_num, embed_dim]).astype(tp) + input_b = np.random.random([ins_num, embed_dim]).astype(tp) + input = np.concatenate((input, input_a, input_b), axis=1) + input_multi = input_a * input_b + input_sim = np.sum(input_multi, axis=1, keepdims=True) + + np_res = np.concatenate( + (np_res, input_a, input_b, input_multi, input_sim), axis=1) + + summary_input = np.zeros( + [3, (embed_dim * 3 + 1) * fields_num]).astype(tp) + summary_input[0, :] = 1e4 + summary_input[1, :] = 0.0 + summary_input[2, :] = 1e4 + + np_mean = summary_input[1, :] / summary_input[0, :] + np_scale = np.sqrt(summary_input[0, :] / summary_input[2, :]) + + self.inputs = {"Input": input, "SummaryInput": summary_input} + self.outputs = { + "Out": np_res, + "CudaMeans": np_mean, + "CudaScales": np_scale + } + self.attrs = {"fields_num": fields_num, "embed_dim": embed_dim} + + def test_check_output_gpu(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad_gpu(self): + if core.is_compiled_with_cuda(): + self.check_grad_with_place(core.CUDAPlace(0), ["Input"], "Out") + + +class TestRankAttentionOpCpu(OpTest): + def setUp(self): + """ + """ + self.op_type = 'cross_norm_hadamard' + + ins_num = 100 + embed_dim = 2 + fields_num = 5 + tp = np.float64 + + input_a = np.random.random([ins_num, embed_dim]).astype(tp) + input_b = np.random.random([ins_num, embed_dim]).astype(tp) + input = np.concatenate((input_a, input_b), axis=1) + input_multi = input_a * input_b + input_sim = np.sum(input_multi, axis=1, keepdims=True) + + np_res = np.concatenate( + (input_a, input_b, input_multi, input_sim), axis=1) + + for _ in range(fields_num - 1): + input_a = np.random.random([ins_num, embed_dim]).astype(tp) + input_b = np.random.random([ins_num, embed_dim]).astype(tp) + input = np.concatenate((input, input_a, input_b), axis=1) + input_multi = input_a * input_b + input_sim = np.sum(input_multi, axis=1, keepdims=True) + + np_res = np.concatenate( + (np_res, input_a, input_b, input_multi, input_sim), axis=1) + + summary_input = np.zeros( + [3, (embed_dim * 3 + 1) * fields_num]).astype(tp) + summary_input[0, :] = 1e4 + summary_input[1, :] = 0.0 + summary_input[2, :] = 1e4 + + np_mean = summary_input[1, :] / summary_input[0, :] + np_scale = np.sqrt(summary_input[0, :] / summary_input[2, :]) + + self.inputs = {"Input": input, "SummaryInput": summary_input} + self.outputs = { + "Out": np_res, + "CudaMeans": np_mean, + "CudaScales": np_scale + } + self.attrs = {"fields_num": fields_num, "embed_dim": embed_dim} + + def test_check_output_cpu(self): + try: + self.check_output_with_place(place=core.CPUPlace()) + except: + print("do not support cpu test, skip") + + def test_check_grad_cpu(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["Input"], "Out") + except: + print("do not support cpu test, skip") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ad091b7e5fee9..586e24d4d38e4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3205,15 +3205,15 @@ def test_partial_sum(self): def test_batch_fc(self): with self.static_graph(): - input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32") + input = fluid.data(name="input", shape=[16, 2 * 3], dtype="float32") out = fluid.contrib.layers.batch_fc( input=input, - param_size=[16, 3, 10], + param_size=[2, 3 * 4], param_attr=fluid.ParamAttr( learning_rate=1.0, name="w_0", initializer=fluid.initializer.Xavier(uniform=False)), - bias_size=[16, 10], + bias_size=[1, 3 * 4], bias_attr=fluid.ParamAttr( learning_rate=1.0, name="b_0", @@ -3237,6 +3237,22 @@ def test_rank_attention(self): max_rank=3) return (out) + def test_cross_norm_layer_hadamard(self): + with self.static_graph(): + input = fluid.data(name="input", shape=[None, 2], dtype="float32") + param_attr = { + "batch_size": 1e4, + "batch_sum": 0.0, + "batch_square": 1e4 + } + out = fluid.contrib.layers.cross_norm_layer_hadamard( + input=input, + param_dict=param_attr, + fields_num=1, + embed_dim=1, + name="cross") + return (out) + def test_roi_pool(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph():