diff --git a/examples/language_model/bert/static_ipu/README.md b/examples/language_model/bert/static_ipu/README.md new file mode 100644 index 000000000000..cb0185035b7b --- /dev/null +++ b/examples/language_model/bert/static_ipu/README.md @@ -0,0 +1,204 @@ +# Paddle-BERT with Graphcore IPUs + +## Overview + +This project enabled BERT-Base pre-training and SQuAD fine-tuning task using [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) on Graphcore [IPU-POD16](https://www.graphcore.ai/products/mk2/ipu-pod16). + +## File Structure + +| File | Description | +| ------------------------ | ------------------------------------------------------------------ | +| `README.md` | How to run the model. | +| `run_pretrain.py` | The algorithm script to run pretraining tasks (phase1 and phase2). | +| `run_squad.py` | The algorithm script to run SQuAD finetune and validation task. | +| `modeling.py` | The algorithm script to build the Bert-Base model. | +| `dataset_ipu.py` | The algorithm script to load input data in pretraining. | +| `custom_ops/` | The folder contains custom ops that will be used. | +| `run_pretrain.sh` | Test script to run pretrain phase 1. | +| `run_pretrain_phase2.sh` | Test script to run pretrain phase 2. | +| `run_squad.sh` | Test script to run SQuAD finetune. | +| `run_squad_infer.sh` | Test script to run SQuAD validation. | + +## Dataset + +- Pretraining dataset + + Wikipedia dataset is used to do pretraining. Please refer to the Wikipedia dataset generator provided by [Nvidia](https://github.com/NVIDIA/DeepLearningExamples.git) to generate pretraining dataset. + + The sequence length used in pretraining phase1 and phase2 are: 128 and 384. Following steps are provided for dataset generation. + + ```bash + # Here we use a specific commmit, the latest commit should also be fine. + git clone https://github.com/NVIDIA/DeepLearningExamples.git + git checkout 88eb3cff2f03dad85035621d041e23a14345999e + + cd DeepLearningExamples/PyTorch/LanguageModeling/BERT + + # Modified the parameters `--max_seq_length 512` to `--max_seq_length 384` at line 50 and + # `--max_predictions_per_seq 80` to `--max_predictions_per_seq 56` at line 51. + vim data/create_datasets_from_start.sh + + # Build docker image + bash scripts/docker/build.sh + + # Use NV's docker to download and generate hdf5 file. This may requires GPU available. + # You can Remove `--gpus $NV_VISIBLE_DEVICES` to avoid GPU requirements. + bash scripts/docker/launch.sh + + # generate dataset with wiki_only + bash data/create_datasets_from_start.sh wiki_only + ``` + +- SQuAD v1.1 dataset + + SQuAD v1.1 dataset will be downloaded automatically. You don't have to download manually. + + +## Quick Start Guide + +### Prepare Project Environment + +- Create docker image + +```bash +# clone paddle repo +git clone https://github.com/paddlepaddle/Paddle.git +cd Paddle + +# build docker image +docker build -t paddlepaddle/paddle:latest-dev-ipu -f tools/dockerfile/Dockerfile.ipu . +``` + +- Create docker container + +```bash +# clone paddlenlp repo +git clone https://github.com/paddlepaddle/paddlenlp.git +cd paddlenlp/examples/language_model/bert/static_ipu + +# create docker container +# the ipuof configuration file need to be pre-generated and mounted to docker container +# the environment variable IPUOF_CONFIG_PATH should point to the ipuof configuration file +# more information on ipuof configuration is available at https://docs.graphcore.ai/projects/vipu-admin/en/latest/cli_reference.html?highlight=ipuof#ipuof-configuration-file +docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \ +--device=/dev/infiniband/ --ipc=host \ +--name paddle-bert-base \ +-v ${IPUOF_CONFIG_PATH}:/ipu.conf \ +-e IPUOF_CONFIG_PATH=/ipu.conf \ +-v ${PWD}:/workdir \ +-w /home -it paddlepaddle/paddle:latest-dev-ipu bash +``` + +All of later processes are required to be executed in the container. + +- Compile and installation + +```bash +# clone paddle repo +git clone https://github.com/paddlepaddle/Paddle.git +cd Paddle + +mkdir build && cd build + +# run cmake +cmake .. -DWITH_IPU=ON -DWITH_PYTHON=ON -DPY_VERSION=3.7 -DWITH_MKL=ON \ + -DPOPLAR_DIR=/opt/poplar -DPOPART_DIR=/opt/popart -DCMAKE_BUILD_TYPE=Release + +# compile +make paddle_python -j$(nproc) + +# install paddle package +pip install -U python/dist/paddlepaddle-0.0.0-cp37-cp37m-linux_x86_64.whl + +# go to workdir +cd /workdir +``` + +### Execution + +- Run pretraining phase1 (sequence_length = 128) + +```bash +./run_pretrain.sh +``` + +- Run pretraining phase2 (sequence_length = 384) + +```bash +./run_pretrain_phase2.sh +``` + +- Run SQuAD finetune task + +```bash +./run_squad.sh +``` + +- Run SQuAD validation + +```bash +./run_squad_infer.sh +``` + +#### Parameters + +- `task` The type of the NLP model. +- `input_files` The directory of the input data. +- `output_dir` The directory of the trained models. +- `is_training` Training or inference. +- `seq_len` The sequence length. +- `vocab_size` Size of the vocabulary. +- `max_predictions_per_seq` The max number of the masked token each sentence. +- `max_position_embeddings` The length of the input mask. +- `num_hidden_layers` The number of encoder layers. +- `hidden_size` The size of the hidden state of the transformer layers size. +- `ignore_index` The ignore index for the masked position. +- `hidden_dropout_prob` The dropout probability for fully connected layer in embedding and encoder +- `attention_probs_dropout_prob` The dropout probability for attention layer in encoder. +- `learning_rate` The learning rate for training. +- `weight_decay` The weight decay. +- `beta1` The Adam/Lamb beta1 value +- `beta2` The Adam/Lamb beta2 value +- `adam_epsilon` Epsilon for Adam optimizer. +- `max_steps` The max training steps. +- `warmup_steps` The warmup steps used to update learning rate with lr_schedule. +- `scale_loss` The loss scaling. +- `accl1_type` set accl1 type to FLOAT or FLOAT16 +- `accl2_type` set accl2 type to FLOAT or FLOAT16 +- `weight_decay_mode` decay or l2 regularization +- `optimizer_state_offchip` The store location of the optimizer tensors +- `logging_steps` The gap steps of logging. +- `save_steps` Save the paddle model every n steps. +- `epochs` the iteration of the whole dataset. +- `batch_size` total batch size (= batches_per_step \* num_replica \* grad_acc_factor \* micro_batch_size). +- `micro_batch_size` The batch size of the IPU graph. +- `batches_per_step` The number of batches per step with pipelining. +- `seed` The random seed. +- `num_ipus` The number of IPUs. +- `ipu_enable_fp16` Enable FP16 or not. +- `num_replica` The number of the graph replication. +- `enable_grad_acc` Enable gradiant accumulation or not. +- `grad_acc_factor` Update the weights every n batches. +- `available_mem_proportion` The available proportion of memory used by conv or matmul. +- `shuffle` Shuffle Dataset. +- `wandb` Enable logging to Weights and Biases. +- `enable_engine_caching` Enable engine caching or not. +- `enable_load_params` Load paddle params or not. +- `tf_checkpoint` Path to Tensorflow Checkpoint to initialise the model. + +## Result + +| Task | Metric | Result | +| ------ | -------- | ------- | +| Phase1 | MLM Loss | 1.6064 | +| | NSP Loss | 0.0272 | +| | MLM Acc | 0.6689 | +| | NSP Acc | 0.9897 | +| | tput | 11700 | +| Phase2 | MLM Loss | 1.5029 | +| | NSP Loss | 0.02444 | +| | MLM Acc | 0.68555 | +| | NSP Acc | 0.99121 | +| | tput | 3470 | +| SQuAD | EM | 79.9053 | +| | F1 | 87.6396 | diff --git a/examples/language_model/bert/static_ipu/custom_ops/custom_checkpointoutput.cc b/examples/language_model/bert/static_ipu/custom_ops/custom_checkpointoutput.cc new file mode 100644 index 000000000000..edc7eec8fbf3 --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/custom_checkpointoutput.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +namespace { +std::vector> InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDtype(paddle::DataType x_dtype) { + return {x_dtype}; +} + +std::vector OpForward(const paddle::Tensor &x) { return {x}; } + +std::vector OpBackward(const paddle::Tensor &x) { return {x}; } +} + +PD_BUILD_OP(checkpointoutput) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)) + .SetKernelFn(PD_KERNEL(OpForward)); + +PD_BUILD_GRAD_OP(checkpointoutput) + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(OpBackward)); diff --git a/examples/language_model/bert/static_ipu/custom_ops/custom_detach.cc b/examples/language_model/bert/static_ipu/custom_ops/custom_detach.cc new file mode 100644 index 000000000000..2796fd07d60d --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/custom_detach.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +namespace { +std::vector> +InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDtype(paddle::DataType x_dtype) { + return {x_dtype}; +} + +std::vector OpForward(const paddle::Tensor &x) { return {x}; } + +std::vector OpBackward(const paddle::Tensor &x) { return {x}; } +} + +PD_BUILD_OP(detach) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)) + .SetKernelFn(PD_KERNEL(OpForward)); + +PD_BUILD_GRAD_OP(detach) + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(OpBackward)); diff --git a/examples/language_model/bert/static_ipu/custom_ops/custom_identity.cc b/examples/language_model/bert/static_ipu/custom_ops/custom_identity.cc new file mode 100644 index 000000000000..1997d0e896c1 --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/custom_identity.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +namespace { +std::vector> InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDtype(paddle::DataType x_dtype) { + return {x_dtype}; +} + +std::vector OpForward(const paddle::Tensor &x) { return {x}; } + +std::vector OpBackward(const paddle::Tensor &x) { return {x}; } +} + +PD_BUILD_OP(identity) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)) + .SetKernelFn(PD_KERNEL(OpForward)); + +PD_BUILD_GRAD_OP(identity) + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(OpBackward)); diff --git a/examples/language_model/bert/static_ipu/custom_ops/custom_nll_loss.cc b/examples/language_model/bert/static_ipu/custom_ops/custom_nll_loss.cc new file mode 100644 index 000000000000..de874425a8dd --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/custom_nll_loss.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +namespace { +std::vector> +InferShape(std::vector x_shape, std::vector y_shape, + const int &reduction, const int &ignoreIndex, + const bool &inputIsLogProbability) { + // 0: sum, 1: mean, 2: none + if (reduction == 2) { + return {y_shape}; + } else { + return {{1}}; + } +} + +std::vector InferDtype(paddle::DataType x_dtype, + paddle::DataType y_dtype) { + return {x_dtype}; +} + +std::vector OpForward(const paddle::Tensor &x, + const paddle::Tensor &y) { + return {x}; +} + +std::vector OpBackward(const paddle::Tensor &x) { return {x}; } +} + +PD_BUILD_OP(custom_nll_loss) + .Inputs({"X", "Y"}) + .Outputs({"Out"}) + .Attrs({"reduction: int", "ignoreIndex: int", + "inputIsLogProbability: bool"}) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)) + .SetKernelFn(PD_KERNEL(OpForward)); + +PD_BUILD_GRAD_OP(custom_nll_loss) + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(OpBackward)); diff --git a/examples/language_model/bert/static_ipu/custom_ops/custom_shape_infer.cc b/examples/language_model/bert/static_ipu/custom_ops/custom_shape_infer.cc new file mode 100644 index 000000000000..74e144d8d7e6 --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/custom_shape_infer.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +auto splitShapeInferenceFun = [](popart::ShapeInferenceContext &ctx) { + auto numOutputs = ctx.getNumOutputs(); + auto type = ctx.inType(0); + auto shape = ctx.inShape(0); + auto axis = ctx.getAttribute("axis"); + auto split = ctx.getAttribute>("split"); + + for (int i = 0; i < numOutputs; i++) { + shape[axis] = split.at(i); + ctx.outInfo(i) = {type, shape}; + } +}; + +#if POPART_VERSION_MAJOR == 2 +#if POPART_VERSION_MINOR == 3 +// for version 2.3, need to register a shape inference function for Split op +static popart::RegisterShapeInferenceFunction + splitRegister11(popart::Onnx::Operators::Split_11, splitShapeInferenceFun); +#endif +#endif \ No newline at end of file diff --git a/examples/language_model/bert/static_ipu/custom_ops/disable_attn_dropout_bwd_pattern.cc b/examples/language_model/bert/static_ipu/custom_ops/disable_attn_dropout_bwd_pattern.cc new file mode 100644 index 000000000000..803ae20c658b --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/disable_attn_dropout_bwd_pattern.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.cc" + +// Tests have found that disabling dropout in the backwards pass of the attention, just before the softmax, +// can improve SQuAD fine-tuning. This pattern finds that op replaces it with an identity op. +class DisableAttnDropoutBwdPattern : public popart::PreAliasPattern { +public: + bool matches(popart::Op *op) const override { + int check_levels = 2; + + if (!op->isConvertibleTo()) { + return false; + } + + // Is dropout enabled? If ratio is 0, we don't need to apply the pattern. + auto dropoutGradOp = dynamic_cast(op); + if (dropoutGradOp->getRatio() == 0.f) { + return false; + } + + // The specific attention DropoutGradOp we want to cull sits between a matmul and a softmax, + // so we'll look through producers and consumers and see if we can find them. + auto grad = op->input->tensor(popart::DropoutGradOp::getGradInIndex()); + + // The MatMulPattern converts the MatMulLhsGradOp to a MatMulOp + // There doesn't seem to be a way to check if a pattern is enabled from inside another pattern. + // The IR holds the patterns object, but it’s inaccessible for checking the status of individual patterns. + // Check both, with the most likely first. + bool hasMatMulProducer = search_producers_for(grad, check_levels) != nullptr; + if (!hasMatMulProducer) { + hasMatMulProducer |= search_producers_for(grad, check_levels) != nullptr; + } + + return hasMatMulProducer && search_consumers_for(grad) != nullptr; + } + + std::vector touches(popart::Op *) const override { return {}; } + + bool apply(popart::Op *op) const override { + if (!op->isConvertibleTo()) { + return false; + } + + auto dropoutGradOp = dynamic_cast(op); + auto identityOp = makeReplacementOpInIr(popart::Onnx::Operators::Identity_1, + dropoutGradOp, + ""); + + auto inputId = dropoutGradOp->inId(popart::DropoutGradOp::getGradInIndex()); + auto outputId = dropoutGradOp->outId(popart::DropoutGradOp::getOutIndex()); + dropoutGradOp->disconnectAllInputs(); + dropoutGradOp->disconnectAllOutputs(); + dropoutGradOp->getGraph().eraseOp(dropoutGradOp->id); + + identityOp->connectInTensor(popart::IdentityOp::getInIndex(), inputId); + identityOp->connectOutTensor(popart::IdentityOp::getOutIndex(), outputId); + identityOp->setup(); + + return true; + } +}; + + +static popart::PatternCreator disableAttnDropoutBwdPatternCreator("DisableAttnDropoutBwdPattern", false); diff --git a/examples/language_model/bert/static_ipu/custom_ops/tied_gather.cc b/examples/language_model/bert/static_ipu/custom_ops/tied_gather.cc new file mode 100644 index 000000000000..2350ffd243c4 --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/tied_gather.cc @@ -0,0 +1,181 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace CustomOperators { + const popart::OperatorIdentifier TiedGather = {"ai.graphcore", "TiedGather", 1}; +} // namespace CustomOperators + +class TiedGatherOp; +class TiedGatherGradOp; + +class TiedGatherGradOp : public popart::GatherGradOp { +public: + TiedGatherGradOp(const popart::GatherOp &op, int64_t axis_) + : popart::GatherGradOp(op, axis_), + fwd_op(&op) {} + const popart::GatherOp *fwd_op; +}; + +class TiedGatherOp : public popart::GatherOp { +public: + TiedGatherOp(int64_t axis_, const popart::Op::Settings &settings_) + : popart::GatherOp(CustomOperators::TiedGather, axis_, settings_) {} + bool check_indices = true; + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + std::vector> getGradOps() { + std::vector> result; + result.push_back(std::make_unique(*this, getAxis())); + result[0]->pruneable = false; + return result; + } +}; + +class TiedGatherOpx : public popart::popx::Opx { +public: + TiedGatherOpx(popart::Op *op, popart::popx::Devicex *devicex) : popart::popx::Opx(op, devicex) { + verifyOp(op, CustomOperators::TiedGather); + // We always want this to layout its inputs + inputCreatorPriority = std::numeric_limits::max(); + } + + bool createsEquiv(int, const popart::popx::Opx *, int) const final { return false; } + + std::set mustExistBeforeCreate(int) const final { return {}; } + + popart::popx::InputCreatorType getInputCreatorType(int index0) const final { + return index0 == TiedGatherOp::dataInIndex() ? popart::popx::InputCreatorType::CanCreate + : popart::popx::Opx::getInputCreatorType(index0); + } + + poplar::Tensor createInput(popart::InIndex index, + const poplar::DebugNameAndId &dnai) const final { + popart::logging::debug("TiedGather asked to create index {}: name {}", index, dnai); + if (index != TiedGatherOp::dataInIndex()) { + throw popart::error("CustomOps Error: GatherOpx::createInput Cannot create input {}", index); + } + + auto inputInfo = inInfo(TiedGatherOp::indicesInIndex()); + auto weightInfo = inInfo(TiedGatherOp::dataInIndex()); + + unsigned inputSize = inputInfo.nelms(); + unsigned inChannels = weightInfo.dim(getOp().getAxis()); + unsigned outChannels = weightInfo.nelms() / inChannels; + + std::vector lhsShape = {inputSize, inChannels}; + std::vector rhsShape = {inChannels, outChannels}; + + return poplin::createMatMulInputRHS(graph(), + popart::popx::popType(weightInfo), + lhsShape, + rhsShape, + dnai, + {}, + &dv_p->matmulCache); + } + + // Identical to popart::opx::GatherOpx::grow however: + // 1) uses popops::gather instead of popops::multislice + // 2) range checks the indices and masks those out of range + void grow(poplar::program::Sequence &prog) const final { + const auto indicesShape = inShape(TiedGatherOp::indicesInIndex()); + const auto outputShape = + popart::vXtoY(outShape(TiedGatherOp::outIndex())); + + auto op = getOp(); + unsigned axis = op.getAxis(); + auto indices = getInTensor(TiedGatherOp::indicesInIndex()); + auto data = getInTensor(TiedGatherOp::dataInIndex()); + + // If there are no indices, return an empty tensor of the appropriate + // shape + if (indices.numElements() == 0) { + auto result = graph().addVariable( + data.elementType(), outputShape, debugContext("result")); + + setOutTensor(TiedGatherOp::outIndex(), result); + } else { + // Flatten the scalar indices. + auto offsets = indices.flatten(); + // reinterpret the indices as unsigned int. This assumes negative indices. + // are impossible. + offsets = offsets.reinterpret(poplar::UNSIGNED_INT); + + // Place the gather axis at the front. + data = data.dimShufflePartial({0}, {axis}); + // Store the shape for later. + auto tmp_shape = data.shape(); + // Flatten the other dimensions. + data = data.flatten(1, data.rank()); + + // Change (2) + poplar::Tensor mask; + if (op.check_indices) { + auto gather_size = data.shape()[0]; + mask = popops::lt(graph(), offsets, static_cast(gather_size), prog, debugContext("mask + tiedGatherOpxCreator(CustomOperators::TiedGather); diff --git a/examples/language_model/bert/static_ipu/custom_ops/tied_gather_pattern.cc b/examples/language_model/bert/static_ipu/custom_ops/tied_gather_pattern.cc new file mode 100644 index 000000000000..ddbe4bd151a0 --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/tied_gather_pattern.cc @@ -0,0 +1,504 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "tied_gather.cc" +#include "utils.cc" + +using SerialiseSettings = popart::MatMulBaseOp::SerialiseSettings; + +// This pattern matches for graphs of the shape. +// +// Weight +// / \ +// Transpose MatMul +// | +// Indices --Gather +// +// And performs the following transformations: +// 1) Disable FullyConnectedPass on MatMul +// 2) Add Detach between the Gather and the Weight so no SGD ops are created (they will be added later by TiedGatherAccumulatePattern) +// 3) Replace Gather with TiedGather +// Resulting in: +// Weight +// / \ +// Transpose MatMul +// | +// Detach +// | +// Indices --TiedGather +// +// Conditionally, if MatMul is annotated with serialisation it will: +// 4) Replace Gather with N x TiedGather to match the serialisation on the MatMul +// Resulting in: +// For serialisation factor: 2 +// +// Weight +// / \ +// Transpose MatMul +// | +// Indices Detach +// | | | | +// | | | Slice--\ +// | Sub -|------TiedGather +// | | | +// | Slice--\ | +// Sub ---------TiedGather | +// \ | +// Add +// +namespace { +bool produced_by_transpose(popart::Tensor *t) { + return t->hasProducer() && t->getProducer()->isConvertibleTo(); +} +} + +class TiedGatherPattern : public popart::PreAliasPattern { + mutable std::map tied_op_map; +public: + bool matches(popart::Op *op) const override { + auto &ir = op->getIr(); + // Only run in the fwd pass + if (op->getIr().hasConstructedBackwards()) { + return false; + } + if (op->getIr().isTraining() && !op->getIr().getSessionOptions().enableGradientAccumulation) { + return false; + } + if (op->isConvertibleTo() && !op->isConvertibleTo()) { + if (produced_by_transpose(op->input->tensor(popart::GatherOp::dataInIndex()))) { + auto matmul = weight_consumed_by(op->input->tensor(popart::GatherOp::dataInIndex())); + if (matmul) { + tied_op_map.insert({op, matmul}); + return true; + } + } + } + return false; + } + + std::vector touches(popart::Op *) const override { return {}; } + + bool apply(popart::Op *op) const override { + auto &graph = op->getGraph(); + + auto gather = dynamic_cast(op); + auto matmul = tied_op_map[gather]; + + // (1) + matmul->setUseFullyConnectedPass(false); + + auto axis = gather->getAxis(); + auto serialisation = matmul->getSerialiseSettings(); + + auto data = gather->input->tensor(popart::GatherOp::dataInIndex()); + auto indices = gather->input->tensor(popart::GatherOp::indicesInIndex()); + auto out = gather->output->tensor(popart::GatherOp::outIndex()); + + // Disconnect "out" so it can be connected to the replacing ops. + gather->disconnectAllOutputs(); + + // (2) + auto detach_up = std::make_unique( + popart::Onnx::CustomOperators::Detach_1, + popart::Op::Settings(graph, "TiedGatherDetach") + ); + auto detach = detach_up.get(); + transferBaseProperties(gather, detach); + graph.moveIntoGraph(std::move(detach_up)); + detach->connectInTensor(0, data->id); + auto detached_data_id = data->id + "/detached"; + detach->createAndConnectOutTensor(0, detached_data_id); + detach->setup(); + data = graph.getTensors().get(detached_data_id); + + std::string name = gather->name(); + if (name.empty()) { + name = std::to_string(gather->id); + } + + auto replace_with_tied_gather = [&](popart::TensorId dict, popart::TensorId ind, int64_t i, const std::string &debugContext) { + auto tied_gather_up = std::make_unique( + axis, + popart::Op::Settings(graph, debugContext)); + auto tied_gather = tied_gather_up.get(); + transferBaseProperties(gather, tied_gather); + graph.moveIntoGraph(std::move(tied_gather_up)); + + tied_gather->connectInTensor(TiedGatherOp::dataInIndex(), dict); + tied_gather->connectInTensor(TiedGatherOp::indicesInIndex(), ind); + + auto out_id = out->id; + if (i >= 0) { + out_id = debugContext + ":0"; + tied_gather->createAndConnectOutTensor(TiedGatherOp::outIndex(), out_id); + } else { + tied_gather->connectOutTensor(TiedGatherOp::outIndex(), out_id); + } + + graph.topoCons->transfer(gather, tied_gather); + + tied_gather->setup(); + + return out_id; + }; + + if (serialisation.factor <= 1 || serialisation.mode == SerialiseSettings::Mode::None) { + // (3) + replace_with_tied_gather(data->id, indices->id, -1, name); + } else { + // (4) + if (serialisation.mode != SerialiseSettings::Mode::OutputChannels) { + throw popart::error("CustomOps Error: Tied Gather Pattern only supports Serialisation::Mode::OutputChannels"); + } + + auto slice_op = [&](int64_t starts, int64_t ends, const std::string &debugContext) { + auto slice_up = std::make_unique( + popart::Onnx::AiOnnx::OpSet9::Slice, + std::vector({starts}), + std::vector({ends}), + std::vector({axis}), + popart::Op::Settings(graph, debugContext + "/slice")); + auto slice = slice_up.get(); + transferBaseProperties(gather, slice); + graph.moveIntoGraph(std::move(slice_up)); + slice->connectInTensor(popart::SliceOp::getInIndex(), data->id); + auto data_slice = debugContext + "/slice:0"; + slice->createAndConnectOutTensor(popart::SliceOp::getOutIndex(), data_slice); + slice->setup(); + return data_slice; + }; + + auto subtract_with_constant = [&](popart::Tensor *a, int64_t c, const std::string &debugContext) { + auto sub_up = std::make_unique( + popart::Onnx::Operators::Sub_7, + popart::Op::Settings(graph, debugContext + "/sub")); + auto sub = sub_up.get(); + transferBaseProperties(gather, sub); + graph.moveIntoGraph(std::move(sub_up)); + sub->connectInTensor(popart::SubtractOp::getArg0InIndex(), a->id); + // Create constant to subtract from + static unsigned i = 0; + auto sub_const_id = a->id + "_sub_const_" + std::to_string(i++); + popart::TensorInfo subInfo(a->info.dataType(), {1}); + std::vector d(1, c); + graph.getTensors().addConstInit(sub_const_id, subInfo, d.data()); + sub->connectInTensor(popart::SubtractOp::getArg1InIndex(), sub_const_id); + auto indices_sub = debugContext + "/sub:0"; + sub->createAndConnectOutTensor(popart::SubtractOp::getOutIndex(), indices_sub); + sub->setup(); + return indices_sub; + }; + + auto add_op = [&](popart::TensorId a, popart::TensorId b, popart::TensorId out, const std::string &debugContext) { + auto add_up = std::make_unique( + popart::Onnx::Operators::Add_6, + popart::Op::Settings(graph, debugContext + "/add")); + auto add = add_up.get(); + transferBaseProperties(gather, add); + graph.moveIntoGraph(std::move(add_up)); + add->connectInTensor(popart::AddOp::getArg0InIndex(), a); + add->connectInTensor(popart::AddOp::getArg1InIndex(), b); + if (graph.getTensors().contains(out)) { + add->connectOutTensor(popart::AddOp::getOutIndex(), out); + } else { + add->createAndConnectOutTensor(popart::AddOp::getOutIndex(), out); + } + add->setup(); + return out; + }; + + popart::TensorId tmp_id; + for (int64_t i = 0; i < serialisation.factor; i++) { + int64_t slice_size = data->info.dim(axis) / serialisation.factor; + auto serial_name = name + "/" + std::to_string(i); + // Slice the Dictionary + auto data_slice = slice_op(i * slice_size, (i + 1) * slice_size, serial_name); + // Subtract the indicies + auto indices_sub = subtract_with_constant(indices, i * slice_size, serial_name); + // Add the tied gather to the graph + auto next_id = replace_with_tied_gather(data_slice, indices_sub, i, serial_name); + + // Add the results + if (i == 0) { + tmp_id = next_id; + } else { + auto out_id = out->id; + if (i < serialisation.factor - 1) { + out_id += "_tmp" + std::to_string(i); + } + tmp_id = add_op(tmp_id, next_id, out_id, serial_name); + + // Tie the add to happen directly after the gather + graph.topoCons->insert( + graph.getTensors().get(next_id)->getProducer(), + graph.getTensors().get(tmp_id)->getProducer(), + true); + } + } + } + + gather->disconnectAllInputs(); + graph.eraseOp(gather->id); + + return true; + } +}; + +// This pattern matches for graphs of the shape. +// +// Weight +// | \ +// TiedGatherGrad MatMul +// | +// Accl - Accumulate +// +// And will perform the following transformation +// 1) Replace TiedGatherGrad with SparseAccumulate +// +// Resulting in: +// +// Weight +// | \ +// | MatMul +// | | +// | Accl - Accumulate +// | | | +// SparseAccumulate - Optimizer +// +// (--> is a topocon) + +class TiedGatherAccumulatePattern : public popart::PreAliasPattern { +public: + bool matches(popart::Op *op) const override { + // Only works with gradient accumulation + if (!op->getIr().getSessionOptions().enableGradientAccumulation) { + return false; + } + // Only run after the optimizers have been created + if (!op->getIr().hasDecomposedOptimizers()) { + return false; + } + return op->isConvertibleTo(); + } + + std::vector touches(popart::Op *) const override { return {}; } + + bool apply(popart::Op *op) const override { + auto gather_grad = dynamic_cast(op); + auto gather = gather_grad->fwd_op; + auto root_weight = get_variable(gather->input->tensor(popart::GatherOp::dataInIndex())); + + auto gather_ops = find_all_consumers(root_weight); + + auto &ir = op->getIr(); + + // Get all the Accumulate ops in the normal context + std::vector accumulate_ops; + + auto update_ops = find_all_consumers(root_weight); + if (update_ops.size() < 1) { + // OptimizerDecomposePattern has not run. + throw popart::error("CustomOps Error: Could not find update ops for weight {}", root_weight->id); + } + + for (size_t i = 0; i < update_ops.size(); i++) { + auto var_update = update_ops[i]; + + auto accum = var_update->inTensor(popart::VarUpdateWithUpdaterOp::getUpdaterInIndex()); + // Accumulate Ops in the normal fragment are Gradient Accumulation. + auto accl_op = search_producers_for(accum, 10); + + if (accl_op) { + auto exists = std::find_if(accumulate_ops.begin(), accumulate_ops.end(), [&accl_op](popart::Op* op){ return op->id == accl_op->id; }); + if (exists == accumulate_ops.end()) { + accumulate_ops.push_back(accl_op); + } + } else { + popart::logging::info("CustomOps Warning: Could not find outer AccumulateOp gradient accumulation via accumulator {}.", accum->id); + } + } + + if (accumulate_ops.size() != gather_ops.size()) { + throw popart::error("CustomOps Error: The number of gather ops ({}) does not match the number of accumulate ops ({}).", gather_ops.size(), accumulate_ops.size()); + } + + // Match up gather serial index to Accumulator's matmul index. + // TODO: Find a more robust way than sorting input ids + std::sort(accumulate_ops.begin(), accumulate_ops.end(), + [](const popart::Op *l, const popart::Op *r) { + return l->input->tensor(popart::AccumulateOp::getVarToUpdateInIndex())->id.compare( + r->input->tensor(popart::AccumulateOp::getVarToUpdateInIndex())->id) < 0; + }); + std::sort(gather_ops.begin(), gather_ops.end(), + [](const popart::Op *l, const popart::Op *r) { + return l->name().compare(r->name()) < 0; + }); + + auto itr = std::find(gather_ops.begin(), gather_ops.end(), gather); + if (itr == gather_ops.end()) { + throw popart::error("CustomOps Error: Could not find {} in the consumers of {}.", gather->name(), root_weight->id); + } + + unsigned serial_index = std::distance(gather_ops.begin(), itr); + + auto dense_accl = accumulate_ops[serial_index]; + + auto accl_id = dense_accl->inId(popart::AccumulateOp::getVarToUpdateInIndex()); + auto weight_id = gather->inId(popart::GatherOp::dataInIndex()); + popart::logging::pattern::info("Using tied accumulator {} for {}", accl_id, gather->name()); + + // Transpose must be inplace so the accumulator is actually updated + accl_id = transpose_inplace(accl_id, gather_grad); + + auto &graph = op->getGraph(); + + auto accum_type = dense_accl->getAccumulationType(); + popart::Tensor *factor = dense_accl->getFactor().isConst() ? nullptr : dense_accl->inTensor(popart::SparseAccumulateOp::getFactorInIndex()); + + if (factor != nullptr && accum_type == popart::AccumulationType::Mean) { + auto inv_counter = factor->id + "_inverse"; + if (!graph.getTensors().contains(inv_counter)) { + popart::TensorInfo one_info(factor->info.dataType(), {}); + std::vector one_data(one_info.nelms(), 1); + const auto &one_id = graph.getIr().createIntermediateTensorId("one"); + graph.getTensors().addConstInit(one_id, one_info, one_data.data()); + auto inv_op = graph.createConnectedOp( + {{popart::DivOp::getArg0InIndex(), one_id}, + {popart::DivOp::getArg1InIndex(), factor->id}}, + {{popart::DivOp::getOutIndex(), inv_counter}}, + popart::Onnx::Operators::Div_7, + popart::Op::Settings(graph, "mean_accumulate_inverse")); + transferBaseProperties(gather_grad, inv_op); + + for (auto cons : factor->consumers.getOps()) { + if (cons->isConvertibleTo() && + cons->inId(popart::AccumulateOp::getVarToUpdateInIndex()) == factor->id) { + graph.topoCons->insert(cons, inv_op); + } + } + } + accum_type = popart::AccumulationType::DampenedAdd; + factor = graph.getTensor(inv_counter); + } + + // Add sparseAccumulateOp. + auto sparse_accl_up = std::make_unique( + accum_type, + dense_accl->getFactor(), + gather_grad->getAxis(), + popart::Op::Settings(graph, "_tiedAccumulate/" + std::to_string(serial_index))); + + auto sparse_accl = sparse_accl_up.get(); + transferBaseProperties(gather_grad, sparse_accl); + graph.moveIntoGraph(std::move(sparse_accl_up)); + + // Inputs + // Accumulator + sparse_accl->connectInTensor(popart::SparseAccumulateOp::getVarToUpdateInIndex(), + accl_id); + // Gradients + sparse_accl->connectInTensor( + popart::SparseAccumulateOp::getUpdaterInIndex(), + gather_grad->inId(popart::GatherGradOp::gradInIndex())); + // Scale + if (!dense_accl->getFactor().isConst()) { + sparse_accl->connectInTensor( + // the index at which the dampening scale factor is received, + popart::SparseAccumulateOp::getFactorInIndex(), + // the name of the dampening scale factor + factor->id); + } + // Indices + sparse_accl->connectInTensor( + popart::SparseAccumulateOp::getIndicesInIndex(), + gather_grad->inId(popart::GatherGradOp::indicesInIndex())); + + // Original weight to be cloned + sparse_accl->connectInTensor( + popart::SparseAccumulateOp::getOriginalVarToUpdateInIndex(), + weight_id); + + // Transfer TopoCons + graph.topoCons->transfer(gather_grad, sparse_accl); + + // gatherGrad output that will be isolated + auto grad_Id = gather_grad->outId(TiedGatherGradOp::gradOutIndex()); + + // Remove TiedGatherGrad + gather_grad->disconnectAllInputs(); + gather_grad->disconnectAllOutputs(); + graph.eraseOp(gather_grad->id); + + // Outputs + sparse_accl->createAndConnectOutTensor( + popart::SparseAccumulateOp::getUpdatedVarOutIndex(), + sparse_accl->name() + ":0"); + + // remove the gatherGrad output + graph.getTensors().remove(grad_Id); + + // Finalise sparse op + sparse_accl->setup(); + + return true; + } + + popart::TensorId transpose_inplace(popart::TensorId tid, popart::Op *op) const { + auto &graph = op->getGraph(); + + // TransposeInplaceOp's constructor requires a transposeOp + auto outplace_up = std::make_unique( + popart::Onnx::AiOnnx::OpSet9::Transpose, + std::vector{1, 0}, + popart::Op::Settings(graph, tid + "_Transpose")); + auto transpose_up = outplace_up->getInplaceVariant(popart::Onnx::CustomOperators::TransposeInplace); + + auto transpose = transpose_up.get(); + transferBaseProperties(op, transpose); + graph.moveIntoGraph(std::move(transpose_up)); + + transpose->connectInTensor(popart::TransposeOp::getInIndex(), tid); + popart::TensorId out_id = tid + "/transposed"; + transpose->createAndConnectOutTensor(popart::TransposeOp::getOutIndex(), out_id); + + transpose->setup(); + return out_id; + } +}; + +static popart::PatternCreator TiedGatherPatternCreator("TiedGatherPattern", true); +static popart::PatternCreator TiedGatherAccumulatePatternCreator("TiedGatherAccumulatePattern", true); diff --git a/examples/language_model/bert/static_ipu/custom_ops/utils.cc b/examples/language_model/bert/static_ipu/custom_ops/utils.cc new file mode 100644 index 000000000000..b6c6570f803c --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/utils.cc @@ -0,0 +1,173 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +static T *search_producers_for(popart::Tensor *t, int max_depth=-1) { + + // Searched as far as we can without success + if (t->tensorType() == popart::TensorType::Variable || !t->hasProducer()) { + return nullptr; + } + auto op = t->getProducer(); + if (op->isConvertibleTo() && op->settings.executionContext == Ctx) { + return dynamic_cast(op); + } + + if (op->input->n() < 1) { + return nullptr; + } + + unsigned producer_index = 0; + if (op->input->n() > 1) { + if (op->isConvertibleTo()) { + producer_index = popart::AdamUpdaterOp::getAccl1InIndex(); + } else if (op->isConvertibleTo()) { + producer_index = popart::AdamVarUpdateOp::getUpdaterInIndex(); + } else if (op->isConvertibleTo()) { + producer_index = popart::AccumulateBaseOp::getUpdaterInIndex(); + } else if (op->isConvertibleTo()) { + producer_index = popart::DropoutGradOp::getGradInIndex(); + } else if (op->isConvertibleTo()) { + // Grad Unscaling for Adam-based optimizers + producer_index = popart::MulOp::getArg0InIndex(); + } else if (op->isConvertibleTo()) { + // Replicated Tensor Sharding + producer_index = popart::ReplicatedReduceScatterOp::getInIndex(); + } else if (op->isConvertibleTo()) { + // Replicated Tensor Sharding + producer_index = popart::ReplicatedAllGatherOp::getInIndex(); + } else { + return nullptr; + } + } + + // Providing a max-search depth of -1 will remove the depth limit at the cost of potentially + // unnecessary checks. + if (max_depth > 0) { + max_depth -= 1; + if (max_depth == 0) { + return nullptr; + } + } + + return search_producers_for(op->input->tensor(producer_index), max_depth); +} + +// Finds the underlying variable by searching through producers. +static popart::Tensor *get_variable(popart::Tensor *t) { + if (t->tensorType() == popart::TensorType::Variable || t->tensorType() == popart::TensorType::Const) { + return t; + } else if (!t->hasProducer()) { + return nullptr; + } + auto op = t->getProducer(); + if (op->input->n() != 1) { + return nullptr; + } + return get_variable(op->input->tensors().front()); +} + +// Attempts to find T by searching through consumers. +template +static T *search_consumers_for(popart::Tensor *w, std::queue &q) { + for (auto consumer : w->consumers.getOps()) { + if (consumer->isConvertibleTo() && consumer->settings.executionContext == Ctx) { + return dynamic_cast(consumer); + } + + if (consumer->isConvertibleTo()) { + q.push(consumer->output->tensor(popart::DropoutGradOp::getGradInIndex())); + } + if (consumer->isConvertibleTo()) { + q.push(consumer->output->tensor( + popart::ReplicatedReduceScatterOp::getOutIndex())); + } + + // TODO: Improve this as it's too general. Most ops that have one input and one output are view changing. + if (consumer->input->n() == 1 && consumer->output->n() == 1) { + q.push(consumer->output->tensor(0)); + } + } + if (q.size() < 1) { + return nullptr; + } + w = q.front(); + q.pop(); + return search_consumers_for(w, q); +} +template +static T *search_consumers_for(popart::Tensor *w) { + std::queue q; + return search_consumers_for(w, q); +} + +template +static T *weight_consumed_by(popart::Tensor *w) { + w = get_variable(w); + if (w) { + return search_consumers_for(w); + } + return nullptr; +} + +template +static void find_all_consumers(popart::Tensor *w,std::queue &q, std::vector &result) { + for (auto consumer : w->consumers.getOps()) { + if (std::find(result.begin(), result.end(), consumer) == result.end()) { + if (consumer->isConvertibleTo() && consumer->settings.executionContext == Ctx) { + result.push_back(dynamic_cast(consumer)); + } + if (consumer->isConvertibleTo()) { + q.push(consumer->output->tensor(popart::MatMulOp::getOutIndex())); + } + if (consumer->isConvertibleTo()) { + q.push(consumer->output->tensor( + popart::ReplicatedReduceScatterOp::getOutIndex())); + } + // Most ops that have one input and one output are view changing. + if (consumer->input->n() == 1 && consumer->output->n() == 1) { + q.push(consumer->output->tensor(0)); + } + } + } + if (q.size() < 1) { + return; + } + w = q.front(); + q.pop(); + return find_all_consumers(w, q, result); +} +template +static std::vector find_all_consumers(popart::Tensor *w) { + std::queue q; + std::vector result; + find_all_consumers(w, q, result); + return result; +} diff --git a/examples/language_model/bert/static_ipu/custom_ops/workarounds/prevent_const_expr_folding_op.cc b/examples/language_model/bert/static_ipu/custom_ops/workarounds/prevent_const_expr_folding_op.cc new file mode 100644 index 000000000000..d6482ad4e98a --- /dev/null +++ b/examples/language_model/bert/static_ipu/custom_ops/workarounds/prevent_const_expr_folding_op.cc @@ -0,0 +1,137 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include + +namespace CustomOperators +{ + const popart::OperatorIdentifier PreventConstFolding = {"ai.graphcore", "PreventConstFolding", 1}; +} // namespace CustomOperators +namespace CustomGradOperators { + const popart::OperatorIdentifier PreventConstFoldingGrad = {"ai.graphcore", "PreventConstFoldingGrad", 1}; +} // namespace CustomGradOperators + +class PreventConstFoldingOp; +class PreventConstFoldingGradOp; +class PreventConstFoldingOpx; +class PreventConstFoldingGradOpx; + +// By default, const expressions ops get folded to optimise the graph and remove unnessary ops +// at the start. However, in this case, it causes the word embedding to exist in both its +// original and transposed form. By adding this op, the constant expression folding transform +// can't fold through it, so we prevent folding after this point. + +class PreventConstFoldingOp : public popart::Op +{ +public: + PreventConstFoldingOp(const popart::OperatorIdentifier &_opid, const Op::Settings &settings_) + : Op(_opid, settings_) {} + + void setup() final { outInfo(0) = inInfo(0); } + + std::unique_ptr clone() const { + return std::make_unique(*this); + } + + std::vector> getGradOps() { + std::vector> upops; + upops.emplace_back(std::make_unique(*this)); + return upops; + } + + float getSubgraphValue() const final { return getLowSubgraphValue(); } +}; + +static popart::OpDefinition PreventConstFoldingOpDef({}); + +static popart::OpCreator PreventConstFoldingOpCreator( + popart::OpDefinitions({{CustomOperators::PreventConstFolding, + PreventConstFoldingOpDef}}), + [](const popart::OpCreatorInfo &oci) -> std::unique_ptr { + return std::unique_ptr( + new PreventConstFoldingOp(oci.opid, oci.settings)); + }, + true); + +class PreventConstFoldingOpx : public popart::popx::Opx { +public: + PreventConstFoldingOpx(popart::Op *op, popart::popx::Devicex *devicex) : popart::popx::Opx(op, devicex) + { verifyOp(op, CustomOperators::PreventConstFolding); } + + popart::popx::InputCreatorType getInputCreatorType(popart::InIndex) const { + return popart::popx::InputCreatorType::CanUnwind; + } + + poplar::Tensor unwindTensorLayout(poplar::Tensor tensor, popart::InIndex, popart::OutIndex) const { + return tensor; + } + + popart::view::RegMap unwindRegion(popart::InIndex, popart::OutIndex) const { + return [this](const popart::view::Region &r) { + return popart::view::Regions(1, r); + }; + } + + void grow(poplar::program::Sequence &prog) const final { + insert(outId(0), getInTensor(0)); + } +}; + +class PreventConstFoldingGradOp : public PreventConstFoldingOp +{ +public: + PreventConstFoldingGradOp(const PreventConstFoldingOp &fwdOp) + : PreventConstFoldingOp(CustomGradOperators::PreventConstFoldingGrad, fwdOp.getSettings()) {} + + PreventConstFoldingGradOp(const popart::Op::Settings &settings) + : PreventConstFoldingOp(CustomGradOperators::PreventConstFoldingGrad, settings) {} + + std::unique_ptr clone() const final { + return std::make_unique(*this); + } + + const std::vector &gradInputInfo() const { + static const std::vector inInfo = { + {0, 0, popart::GradOpInType::GradOut}}; + + return inInfo; + } + const std::map &gradOutToNonGradIn() const { + static const std::map outInfo = {{0, 0}}; + return outInfo; + } +}; + +class PreventConstFoldingGradOpx : public popart::popx::Opx { +public: + PreventConstFoldingGradOpx(popart::Op *op, popart::popx::Devicex *devicex) + : popart::popx::Opx(op, devicex) { + verifyOp(op, CustomGradOperators::PreventConstFoldingGrad); + } + + void grow(poplar::program::Sequence &prog) const final { + setOutTensor(0, getInTensor(0)); + } +}; + +static popart::popx::OpxCreator + preventConstFoldingOpxCreator(CustomOperators::PreventConstFolding); +static popart::popx::OpxCreator + preventConstFoldingGradOpxCreator(CustomGradOperators::PreventConstFoldingGrad); diff --git a/examples/language_model/bert/static_ipu/dataset_ipu.py b/examples/language_model/bert/static_ipu/dataset_ipu.py new file mode 100644 index 000000000000..3703064b1f62 --- /dev/null +++ b/examples/language_model/bert/static_ipu/dataset_ipu.py @@ -0,0 +1,283 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import multiprocessing +import threading +from queue import Queue + +import h5py +import numpy as np +import paddle + +KEYS = ('input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', + 'masked_lm_ids', 'next_sentence_labels') + + +def shuffle_dict(dic, len): + idxs = np.arange(len) + np.random.shuffle(idxs) + for k, v in dic.items(): + dic[k] = v[idxs] + + +class PretrainingHDF5DataLoader: + def __init__(self, + input_files, + max_seq_length=128, + max_mask_tokens=20, + batch_size=1, + dtype=np.int32, + shuffle=False, + pad_position_value=511, + num_workers=3): + self.files = input_files + self.batch_size = batch_size + self.max_seq_length = max_seq_length + self.max_mask_tokens = max_mask_tokens + self.dtype = dtype + self.shuffle = shuffle + self.pad_position_value = pad_position_value + if shuffle: + np.random.shuffle(self.files) + + self.counter = 0 + + # get total number of samples + pool = multiprocessing.Pool(min(multiprocessing.cpu_count(), 32)) + num_samples = pool.map(self.samples_in_file, self.files) + pool.close() + pool.join() + self.total_samples = sum(num_samples) + self.len = self.total_samples // self.batch_size + assert self.len > 1, f"Batch size {self.batch_size} larger than number of samples {self.total_samples}" + + # notify feed and fetch processes/thread to stop + self.event_queue = multiprocessing.Manager().Queue(10) + + # buffer to store final data + self.feed_buffer = Queue(20) + + # number of processes to do remask + self.num_workers = num_workers + # each feed_worker has one process_buffer to use + self.process_buffers = [ + multiprocessing.Manager().Queue(10) for _ in range(num_workers) + ] + self.split_files = np.array_split(self.files, self.num_workers) + # feed_worker will load data from h5py files, and do remask process + self.feed_workers = [ + multiprocessing.Process( + target=self.fill_buffer_loop, + args=(self.split_files[idx], self.process_buffers[idx])) + for idx in range(self.num_workers) + ] + for p in self.feed_workers: + p.start() + + # index for which process_buffer is used each time + self.post_fetch_idx = 0 + # load final data from process_buffers + self.fetch_worker = threading.Thread(target=self.post_fetch) + self.fetch_worker.start() + + def samples_in_file(self, filename): + with h5py.File(filename, "r") as f: + data_len = f[KEYS[0]].shape[0] + return data_len + + def release(self): + self.event_queue.put('END') + while not self.feed_buffer.empty(): + self.feed_buffer.get() + for process_buffer in self.process_buffers: + while not process_buffer.empty(): + process_buffer.get() + self.fetch_worker.join() + for p in self.feed_workers: + p.join() + return + + def __len__(self): + return self.len + + def __iter__(self): + self.counter = 0 + return self + + def __next__(self): + result = self.feed_buffer.get() + self.counter += 1 + return result + + def post_fetch(self): + while True: + if not self.event_queue.empty(): + return + if not self.process_buffers[self.post_fetch_idx].empty(): + logging.debug(f"self.post_fetch_idx: {self.post_fetch_idx}") + np_feed_list = self.process_buffers[self.post_fetch_idx].get() + self.post_fetch_idx += 1 + if self.post_fetch_idx == self.num_workers: + self.post_fetch_idx = 0 + elif self.post_fetch_idx > self.num_workers: + raise Exception('post_fetch_idx must < num_workers') + + lod_feed_list = [] + for data in np_feed_list: + tensor = paddle.fluid.core.LoDTensor() + place = paddle.CPUPlace() + tensor.set(data, place) + lod_feed_list.append(tensor) + self.feed_buffer.put(lod_feed_list) + + def fill_buffer_loop(self, files, process_buffer): + data = None + data_index = 0 + file_index = 0 + + def multiprocess_fill_buffer(data, file_index, data_index): + if data is None: + data = self.load_one_file(files[file_index]) + file_index += 1 + data_index = 0 + + curr_batch = [] + still_required = self.batch_size + while still_required > 0: + data_batch = { + k: data[k][data_index:data_index + still_required] + for k in KEYS + } + data_batch_len = len(data_batch[KEYS[0]]) + data_index += data_batch_len + curr_batch.append(data_batch) + curr_batch_len = sum(len(x[KEYS[0]]) for x in curr_batch) + still_required = self.batch_size - curr_batch_len + if still_required > 0: + if file_index >= len(files): + np.random.shuffle(files) + file_index = 0 + + data = self.load_one_file(files[file_index]) + file_index += 1 + data_index = 0 + if not curr_batch_len == self.batch_size: + raise Exception("data length should equal to batch_size") + + result = {} + for k in KEYS: + result[k] = np.concatenate( + [item[k] for item in curr_batch], axis=0) + process_buffer.put(self.do_remask(result)) + + return data, file_index, data_index + + while True: + if self.event_queue.empty(): + data, file_index, data_index = multiprocess_fill_buffer( + data, file_index, data_index) + else: + return + + def do_remask(self, samples): + input_ids = samples['input_ids'] + segment_ids = samples['segment_ids'] + masked_lm_positions = samples['masked_lm_positions'] + masked_lm_ids = samples['masked_lm_ids'] + next_sentence_labels = samples['next_sentence_labels'] + masked_lm_weights = np.ones_like(masked_lm_ids, dtype=np.int32) + masked_lm_weights[masked_lm_ids == 0] = 0 + + # post process + batch_size, seq_len = input_ids.shape + formatted_pos = self.pad_position_value * np.ones_like(samples[ + 'input_ids']) + formatted_input = np.zeros_like(input_ids) + formatted_seg = np.zeros_like(segment_ids) + formatted_mask_labels = np.zeros( + (batch_size, self.max_mask_tokens), dtype=masked_lm_ids.dtype) + + valid_seq_positions = [] + valid_mask_positions = masked_lm_weights == 1 + valid_mask_len = np.sum(valid_mask_positions, axis=1).reshape(-1, 1) + for i, mask_pos in enumerate(masked_lm_positions): + pos = [True] * seq_len + for mask_index, m in enumerate(mask_pos): + if mask_index < valid_mask_len[i]: + pos[m] = False + valid_seq_positions.append(np.logical_and(pos, input_ids[i] != 0)) + valid_seq_len = np.minimum( + np.sum(valid_seq_positions, axis=1) + self.max_mask_tokens, + self.max_seq_length).reshape(-1, 1) + unmasked_len = np.minimum( + np.sum(valid_seq_positions, axis=1), + self.max_seq_length - self.max_mask_tokens) + for i in range(batch_size): + target_mask_indices = np.arange(valid_mask_len[i]) + target_seq_indices = self.max_mask_tokens + np.arange(unmasked_len[ + i]) + source_mask_indices = masked_lm_positions[i][valid_mask_positions[ + i]] + source_seq_indices = np.arange(seq_len)[valid_seq_positions[ + i]][:unmasked_len[i]] + + target_indices = np.hstack( + [target_mask_indices, target_seq_indices]) + source_indices = np.hstack( + [source_mask_indices, source_seq_indices]) + + formatted_pos[i, target_indices] = source_indices + formatted_input[i, target_indices] = input_ids[i, source_indices] + formatted_seg[i, target_indices] = segment_ids[i, source_indices] + formatted_mask_labels[i] = masked_lm_ids[i, :self.max_mask_tokens] + + return [ + formatted_input.astype(np.int32), formatted_seg.astype(np.int32), + formatted_pos.astype(np.int32), valid_mask_len.astype(np.int32), + valid_seq_len.astype(np.int32), + formatted_mask_labels.astype(np.int32), + next_sentence_labels.astype(np.int32) + ] + + def load_one_file(self, file_path): + data = self.load_hdf5(file_path) + + if self.shuffle: + shuffle_dict(data, len(data[KEYS[0]])) + + return data + + def load_hdf5(self, filename): + with h5py.File(filename, "r") as f: + data = {key: np.asarray(f[key][:]) for key in KEYS} + return data + + +if __name__ == "__main__": + import glob + base_dir = 'data_path/wikicorpus_en/' + input_files = glob.glob(f"{base_dir}/*training*.hdf5") + input_files.sort() + # print(input_files) + + seed = 1984 + np.random.seed(seed) + paddle.seed(seed) + + data_loader = PretrainingHDF5DataLoader( + input_files, batch_size=65536, shuffle=True) + + for idx, batch in enumerate(data_loader): + print(f"{idx}: {batch[0].shape()}") diff --git a/examples/language_model/bert/static_ipu/load_tf_ckpt.py b/examples/language_model/bert/static_ipu/load_tf_ckpt.py new file mode 100644 index 000000000000..4bad63fe2a9c --- /dev/null +++ b/examples/language_model/bert/static_ipu/load_tf_ckpt.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from logging import getLogger + +logger = getLogger(__name__) + + +def get_tf_mapping(args): + squad_mapping = { + "cls/squad/output_weights": "linear_72.w_0", + "cls/squad/output_bias": "linear_72.b_0" + } + + tf_to_pdmodel = { + "bert/embeddings/word_embeddings": "ipu_bert_embeddings_0.w_0", + "bert/embeddings/position_embeddings": "embedding_0.w_0", + "bert/embeddings/token_type_embeddings": "ipu_bert_embeddings_0.w_1", + "bert/embeddings/LayerNorm/gamma": "layer_norm_0.w_0", + "bert/embeddings/LayerNorm/beta": "layer_norm_0.b_0" + } + for i in range(args.num_hidden_layers): + layer = { + f"bert/encoder/layer_{i}/attention/self/query/bias": + f"bert_model_0.b_{i}", + f"bert/encoder/layer_{i}/attention/self/key/bias": + f"bert_model_0.b_{i}", + f"bert/encoder/layer_{i}/attention/self/value/bias": + f"bert_model_0.b_{i}", + f"bert/encoder/layer_{i}/attention/output/dense/kernel": + f"linear_{i*6}.w_0", + f"bert/encoder/layer_{i}/attention/output/dense/bias": + f"linear_{i*6}.b_0", + f"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma": + f"layer_norm_{i*4+2}.w_0", + f"bert/encoder/layer_{i}/attention/output/LayerNorm/beta": + f"layer_norm_{i*4+2}.b_0", + f"bert/encoder/layer_{i}/intermediate/dense/kernel": + f"linear_{i*6+2}.w_0", + f"bert/encoder/layer_{i}/intermediate/dense/bias": + f"linear_{i*6+2}.b_0", + f"bert/encoder/layer_{i}/output/dense/kernel": + f"linear_{i*6+3}.w_0", + f"bert/encoder/layer_{i}/output/dense/bias": f"linear_{i*6+3}.b_0", + f"bert/encoder/layer_{i}/output/LayerNorm/gamma": + f"layer_norm_{(i+1)*4}.w_0", + f"bert/encoder/layer_{i}/output/LayerNorm/beta": + f"layer_norm_{(i+1)*4}.b_0", + } + layer[ + f"bert/encoder/layer_{i}/attention/self/query/kernel"] = f"bert_model_0.w_{i*3+0}" + layer[ + f"bert/encoder/layer_{i}/attention/self/key/kernel"] = f"bert_model_0.w_{i*3+1}" + layer[ + f"bert/encoder/layer_{i}/attention/self/value/kernel"] = f"bert_model_0.w_{i*3+2}" + tf_to_pdmodel.update(**layer) + + if args.task == "PRETRAINING": + logger.error("Mapping ckpt weights is only supported in SQUAD task.") + elif args.task == "SQUAD": + tf_to_pdmodel.update(**squad_mapping) + + return tf_to_pdmodel + + +def generate_initializers(args, map_names, load_data, mapping, transform={}): + initializers = {} + initializers_param = {} + initializers_opt = {} + + qkv_tensor_range = { + "query": (0, args.hidden_size), + "key": (args.hidden_size, args.hidden_size * 2), + "value": (args.hidden_size * 2, args.hidden_size * 3), + } + + for name, array in zip(map_names, load_data): + logger.debug( + f"Initialising tensor from checkpoint {name} -> {mapping[name]}") + + # config["lamb_m_dtype"] is for setting the data type for accl1 of lamb + # BERT can use FP16 for accl1 without lossing accuracy + # accl2 is always in FP32 + lamb_m_dtype = np.float32 + dtype = np.float32 + + if "moment1" in mapping[name]: + if array.dtype != lamb_m_dtype: + array = array.astype(lamb_m_dtype) + elif "moment2" in mapping[name]: + if array.dtype != np.float32: + array = array.astype(np.float32) + elif array.dtype != dtype: + array = array.astype(dtype) + + # If it's part of QKV biases, we need to handle separately as those 3 + # tensors need concatenating into one + if "bert_model_0.b" in mapping[name]: + qkv_part = name.split("/")[5] + if mapping[name] not in initializers.keys(): + qkv_shape = (array.shape[0] * 3) + initializers[mapping[name]] = np.empty( + qkv_shape, dtype=array.dtype) + + start_idx = qkv_tensor_range[qkv_part][0] + end_idx = qkv_tensor_range[qkv_part][1] + initializers[mapping[name]][start_idx:end_idx] = array + logger.debug( + f"Initialising QKV_bias component {name}[{start_idx}:{end_idx}] from checkpoint" + ) + continue + + if name in transform: + array = transform[name](array) + + padded_vocab_length = args.vocab_size + if "bert_embeddings_0.w_0" in mapping[name]: + tf_vocab_length = array.shape[0] + diff = padded_vocab_length - tf_vocab_length + # Pad or Crop the vocab. + if diff > 0: + logger.info( + f"Padding the vocabulary. From {tf_vocab_length} to {padded_vocab_length}" + ) + pad = np.zeros((diff, args.hidden_size)).astype(array.dtype) + array = np.concatenate((array, pad), axis=0) + else: + logger.warning( + f"Cropping the vocabulary may negatively effect performance. From {tf_vocab_length} to {padded_vocab_length}" + ) + array = np.array(array[:padded_vocab_length, :]) + # if args.task == "PRETRAINING": + # We use transposed weight in both pretraining and squad + array = np.transpose(array, [1, 0]) + + if "embedding_0.w_0" in mapping[name]: + max_pos, hidden_len = array.shape + if max_pos > args.max_position_embeddings: + array = array[:args.max_position_embeddings, :] + + # Otherwise just copy the positional embeddings over and over again as is done in longformer + elif max_pos < args.max_position_embeddings: + logger.warning( + f"Not enough positional embeddings in checkpoint, copying to match length..." + ) + array = array[np.mod( + np.arange(args.max_position_embeddings), max_pos)] + + initializers[mapping[name]] = array.copy() + for k in initializers: + if "moment" in k: + initializers_opt[k] = initializers[k] + else: + initializers_param[k] = initializers[k] + return initializers_param, initializers_opt + + +# util function for load tf pretrained weight +def load_initializers_from_tf(file_path, args): + """ + Loads weights, etc. from Tensorflow files into a dictionary of Numpy Arrays. + + Can read either checkpoint files, or frozen graphs, according to the + `is_checkpoint` flag, passed in as the second argument. + """ + try: + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model requires TensorFlow to be installed. " + "Please see https://www.tensorflow.org/install/ for installation " + "instructions.") + raise + + tf_path = os.path.abspath(file_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + + mapping = get_tf_mapping(args) + map_names = [name for name, shape in init_vars if name in mapping.keys()] + for name in (n for n, _ in init_vars if n not in mapping.keys()): + logger.debug(f"Skipping load of {name} - Not in mapping") + + load_data = [tf.train.load_variable(tf_path, name) for name in map_names] + initializers, opt_params = generate_initializers(args, map_names, load_data, + mapping) + return initializers, opt_params diff --git a/examples/language_model/bert/static_ipu/modeling.py b/examples/language_model/bert/static_ipu/modeling.py new file mode 100755 index 000000000000..17e8d8900774 --- /dev/null +++ b/examples/language_model/bert/static_ipu/modeling.py @@ -0,0 +1,705 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.static +import paddle.fluid +from paddle.nn import Layer +from typing import List, NamedTuple, Optional +from contextlib import ExitStack + + +class DeviceScope(object): + def __init__(self, index, stage, name_scope=None): + self.index = index + self.stage = stage + self.name_scope = name_scope + + def __enter__(self): + self.stack = ExitStack() + self.stack.enter_context( + paddle.static.ipu_shard_guard( + index=self.index, stage=self.stage)) + if self.name_scope is not None: + self.stack.enter_context(paddle.static.name_scope(self.name_scope)) + return self + + def __exit__(self, *exp): + self.stack.close() + return False + + +class IpuBertConfig(NamedTuple): + """ + The configuration for BERT Model. + Args: + seq_len (int): + The sequence length. Default to `128`. + max_position_embeddings (int): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + max_predictions_per_seq (int): + The max number of the masked token each sentence. Default to `20`. + hidden_size (int): + Dimensionality of the embedding layer, encoder layer and pooler layer. Defaults to `768`. + vocab_size (int): + Vocabulary size of `inputs_ids` in `BertModel`. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `BertModel`. + num_hidden_layers (int): + Number of hidden layers in the Transformer encoder. Defaults to `12`. + available_mem_proportion (float): + The available proportion of memory used by conv or matmul. Default to `0.28`. + type_vocab_size (int): + The vocabulary size of `token_type_ids`. + Defaults to `2`. + hidden_dropout_prob (float): + The dropout probability for all fully connected layers in the embeddings and encoder. + Defaults to `0.1`. + attention_probs_dropout_prob (float): + The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. + Defaults to `0.1`. + task (str): + The type of the NLP model. + layers_per_ipu (list): + Number of attention layers executed on each IPU. + """ + micro_batch_size: int = 1 + seq_len: int = 128 + max_position_embeddings: int = 512 + max_predictions_per_seq: int = 20 + hidden_size: int = 768 + vocab_size: int = 30400 + num_hidden_layers: int = 12 + available_mem_proportion: float = 0.28 + type_vocab_size: int = 2 + + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + + # Choices: PRETRAINING (MLM + NSP), SQUAD + task: str = "PRETRAINING" + layers_per_ipu: List = None + + embeddings_scope: DeviceScope = None + attn_scopes: DeviceScope = None + ff_scopes: DeviceScope = None + mlm_scope: DeviceScope = None + nsp_scope: DeviceScope = None + + +class IpuBertEmbeddings(Layer): + """ + Include embeddings from word, position and token_type embeddings + """ + + def __init__(self, config, custom_ops=None): + super(IpuBertEmbeddings, self).__init__() + self.config = config + self.word_embeddings_weights = self.create_parameter( + shape=[config.hidden_size, config.vocab_size], dtype="float32") + self.token_embeddings_weights = self.create_parameter( + shape=[config.type_vocab_size, config.hidden_size], dtype="float32") + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=0.001) + self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + self.custom_ops = custom_ops + + def forward(self, indices, segments, positions): + # word embeddings + word_embeddings_weights = paddle.transpose(self.word_embeddings_weights, + [1, 0]) + input_embeddings = paddle.gather( + word_embeddings_weights, indices, axis=0) + + # position_embeddings + position_embeddings = self.position_embeddings(positions) + + # token_type_embeddings + token_type_embeddings = paddle.fluid.input.one_hot(segments, depth=2) + token_type_embeddings = paddle.matmul(token_type_embeddings, + self.token_embeddings_weights) + + embeddings = paddle.add(input_embeddings, position_embeddings) + embeddings = paddle.add(embeddings, token_type_embeddings) + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings, self.word_embeddings_weights + + +class BertModel(Layer): + """ + The bare BERT Model transformer outputting raw hidden-states. + + This model refers to :class:`~paddlenlp.transformers.bert.BertModel`. + + Args: + config (IpuBertConfig): + configuration of bert. + custom_ops: + custom defined operators which can be found in directory `custom_ops`. + """ + + def __init__(self, config, custom_ops=None): + super(BertModel, self).__init__() + self.config = config + self.custom_ops = custom_ops + + qk_scale = 1 / np.sqrt(self.config.hidden_size / + self.config.num_hidden_layers) + self.qk_scale_attrs = { + 'name': 'QK_scale', + 'shape': [1], + 'dtype': 'float32', + 'value': qk_scale, + } + self.qkv_shape = [-1, self.config.seq_len, 12, 64] + self.masks = {} + + self.embedding = IpuBertEmbeddings(self.config, custom_ops) + + def _encoder_layer_ipu_offset(self, layer_index): + encoder_index = 0 + if len(self.config.layers_per_ipu) == 1: + encoder_index = layer_index // self.config.layers_per_ipu[0] + else: + for ipu, num_layers in enumerate(self.config.layers_per_ipu): + layer_index -= num_layers + if layer_index < 0: + encoder_index = ipu + break + return encoder_index + + def should_checkpoint(self, layer_index): + encoder_index = self._encoder_layer_ipu_offset(layer_index) + if len(self.config.layers_per_ipu) == 1: + layers = self.config.layers_per_ipu[0] + layer_index -= encoder_index * layers + else: + layers = self.config.layers_per_ipu[encoder_index] + layer_index -= sum(self.config.layers_per_ipu[:encoder_index]) + return layer_index < (layers - 1) + + def forward(self, indices, segments, positions, input_mask): + r''' + The BertModel forward method, overrides the `__call__()` special method. + + Args: + indices (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int32` and it has a shape of [batch_size * sequence_length]. + segments (Tensor): + Segment token indices to indicate different portions of the inputs. + Selected in the range ``[0, type_vocab_size - 1]``. + Its data type should be `int32` and it has a shape of [batch_size * sequence_length]. + positions(Tensor): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + max_position_embeddings - 1]``. + Shape as `[batch_size * sequence_length]` and dtype as int32. + input_mask (Tensor, optional): + Mask used in multi-head attention to avoid performing attention on to some unwanted positions, + usually the paddings or the subsequent positions. + If the task is PRETRAINING: + input_mask[0] is the index that masking starts in the mask_tokens + input_mask[1] is the index that masking starts in the rest of the sequence + Otherwise + input_mask is the mask tensor that has -1000 in positions to be masked and 0 otherwise. + + Returns: + tuple: Returns tuple (`sequence_output`, `word_embeddings_weights`). + + With the fields: + + - `sequence_output` (Tensor): + Sequence of hidden-states at the last layer of the model. + It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size]. + ''' + + with self.config.embeddings_scope: + sequence_output, word_embeddings_weights = self.embedding( + indices, segments, positions) + + if self.config.task == "PRETRAINING": + with paddle.static.ipu_shard_guard(index=0, stage=0): + input_mask[0] = self.custom_ops.detach(input_mask[0]) + input_mask[1] = self.custom_ops.detach(input_mask[1]) + + for i in range(self.config.num_hidden_layers): + # Attention + attn_scope = self.config.attn_scopes[i] + with attn_scope: + with paddle.static.name_scope(f"Layer{i}/Attention"): + layer_input = sequence_output + q = self.create_parameter( + shape=[ + self.config.hidden_size, self.config.hidden_size + ], + dtype="float32") + k = self.create_parameter( + shape=[ + self.config.hidden_size, self.config.hidden_size + ], + dtype="float32") + v = self.create_parameter( + shape=[ + self.config.hidden_size, self.config.hidden_size + ], + dtype="float32") + qkv = paddle.concat([q, k, v], axis=1) + qkv = paddle.matmul(sequence_output, qkv) + qkv.block.ops[-1]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + q, k, v = paddle.split( + qkv, + num_or_sections=[ + self.config.hidden_size, self.config.hidden_size, + self.config.hidden_size + ], + axis=1) + q = paddle.reshape(q, self.qkv_shape) + q = paddle.transpose(q, [0, 2, 1, 3]) + k = paddle.reshape(k, self.qkv_shape) + k = paddle.transpose(k, [0, 2, 3, 1]) + v = paddle.reshape(v, self.qkv_shape) + v = paddle.transpose(v, [0, 2, 1, 3]) + + # Attention calculation + with paddle.static.name_scope(f"Z"): + if self.config.task == "PRETRAINING": + if attn_scope.index in self.masks: + final_mask = self.masks[attn_scope.index] + else: + with paddle.static.name_scope("Mask"): + base_value = np.arange( + self.config.seq_len).astype('int32') + base = paddle.fluid.layers.assign( + base_value) + mmask = paddle.less_than(base, + input_mask[0]) + mask_value = np.greater_equal( + base_value, + self.config.max_predictions_per_seq) + mask = paddle.fluid.layers.assign( + mask_value) + mmask = paddle.logical_or(mmask, mask) + smask = paddle.less_than(base, + input_mask[1]) + final_mask = paddle.logical_and(mmask, + smask) + final_mask = paddle.cast(final_mask, + "float16") + sub_attrs = { + 'name': 'constant_sub', + 'shape': [1], + 'dtype': 'float32', + 'value': 1, + } + mul_attrs = { + 'name': 'constant_mul', + 'shape': [1], + 'dtype': 'float32', + 'value': 1000, + } + final_mask = paddle.fluid.layers.elementwise_sub( + final_mask, + paddle.fluid.layers.fill_constant( + **sub_attrs)) + final_mask = paddle.fluid.layers.elementwise_mul( + final_mask, + paddle.fluid.layers.fill_constant( + **mul_attrs)) + final_mask = paddle.reshape( + final_mask, + [-1, 1, 1, self.config.seq_len]) + final_mask = self.custom_ops.detach( + final_mask) + self.masks[attn_scope.index] = final_mask + + qk = paddle.matmul(q, k) + qk.block.ops[-1]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + qk_scale = paddle.fluid.layers.fill_constant( + **self.qk_scale_attrs) + qk = paddle.fluid.layers.elementwise_mul(qk, qk_scale) + + if self.config.task == "PRETRAINING": + qk = paddle.fluid.layers.elementwise_add(qk, + final_mask) + else: + # for SQUAD task, input_mask is calculated in data preprocessing + qk = paddle.fluid.layers.elementwise_add(qk, + input_mask) + + qk = paddle.fluid.layers.softmax(qk) + if self.config.task == "SQUAD": + qk = paddle.fluid.layers.dropout( + qk, + self.config.attention_probs_dropout_prob, + dropout_implementation='upscale_in_train') + qkv = paddle.matmul(qk, v) + qkv.block.ops[-1]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + qkv = paddle.transpose(qkv, [0, 2, 1, 3]) + qkv = paddle.reshape(qkv, [-1, self.config.hidden_size]) + + qkv_linear = nn.Linear( + self.config.hidden_size, + self.config.hidden_size, + bias_attr=False) + qkv = qkv_linear(qkv) + qkv.block.ops[-1]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + qkv = paddle.fluid.layers.dropout( + qkv, + self.config.attention_probs_dropout_prob, + dropout_implementation='upscale_in_train') + attention = paddle.add(layer_input, qkv) + layer_norm1 = nn.LayerNorm( + self.config.hidden_size, epsilon=0.001) + attention = layer_norm1(attention) + + # FF + with self.config.ff_scopes[i]: + with paddle.static.name_scope(f"Layer{i}/FF"): + ff_linear1 = nn.Linear(self.config.hidden_size, + 4 * self.config.hidden_size) + ff_linear2 = nn.Linear(4 * self.config.hidden_size, + self.config.hidden_size) + with paddle.static.name_scope(f"1"): + ff = ff_linear1(attention) + ff.block.ops[-2]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + ff = paddle.fluid.layers.gelu(ff, approximate=True) + with paddle.static.name_scope(f"2"): + ff = ff_linear2(ff) + ff.block.ops[-2]._set_attr( + '__available_memory', + self.config.available_mem_proportion) + ff = paddle.fluid.layers.dropout( + ff, + self.config.attention_probs_dropout_prob, + dropout_implementation='upscale_in_train') + ff = paddle.add(attention, ff) + layer_norm2 = nn.LayerNorm( + self.config.hidden_size, epsilon=0.001) + sequence_output = layer_norm2(ff) + + if self.should_checkpoint(i): + with paddle.static.name_scope(f"Layer{i}"): + logging.info(f'add checkpointoutput for ff_{i}') + sequence_output = self.custom_ops.checkpointoutput( + sequence_output) + return sequence_output, word_embeddings_weights + + +class IpuBertForQuestionAnswering(Layer): + """ + Bert Model with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and + `span end logits`). + + Args: + hidden_size (int): + Dimensionality of the embedding layer, encoder layer and pooler layer. Defaults to `768`. + seq_len (int): + See :class:`IpuBertConfig`. + """ + + def __init__(self, hidden_size, seq_len): + super(IpuBertForQuestionAnswering, self).__init__() + self.hidden_size = hidden_size + self.seq_len = seq_len + self.classifier = nn.Linear(hidden_size, 2) + + def forward(self, sequence_output): + r""" + The IpuBertForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + sequence_output (Tensor): + See :class:`BertModel`. + + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + """ + logits = self.classifier(sequence_output) + + start_logits = paddle.slice( + input=logits, axes=[1], starts=[0], ends=[1]) + end_logits = paddle.slice(input=logits, axes=[1], starts=[1], ends=[2]) + + start_logits = paddle.reshape(start_logits, [-1, self.seq_len]) + end_logits = paddle.reshape(end_logits, [-1, self.seq_len]) + return start_logits, end_logits + + +class IpuBertQAAccAndLoss(paddle.nn.Layer): + """ + Criterion for Question and Answering. + """ + + def __init__(self, custom_ops=None): + super(IpuBertQAAccAndLoss, self).__init__() + self.custom_ops = custom_ops + + def forward(self, start_logits, end_logits, start_labels, end_labels): + r""" + The IpuBertQAAccAndLoss forward method, overrides the __call__() special method. + + Args: + start_logits (Tensor): + See :class:`IpuBertForQuestionAnswering`. + end_logits (Tensor): + See :class:`IpuBertForQuestionAnswering`. + start_labels (Tensor): + Labels for start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + end_labels (Tensor): + Labels for end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + """ + with paddle.static.name_scope("loss"): + start_loss = paddle.fluid.layers.softmax(start_logits) + start_loss = self.custom_ops.custom_nll_loss( + start_loss, start_labels, 1, -100, False) + end_loss = paddle.fluid.layers.softmax(end_logits) + end_loss = self.custom_ops.custom_nll_loss(end_loss, end_labels, 1, + -100, False) + loss = paddle.add(start_loss, end_loss) + + with paddle.static.name_scope("acc"): + start_logits = paddle.fluid.layers.argmax(start_logits, axis=1) + end_logits = paddle.fluid.layers.argmax(end_logits, axis=1) + start_equal = paddle.fluid.layers.equal(start_logits, start_labels) + end_equal = paddle.fluid.layers.equal(end_logits, end_labels) + start_equal = paddle.fluid.layers.cast(start_equal, 'float32') + end_equal = paddle.fluid.layers.cast(end_equal, 'float32') + start_acc = paddle.mean(start_equal) + end_acc = paddle.mean(end_equal) + + return start_acc, end_acc, loss + + +class IpuBertPretrainingMLMHeads(Layer): + """ + Perform language modeling task. + + Args: + hidden_size (int): + See :class:`IpuBertConfig`. + vocab_size (int): + See :class:`IpuBertConfig`. + max_position_embeddings (int): + See :class:`IpuBertConfig`. + max_predictions_per_seq (int): + See :class:`IpuBertConfig`. + seq_len (int): + See :class:`IpuBertConfig`. + """ + + def __init__(self, hidden_size, vocab_size, max_position_embeddings, + max_predictions_per_seq, seq_len): + super(IpuBertPretrainingMLMHeads, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.max_predictions_per_seq = max_predictions_per_seq + self.sequence_length = seq_len + self.transform = nn.Linear(hidden_size, hidden_size) + self.layer_norm = nn.LayerNorm(hidden_size, epsilon=0.001) + + def forward(self, encoders_output, word_embeddings_weights): + # cls + out = self.transform(encoders_output) + out = paddle.fluid.layers.gelu(out, approximate=True) + out = self.layer_norm(out) + + # mlm + out = paddle.reshape(out, [-1, self.sequence_length, self.hidden_size]) + out = paddle.slice(out, [1], [0], [self.max_predictions_per_seq]) + out = paddle.reshape(out, [-1, self.hidden_size]) + + # serialized matmul + out = paddle.matmul(out, word_embeddings_weights) + out.block.ops[-1]._set_attr('serialize_factor', 5) + mlm_out = paddle.reshape( + out, [-1, self.max_predictions_per_seq, self.vocab_size]) + + return mlm_out + + +class IpuBertPretrainingNSPHeads(Layer): + """ + Perform next sequence classification task. + + Args: + hidden_size (int): + See :class:`IpuBertConfig`. + max_predictions_per_seq (int): + See :class:`IpuBertConfig`. + seq_len (int): + See :class:`IpuBertConfig`. + """ + + def __init__(self, hidden_size, max_predictions_per_seq, seq_len): + super(IpuBertPretrainingNSPHeads, self).__init__() + self.hidden_size = hidden_size + self.max_predictions_per_seq = max_predictions_per_seq + self.seq_len = seq_len + self.seq_relationship = nn.Linear(hidden_size, 2) + self.pooler = IpuBertPooler(hidden_size, self.seq_len, + self.max_predictions_per_seq) + + def forward(self, encoders_output): + pooled_output = self.pooler(encoders_output) + nsp_out = self.seq_relationship(pooled_output) + return nsp_out + + +class IpuBertPooler(Layer): + """ + Pool the result of BertEncoder. + """ + + def __init__(self, + hidden_size, + sequence_length, + max_predictions_per_seq, + pool_act="tanh"): + super(IpuBertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + self.pool_act = pool_act + self.sequence_length = sequence_length + self.max_predictions_per_seq = max_predictions_per_seq + self.hidden_size = hidden_size + + def forward(self, hidden_states): + hidden_states = paddle.reshape( + hidden_states, [-1, self.sequence_length, self.hidden_size]) + first_token_tensor = paddle.slice( + input=hidden_states, + axes=[1], + starts=[self.max_predictions_per_seq], + ends=[self.max_predictions_per_seq + 1]) + first_token_tensor = paddle.reshape(first_token_tensor, + [-1, self.hidden_size]) + pooled_output = self.dense(first_token_tensor) + if self.pool_act == "tanh": + pooled_output = self.activation(pooled_output) + return pooled_output + + +class IpuBertPretrainingMLMAccAndLoss(Layer): + """ + Criterion for masked language modeling. + """ + + def __init__(self, micro_batch, ignore_index, custom_ops): + super(IpuBertPretrainingMLMAccAndLoss, self).__init__() + self.micro_batch = micro_batch + self.ignore_index = ignore_index + self.custom_ops = custom_ops + + def forward(self, mlm, masked_lm_ids): + mlm_pred = paddle.fluid.layers.argmax(mlm, axis=-1) + mlm_pred = paddle.cast(mlm_pred, "int32") + with paddle.static.name_scope("Accuracy"): + mlm_label = paddle.cast(masked_lm_ids, "int32") + mlm_correct = paddle.fluid.layers.equal(mlm_pred, mlm_label) + attrs = { + 'name': 'mlm_mask_val', + 'shape': [1], + 'dtype': 'int32', + 'value': self.ignore_index, + } + mlm_mask_val = paddle.fluid.layers.fill_constant(**attrs) + mlm_unmask = paddle.fluid.layers.equal(mlm_label, mlm_mask_val) + mlm_mask = paddle.logical_not(mlm_unmask) + mlm_mask = paddle.cast(mlm_mask, "float32") + mlm_correct = paddle.cast(mlm_correct, "float32") + masked_mlm_correct = paddle.fluid.layers.elementwise_mul( + mlm_correct, mlm_mask) + total_correct_tokens = paddle.fluid.layers.reduce_sum( + masked_mlm_correct) + total_tokens = paddle.fluid.layers.reduce_sum(mlm_mask) + total_correct_tokens = paddle.cast(total_correct_tokens, "float32") + total_tokens = paddle.cast(total_tokens, "float32") + mlm_acc = paddle.fluid.layers.elementwise_div(total_correct_tokens, + total_tokens) + + masked_lm_softmax = paddle.fluid.layers.softmax(mlm) + mlm_loss = self.custom_ops.custom_nll_loss( + masked_lm_softmax, masked_lm_ids, 1, self.ignore_index, False) + + return mlm_acc, mlm_loss + + +class IpuBertPretrainingNSPAccAndLoss(Layer): + """ + Criterion for next sequence classification. + """ + + def __init__(self, micro_batch, ignore_index, custom_ops): + super(IpuBertPretrainingNSPAccAndLoss, self).__init__() + self.micro_batch = micro_batch + self.ignore_index = ignore_index + self.custom_ops = custom_ops + + def forward(self, nsp, nsp_label): + nsp_pred = paddle.fluid.layers.argmax(nsp, axis=-1) + nsp_pred = paddle.cast(nsp_pred, "int32") + with paddle.static.name_scope("Accuracy"): + nsp_label = paddle.cast(nsp_label, "int32") + nsp_correct = paddle.fluid.layers.equal(nsp_pred, nsp_label) + nsp_correct = paddle.cast(nsp_correct, "int32") + nsp_correct = paddle.fluid.layers.reduce_sum(nsp_correct) + nsp_correct = paddle.cast(nsp_correct, "float32") + attrs = { + 'name': 'mlm_mask_val', + 'shape': [1], + 'dtype': 'int32', + 'value': self.micro_batch, + } + nsp_total = paddle.fluid.layers.fill_constant(**attrs) + nsp_total = paddle.cast(nsp_total, "float32") + nsp_acc = paddle.fluid.layers.elementwise_div(nsp_correct, + nsp_total) + + next_sentence_softmax = paddle.fluid.layers.softmax(nsp) + nsp_loss = self.custom_ops.custom_nll_loss(next_sentence_softmax, + nsp_label, 1, -100, False) + + return nsp_acc, nsp_loss diff --git a/examples/language_model/bert/static_ipu/requirements.txt b/examples/language_model/bert/static_ipu/requirements.txt new file mode 100644 index 000000000000..27244318193c --- /dev/null +++ b/examples/language_model/bert/static_ipu/requirements.txt @@ -0,0 +1,7 @@ +datasets +h5py +multiprocess +numpy +paddlenlp +scipy +wandb diff --git a/examples/language_model/bert/static_ipu/run_pretrain.py b/examples/language_model/bert/static_ipu/run_pretrain.py new file mode 100644 index 000000000000..eb9ca79df331 --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_pretrain.py @@ -0,0 +1,410 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import pickle +import random +import time + +import numpy as np +import paddle +import paddle.optimizer +import paddle.static +from paddlenlp.transformers import LinearDecayWithWarmup +from scipy.stats import truncnorm + +from dataset_ipu import PretrainingHDF5DataLoader +from modeling import ( + BertModel, DeviceScope, IpuBertConfig, IpuBertPretrainingMLMAccAndLoss, + IpuBertPretrainingMLMHeads, IpuBertPretrainingNSPAccAndLoss, + IpuBertPretrainingNSPHeads) +from utils import load_custom_ops, parse_args + + +def set_seed(seed): + """ + Use the same data seed(for data shuffle) for all procs to guarantee data + consistency after sharding. + """ + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + +def create_data_holder(args): + bs = args.micro_batch_size + indices = paddle.static.data( + name="indices", shape=[bs * args.seq_len], dtype="int32") + segments = paddle.static.data( + name="segments", shape=[bs * args.seq_len], dtype="int32") + positions = paddle.static.data( + name="positions", shape=[bs * args.seq_len], dtype="int32") + mask_tokens_mask_idx = paddle.static.data( + name="mask_tokens_mask_idx", shape=[bs, 1], dtype="int32") + sequence_mask_idx = paddle.static.data( + name="sequence_mask_idx", shape=[bs, 1], dtype="int32") + masked_lm_ids = paddle.static.data( + name="masked_lm_ids", + shape=[bs, args.max_predictions_per_seq], + dtype="int32") + next_sentence_labels = paddle.static.data( + name="next_sentence_labels", shape=[bs], dtype="int32") + return [ + indices, segments, positions, mask_tokens_mask_idx, sequence_mask_idx, + masked_lm_ids, next_sentence_labels + ] + + +def reset_program_state_dict(state_dict, mean=0, scale=0.02): + """ + Initialize the parameter from the bert config, and set the parameter by + reseting the state dict." + """ + new_state_dict = dict() + for n, p in state_dict.items(): + if n.endswith('_moment1_0') or n.endswith('_moment2_0') \ + or n.endswith('_beta2_pow_acc_0') or n.endswith('_beta1_pow_acc_0'): + continue + if 'learning_rate' in n: + continue + + dtype_str = "float32" + if p._dtype == paddle.float64: + dtype_str = "float64" + + if "layer_norm" in n and n.endswith('.w_0'): + new_state_dict[n] = np.ones(p.shape()).astype(dtype_str) + continue + + if n.endswith('.b_0'): + new_state_dict[n] = np.zeros(p.shape()).astype(dtype_str) + else: + new_state_dict[n] = truncnorm.rvs(-2, + 2, + loc=mean, + scale=scale, + size=p.shape()).astype(dtype_str) + return new_state_dict + + +def create_ipu_strategy(args): + ipu_strategy = paddle.static.IpuStrategy() + options = { + 'is_training': args.is_training, + 'enable_manual_shard': True, + 'enable_pipelining': True, + 'batches_per_step': args.batches_per_step, + 'micro_batch_size': args.micro_batch_size, + 'loss_scaling': args.scale_loss, + 'enable_replicated_graphs': True, + 'replicated_graph_count': args.num_replica, + 'num_ipus': args.num_ipus * args.num_replica, + 'enable_gradient_accumulation': args.enable_grad_acc, + 'accumulation_factor': args.grad_acc_factor, + 'auto_recomputation': 3, + 'enable_half_partial': True, + 'available_memory_proportion': args.available_mem_proportion, + 'enable_stochastic_rounding': True, + 'max_weight_norm': 65504.0, + 'default_prefetch_buffering_depth': 3, + 'rearrange_anchors_on_host': False, + 'enable_fp16': args.ipu_enable_fp16, + 'random_seed': args.seed, + 'use_no_bias_optimizer': True, + 'enable_prefetch_datastreams': True, + 'enable_outlining': True, + 'subgraph_copying_strategy': 1, # JustInTime + 'outline_threshold': 10.0, + 'disable_grad_accumulation_tensor_streams': True, + 'schedule_non_weight_update_gradient_consumers_early': True, + 'cache_path': 'paddle_cache', + 'enable_floating_point_checks': False, + 'accl1_type': args.accl1_type, + 'accl2_type': args.accl2_type, + 'weight_decay_mode': args.weight_decay_mode, + } + + if not args.optimizer_state_offchip: + options['location_optimizer'] = { + 'on_chip': 1, # popart::TensorStorage::OnChip + 'use_replicated_tensor_sharding': + 1, # popart::ReplicatedTensorSharding::On + } + + # use popart::AccumulateOuterFragmentSchedule::OverlapMemoryOptimized + # excludedVirtualGraphs = [0] + options['accumulate_outer_fragment'] = {3: [0]} + + options['convolution_options'] = {"partialsType": "half"} + options['engine_options'] = { + "opt.useAutoloader": "true", + "target.syncReplicasIndependently": "true", + "exchange.streamBufferOverlap": "hostRearrangeOnly", + } + + options['enable_engine_caching'] = args.enable_engine_caching + + ipu_strategy.set_options(options) + + # enable custom patterns + ipu_strategy.enable_pattern('DisableAttnDropoutBwdPattern') + + return ipu_strategy + + +def main(args): + paddle.enable_static() + place = paddle.set_device('ipu') + set_seed(args.seed) + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + + # The sharding of encoder layers + if args.num_hidden_layers == 12: + attn_index = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] + ff_index = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] + else: + raise Exception("Only support num_hidden_layers = 12") + + bert_config = { + k: getattr(args, k) + for k in IpuBertConfig._fields if hasattr(args, k) + } + bert_config['embeddings_scope'] = DeviceScope(0, 0, "Embedding") + bert_config['attn_scopes'] = [ + DeviceScope(attn_index[i], attn_index[i]) + for i in range(args.num_hidden_layers) + ] + bert_config['ff_scopes'] = [ + DeviceScope(ff_index[i], ff_index[i]) + for i in range(args.num_hidden_layers) + ] + bert_config['mlm_scope'] = DeviceScope(0, args.num_ipus, "MLM") + bert_config['nsp_scope'] = DeviceScope(0, args.num_ipus, "NSP") + bert_config['layers_per_ipu'] = [4, 4, 4] + + config = IpuBertConfig(**bert_config) + + # custom_ops + custom_ops = load_custom_ops() + + logging.info("Building Model") + + [ + indices, segments, positions, mask_tokens_mask_idx, sequence_mask_idx, + masked_lm_ids, next_sentence_labels + ] = create_data_holder(args) + + # Encoder Layers + bert_model = BertModel(config, custom_ops) + encoders, word_embedding = bert_model( + indices, segments, positions, + [mask_tokens_mask_idx, sequence_mask_idx]) + + # PretrainingHeads + mlm_heads = IpuBertPretrainingMLMHeads( + args.hidden_size, args.vocab_size, args.max_position_embeddings, + args.max_predictions_per_seq, args.seq_len) + nsp_heads = IpuBertPretrainingNSPHeads( + args.hidden_size, args.max_predictions_per_seq, args.seq_len) + + # AccAndLoss + nsp_criterion = IpuBertPretrainingNSPAccAndLoss( + args.micro_batch_size, args.ignore_index, custom_ops) + mlm_criterion = IpuBertPretrainingMLMAccAndLoss( + args.micro_batch_size, args.ignore_index, custom_ops) + + with config.nsp_scope: + nsp_out = nsp_heads(encoders) + nsp_acc, nsp_loss = nsp_criterion(nsp_out, next_sentence_labels) + + with config.mlm_scope: + mlm_out = mlm_heads(encoders, word_embedding) + mlm_acc, mlm_loss, = mlm_criterion(mlm_out, masked_lm_ids) + total_loss = mlm_loss + nsp_loss + + # lr_scheduler + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, args.max_steps, + args.warmup_steps) + # optimizer + optimizer = paddle.optimizer.Lamb( + learning_rate=lr_scheduler, + lamb_weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + epsilon=args.adam_epsilon) + optimizer.minimize(total_loss) + + # Static executor + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Set initial weights + state_dict = main_program.state_dict() + reset_state_dict = reset_program_state_dict(state_dict) + paddle.static.set_program_state(main_program, reset_state_dict) + + if args.enable_load_params: + logging.info(f'loading weights from: {args.load_params_path}') + if not args.load_params_path.endswith('pdparams'): + raise Exception('need pdparams file') + with open(args.load_params_path, 'rb') as file: + params = pickle.load(file) + paddle.static.set_program_state(main_program, params) + + # Create ipu_strategy + ipu_strategy = create_ipu_strategy(args) + + feed_list = [ + "indices", + "segments", + "positions", + "mask_tokens_mask_idx", + "sequence_mask_idx", + "masked_lm_ids", + "next_sentence_labels", + ] + fetch_list = [mlm_acc.name, mlm_loss.name, nsp_acc.name, nsp_loss.name] + + # Compile program for IPU + ipu_compiler = paddle.static.IpuCompiledProgram( + main_program, ipu_strategy=ipu_strategy) + logging.info(f'start compiling, please wait some minutes') + logging.info( + f'you can run `export POPART_LOG_LEVEL=INFO` before running program to see the compile progress' + ) + cur_time = time.time() + main_program = ipu_compiler.compile(feed_list, fetch_list) + time_cost = time.time() - cur_time + logging.info(f'finish compiling! time cost: {time_cost}') + + # Load the training dataset + input_files = [ + os.path.join(args.input_files, f) for f in os.listdir(args.input_files) + if os.path.isfile(os.path.join(args.input_files, f)) and "training" in f + ] + input_files.sort() + + dataset = PretrainingHDF5DataLoader( + input_files=input_files, + max_seq_length=args.seq_len, + max_mask_tokens=args.max_predictions_per_seq, + batch_size=args.batch_size, + shuffle=args.shuffle) + logging.info(f"dataset length: {len(dataset)}") + total_samples = dataset.total_samples + logging.info("total samples: %d, total batch_size: %d, max steps: %d" % + (total_samples, args.batch_size, args.max_steps)) + + batch_start = time.time() + global_step = 0 + for batch in dataset: + global_step += 1 + epoch = global_step * args.batch_size // total_samples + read_cost = time.time() - batch_start + + feed = { + "indices": batch[0], + "segments": batch[1], + "positions": batch[2], + "mask_tokens_mask_idx": batch[3], + "sequence_mask_idx": batch[4], + "masked_lm_ids": batch[5], + "next_sentence_labels": batch[6], + } + lr_scheduler.step() + + train_start = time.time() + loss_return = exe.run(main_program, + feed=feed, + fetch_list=fetch_list, + use_program_cache=True) + train_cost = time.time() - train_start + total_cost = time.time() - batch_start + tput = args.batch_size / total_cost + + if args.wandb: + wandb.log({ + "epoch": epoch, + "global_step": global_step, + "loss/MLM": np.mean(loss_return[1]), + "loss/NSP": np.mean(loss_return[3]), + "accuracy/MLM": np.mean(loss_return[0]), + "accuracy/NSP": np.mean(loss_return[2]), + "latency/read": read_cost, + "latency/train": train_cost, + "latency/e2e": total_cost, + "throughput": tput, + "learning_rate": lr_scheduler(), + }) + + if global_step % args.logging_steps == 0: + logging.info({ + "epoch": epoch, + "global_step": global_step, + "loss/MLM": np.mean(loss_return[1]), + "loss/NSP": np.mean(loss_return[3]), + "accuracy/MLM": np.mean(loss_return[0]), + "accuracy/NSP": np.mean(loss_return[2]), + "latency/read": read_cost, + "latency/train": train_cost, + "latency/e2e": total_cost, + "throughput": tput, + "learning_rate": lr_scheduler(), + }) + + if global_step % args.save_steps == 0: + ipu_compiler._backend.weights_to_host() + paddle.static.save(main_program.org_program, + os.path.join(args.output_dir, + 'step_{}'.format(global_step))) + + if global_step >= args.max_steps: + ipu_compiler._backend.weights_to_host() + paddle.static.save( + main_program.org_program, + os.path.join(args.output_dir, + 'final_step_{}'.format(global_step))) + dataset.release() + del dataset + return + + batch_start = time.time() + + +if __name__ == "__main__": + args = parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt='%Y-%m-%d %H:%M:%S %a') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + + if args.wandb: + import wandb + wandb.init( + project="paddle-base-bert", + settings=wandb.Settings(console='off'), + name='paddle-base-bert') + wandb_config = vars(args) + wandb_config["global_batch_size"] = args.batch_size + wandb.config.update(args) + + logging.info(args) + main(args) + logging.info("program finished") diff --git a/examples/language_model/bert/static_ipu/run_pretrain.sh b/examples/language_model/bert/static_ipu/run_pretrain.sh new file mode 100755 index 000000000000..cd1c5bb00f40 --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_pretrain.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +export RDMAV_FORK_SAFE=1 +python3 run_pretrain.py \ + --input_files "path_to_phase1_hdf5_dataset" \ + --output_dir pretrain_128_model \ + --seq_len 128 \ + --hidden_size 768 \ + --vocab_size 30400 \ + --max_predictions_per_seq 20 \ + --max_position_embeddings 512 \ + --learning_rate 0.006 \ + --weight_decay 1e-2 \ + --max_steps 7038 \ + --warmup_steps 2000 \ + --logging_steps 10 \ + --seed 1984 \ + --beta1 0.9 \ + --beta2 0.999 \ + --num_hidden_layers 12 \ + --micro_batch_size 32 \ + --ipu_enable_fp16 True \ + --scale_loss 512 \ + --batches_per_step 1 \ + --num_replica 4 \ + --enable_grad_acc True \ + --grad_acc_factor 512 \ + --batch_size 65536 \ + --available_mem_proportion 0.28 \ + --ignore_index 0 \ + --enable_load_params False \ + --hidden_dropout_prob 0.1 \ + --attention_probs_dropout_prob 0.1 \ + --shuffle True \ + --wandb False \ + --save_steps 1000 diff --git a/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh b/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh new file mode 100755 index 000000000000..8458ed48b6b2 --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +export RDMAV_FORK_SAFE=1 +python3 run_pretrain.py \ + --input_files "path_to_phase2_hdf5_dataset" \ + --output_dir pretrain_384_model \ + --seq_len 384 \ + --hidden_size 768 \ + --vocab_size 30400 \ + --max_predictions_per_seq 56 \ + --max_position_embeddings 512 \ + --learning_rate 0.002828427125 \ + --weight_decay 1e-2 \ + --max_steps 2137 \ + --warmup_steps 274 \ + --logging_steps 10 \ + --seed 1984 \ + --beta1 0.9 \ + --beta2 0.999 \ + --num_hidden_layers 12 \ + --micro_batch_size 8 \ + --ipu_enable_fp16 True \ + --scale_loss 128 \ + --batches_per_step 1 \ + --num_replica 4 \ + --enable_grad_acc True \ + --grad_acc_factor 512 \ + --batch_size 16384 \ + --available_mem_proportion 0.28 \ + --ignore_index 0 \ + --enable_load_params True \ + --load_params_path "./pretrain_128_model/final_step_7038.pdparams" \ + --hidden_dropout_prob 0.1 \ + --attention_probs_dropout_prob 0.1 \ + --shuffle True \ + --wandb False \ + --enable_engine_caching False \ + --save_steps 500 diff --git a/examples/language_model/bert/static_ipu/run_squad.py b/examples/language_model/bert/static_ipu/run_squad.py new file mode 100644 index 000000000000..bf2f37044d1c --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_squad.py @@ -0,0 +1,519 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pickle +import time +from functools import partial + +import numpy as np +import paddle +import paddle.optimizer +import paddle.static +from datasets import load_dataset +from paddle.io import BatchSampler, DataLoader +from paddlenlp.data import Dict, Stack +from paddlenlp.metrics.squad import compute_prediction, squad_evaluate +from paddlenlp.transformers import BertTokenizer, LinearDecayWithWarmup + +from modeling import (BertModel, DeviceScope, IpuBertConfig, + IpuBertForQuestionAnswering, IpuBertQAAccAndLoss) +from run_pretrain import (create_ipu_strategy, reset_program_state_dict, + set_seed) +from utils import load_custom_ops, parse_args + + +def create_data_holder(args): + bs = args.micro_batch_size + indices = paddle.static.data( + name="indices", shape=[bs * args.seq_len], dtype="int32") + segments = paddle.static.data( + name="segments", shape=[bs * args.seq_len], dtype="int32") + positions = paddle.static.data( + name="positions", shape=[bs * args.seq_len], dtype="int32") + input_mask = paddle.static.data( + name="input_mask", shape=[bs, 1, 1, args.seq_len], dtype="float32") + if not args.is_training: + return [indices, segments, positions, input_mask] + else: + start_labels = paddle.static.data( + name="start_labels", shape=[bs], dtype="int32") + end_labels = paddle.static.data( + name="end_labels", shape=[bs], dtype="int32") + return [ + indices, segments, positions, input_mask, start_labels, end_labels + ] + + +def prepare_train_features(examples, tokenizer, args): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + contexts = examples['context'] + questions = examples['question'] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + questions, + contexts, + stride=128, + max_seq_len=args.seq_len, + pad_to_max_seq_len=True, + return_position_ids=True, + return_token_type_ids=True, + return_attention_mask=True, + return_length=True) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + tokenized_examples["input_mask"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + sequence_ids = tokenized_examples['token_type_ids'][i] + + # attention_mask to input_mask + input_mask = ( + np.asarray(tokenized_examples["attention_mask"][i]) - 1) * 1e3 + input_mask = np.expand_dims(input_mask, axis=(0, 1)) + if args.ipu_enable_fp16: + input_mask = input_mask.astype(np.float16) + else: + input_mask = input_mask.astype(np.float32) + tokenized_examples["input_mask"].append(input_mask) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples['answers'][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != 1: + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != 1: + token_end_index -= 1 + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and + offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[ + token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - + 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + +def prepare_validation_features(examples, tokenizer, args): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + #NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is + # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead. + contexts = examples['context'] + questions = examples['question'] + tokenized_examples = tokenizer( + questions, + contexts, + stride=128, + max_seq_len=args.seq_len, + pad_to_max_seq_len=True, + return_position_ids=True, + return_token_type_ids=True, + return_attention_mask=True, + return_length=True) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + tokenized_examples["input_mask"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + input_ids = tokenized_examples["input_ids"][i] + sequence_A_lengths = input_ids.index(tokenizer.sep_token_id) + 2 + sequence_B_lengths = len(input_ids) - sequence_A_lengths + sequence_ids = [0] * sequence_A_lengths + [1] * sequence_B_lengths + context_index = 1 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + # attention_mask to input_mask + input_mask = ( + np.asarray(tokenized_examples["attention_mask"][i]) - 1) * 1e3 + input_mask = np.expand_dims(input_mask, axis=(0, 1)) + if args.ipu_enable_fp16: + input_mask = input_mask.astype(np.float16) + else: + input_mask = input_mask.astype(np.float32) + tokenized_examples["input_mask"].append(input_mask) + + return tokenized_examples + + +def load_squad_dataset(args): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + features_fn = prepare_train_features if args.is_training else prepare_validation_features + if args.is_training: + raw_dataset = load_dataset('squad', split='train') + else: + raw_dataset = load_dataset('squad', split='validation') + column_names = raw_dataset.column_names + dataset = raw_dataset.map(partial( + features_fn, tokenizer=tokenizer, args=args), + batched=True, + remove_columns=column_names, + num_proc=4) + + bs = args.micro_batch_size * args.grad_acc_factor * args.batches_per_step * args.num_replica + args.batch_size = bs + if args.is_training: + train_batch_sampler = BatchSampler( + dataset, batch_size=bs, shuffle=args.shuffle, drop_last=True) + else: + train_batch_sampler = BatchSampler( + dataset, batch_size=bs, shuffle=args.shuffle, drop_last=False) + + if args.is_training: + collate_fn = lambda samples, fn=Dict({ + "input_ids": Stack(), + "token_type_ids": Stack(), + "position_ids": Stack(), + "input_mask": Stack(), + "start_positions": Stack(), + "end_positions": Stack() + }): fn(samples) + else: + collate_fn = lambda samples, fn=Dict({ + "input_ids": Stack(), + "token_type_ids": Stack(), + "position_ids": Stack(), + "input_mask": Stack()}): fn(samples) + + data_loader = DataLoader( + dataset=dataset, + batch_sampler=train_batch_sampler, + collate_fn=collate_fn, + return_list=True) + return raw_dataset, data_loader + + +def main(args): + paddle.enable_static() + place = paddle.set_device('ipu') + set_seed(args.seed) + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + + # The sharding of encoder layers + if args.num_hidden_layers == 12: + attn_ipu_index = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] + ff_ipu_index = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] + else: + raise Exception("Only support num_hidden_layers = 12") + + bert_config = { + k: getattr(args, k) + for k in IpuBertConfig._fields if hasattr(args, k) + } + bert_config['embeddings_scope'] = DeviceScope(0, 0, "Embedding") + bert_config['attn_scopes'] = [ + DeviceScope(attn_ipu_index[i], attn_ipu_index[i]) + for i in range(args.num_hidden_layers) + ] + bert_config['ff_scopes'] = [ + DeviceScope(ff_ipu_index[i], ff_ipu_index[i]) + for i in range(args.num_hidden_layers) + ] + bert_config['layers_per_ipu'] = [6, 6] + + config = IpuBertConfig(**bert_config) + + # custom_ops + custom_ops = load_custom_ops() + + logging.info("building model") + + if args.is_training: + [indices, segments, positions, input_mask, start_labels, + end_labels] = create_data_holder(args) + else: + [indices, segments, positions, input_mask] = create_data_holder(args) + + # Encoder Layers + bert_model = BertModel(config, custom_ops) + encoders, _ = bert_model(indices, segments, positions, input_mask) + + squad_scope = DeviceScope(args.num_ipus - 1, args.num_ipus - 1, "squad") + with squad_scope: + qa_cls = IpuBertForQuestionAnswering(args.hidden_size, args.seq_len) + start_logits, end_logits = qa_cls(encoders) + + if args.is_training: + acc_loss = IpuBertQAAccAndLoss(custom_ops) + acc0, acc1, loss = acc_loss(start_logits, end_logits, start_labels, + end_labels) + + # load squad dataset + raw_dataset, data_loader = load_squad_dataset(args) + + total_samples = len(data_loader.dataset) + max_steps = total_samples // args.batch_size * args.epochs + logging.info("total samples: %d, total batch_size: %d, max steps: %d" % + (total_samples, args.batch_size, max_steps)) + + if args.is_training: + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps, + args.warmup_steps) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + epsilon=args.adam_epsilon) + optimizer.minimize(loss) + + # Static executor + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Set initial weights + state_dict = main_program.state_dict() + reset_state_dict = reset_program_state_dict(state_dict) + paddle.static.set_program_state(main_program, reset_state_dict) + + if args.enable_load_params: + logging.info(f'loading weights from: {args.load_params_path}') + if not args.load_params_path.endswith('pdparams'): + raise Exception('need pdparams file') + with open(args.load_params_path, 'rb') as file: + params = pickle.load(file) + # Delete mlm and nsp weights + if args.is_training and 'linear_72.w_0' in params: + params.pop("linear_72.w_0") + params.pop("linear_72.b_0") + paddle.static.set_program_state(main_program, params) + + if args.tf_checkpoint: + from load_tf_ckpt import load_initializers_from_tf + logging.info(f'loading weights from: {args.tf_checkpoint}') + initializers, _ = load_initializers_from_tf(args.tf_checkpoint, args) + paddle.static.set_program_state(main_program, initializers) + + # Create ipu_strategy + ipu_strategy = create_ipu_strategy(args) + + if args.is_training: + feed_list = [ + "indices", "segments", "positions", "input_mask", "start_labels", + "end_labels" + ] + fetch_list = [loss.name, acc0.name, acc1.name] + else: + feed_list = ["indices", "segments", "positions", "input_mask"] + fetch_list = [start_logits.name, end_logits.name] + + ipu_compiler = paddle.static.IpuCompiledProgram( + main_program, ipu_strategy=ipu_strategy) + logging.info(f'start compiling, please wait some minutes') + logging.info( + f'you can run `export POPART_LOG_LEVEL=INFO` before running program to see the compile progress' + ) + cur_time = time.time() + main_program = ipu_compiler.compile(feed_list, fetch_list) + time_cost = time.time() - cur_time + logging.info(f'finish compiling! time cost: {time_cost}') + + if args.is_training: + global_step = 0 + batch_start = time.time() + for epoch in range(args.epochs): + for batch in data_loader: + global_step += 1 + + feed = { + "indices": batch[0], + "segments": batch[1], + "positions": batch[2], + "input_mask": batch[3], + "start_labels": batch[4], + "end_labels": batch[5], + } + lr_scheduler.step() + + train_start = time.time() + outputs = exe.run(main_program, + feed=feed, + fetch_list=fetch_list, + use_program_cache=True) + train_cost = time.time() - train_start + total_cost = time.time() - batch_start + + tput = args.batch_size / total_cost + if args.wandb: + wandb.log({ + "epoch": epoch, + "global_step": global_step, + "loss": np.mean(outputs[0]), + "accuracy": np.mean(outputs[1:]), + "train_cost": train_cost, + "total_cost": total_cost, + "throughput": tput, + "learning_rate": lr_scheduler(), + }) + + if global_step % args.logging_steps == 0: + logging.info({ + "epoch": epoch, + "global_step": global_step, + "loss": np.mean(outputs[0]), + "accuracy": np.mean(outputs[1:]), + "train_cost": train_cost, + "total_cost": total_cost, + "throughput": tput, + "learning_rate": lr_scheduler(), + }) + + batch_start = time.time() + + # save final state + ipu_compiler._backend.weights_to_host() + paddle.static.save(main_program.org_program, + os.path.join(args.output_dir, 'Final_model')) + + if not args.is_training: + all_start_logits = [] + all_end_logits = [] + for step, batch in enumerate(data_loader): + if step % args.logging_steps == 0: + logging.info(f'running step: {step}') + + real_len = np.array(batch[0]).shape[0] + # padding zeros if needed + if real_len < args.batch_size: + batch = [np.asarray(x) for x in batch] + pad0 = np.zeros([args.batch_size - real_len, + args.seq_len]).astype(batch[0].dtype) + batch[0] = np.vstack((batch[0], pad0)) + batch[1] = np.vstack((batch[1], pad0)) + batch[2] = np.vstack((batch[2], pad0)) + pad1 = np.zeros( + [args.batch_size - real_len, 1, 1, args.seq_len]) - 1e3 + pad1 = pad1.astype(batch[3].dtype) + batch[3] = np.vstack((batch[3], pad1)) + + feed = { + "indices": batch[0], + "segments": batch[1], + "positions": batch[2], + "input_mask": batch[3], + } + start_logits, end_logits = exe.run(main_program, + feed=feed, + fetch_list=fetch_list) + + start_logits = start_logits.reshape([-1, args.seq_len]) + end_logits = end_logits.reshape([-1, args.seq_len]) + for idx in range(real_len): + all_start_logits.append(start_logits[idx]) + all_end_logits.append(end_logits[idx]) + + # evaluate results + all_predictions, all_nbest_json, scores_diff_json = compute_prediction( + raw_dataset, data_loader.dataset, + (all_start_logits, all_end_logits)) + squad_evaluate( + examples=[raw_data for raw_data in raw_dataset], + preds=all_predictions, + na_probs=scores_diff_json) + # write results to file + with open('squad_prediction.json', "w", encoding='utf-8') as writer: + writer.write( + json.dumps( + all_predictions, ensure_ascii=False, indent=4) + "\n") + + +if __name__ == "__main__": + args = parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt='%Y-%m-%d %H:%M:%S %a') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + + if args.wandb: + import wandb + wandb.init( + project="paddle-squad", + settings=wandb.Settings(console='off'), + name='paddle-squad') + wandb_config = vars(args) + wandb_config["global_batch_size"] = args.batch_size + wandb.config.update(args) + + logging.info(args) + main(args) + logging.info("program finished") diff --git a/examples/language_model/bert/static_ipu/run_squad.sh b/examples/language_model/bert/static_ipu/run_squad.sh new file mode 100755 index 000000000000..4c36ef69d6b1 --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_squad.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +python3 run_squad.py \ + --output_dir squad_model \ + --task "SQUAD" \ + --is_training True \ + --seq_len 384 \ + --hidden_size 768 \ + --vocab_size 30400 \ + --max_predictions_per_seq 56 \ + --max_position_embeddings 512 \ + --learning_rate 5.6e-05 \ + --weight_decay 0 \ + --epochs 4 \ + --warmup_steps 52 \ + --logging_steps 10 \ + --seed 42 \ + --beta1 0.9 \ + --beta2 0.999 \ + --num_hidden_layers 12 \ + --micro_batch_size 2 \ + --ipu_enable_fp16 True \ + --accl1_type "FLOAT" \ + --accl2_type "FLOAT" \ + --weight_decay_mode "decay" \ + --scale_loss 256 \ + --optimizer_state_offchip False \ + --batches_per_step 4 \ + --num_replica 4 \ + --num_ipus 2 \ + --enable_grad_acc True \ + --grad_acc_factor 16 \ + --available_mem_proportion 0.40 \ + --ignore_index 0 \ + --hidden_dropout_prob 0.1 \ + --attention_probs_dropout_prob 0.1 \ + --shuffle True \ + --wandb False \ + --enable_engine_caching False \ + --enable_load_params True \ + --load_params_path "pretrain_384_model/final_step_2137.pdparams" diff --git a/examples/language_model/bert/static_ipu/run_squad_infer.sh b/examples/language_model/bert/static_ipu/run_squad_infer.sh new file mode 100755 index 000000000000..28ffa7285443 --- /dev/null +++ b/examples/language_model/bert/static_ipu/run_squad_infer.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +python3 run_squad.py \ + --output_dir squad_model \ + --task "SQUAD" \ + --is_training False \ + --seq_len 384 \ + --hidden_size 768 \ + --vocab_size 30400 \ + --max_predictions_per_seq 56 \ + --max_position_embeddings 512 \ + --learning_rate 5.6e-05 \ + --weight_decay 1e-2 \ + --epochs 4 \ + --warmup_steps 52 \ + --logging_steps 10 \ + --seed 1984 \ + --beta1 0.9 \ + --beta2 0.999 \ + --num_hidden_layers 12 \ + --micro_batch_size 2 \ + --ipu_enable_fp16 True \ + --scale_loss 256 \ + --optimizer_state_offchip False \ + --batches_per_step 4 \ + --num_replica 4 \ + --num_ipus 2 \ + --enable_grad_acc False \ + --grad_acc_factor 1 \ + --available_mem_proportion 0.40 \ + --ignore_index 0 \ + --hidden_dropout_prob 0.0 \ + --attention_probs_dropout_prob 0.0 \ + --shuffle False \ + --wandb False \ + --enable_engine_caching False \ + --enable_load_params True \ + --load_params_path "squad_model/Final_model.pdparams" diff --git a/examples/language_model/bert/static_ipu/utils.py b/examples/language_model/bert/static_ipu/utils.py new file mode 100644 index 000000000000..3e444a61b7df --- /dev/null +++ b/examples/language_model/bert/static_ipu/utils.py @@ -0,0 +1,251 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from distutils.util import strtobool +from paddle.utils.cpp_extension import load + + +def load_custom_ops(): + cur_dir = os.path.dirname(os.path.realpath(__file__)) + custom_dir = cur_dir + "/custom_ops" + sources = [ + f"{custom_dir}/custom_shape_infer.cc", + f"{custom_dir}/custom_checkpointoutput.cc", + f"{custom_dir}/custom_detach.cc", f"{custom_dir}/custom_identity.cc", + f"{custom_dir}/custom_nll_loss.cc", + f"{custom_dir}/tied_gather_pattern.cc", f"{custom_dir}/tied_gather.cc", + f"{custom_dir}/disable_attn_dropout_bwd_pattern.cc", + f"{custom_dir}/workarounds/prevent_const_expr_folding_op.cc", + f"{custom_dir}/utils.cc" + ] + custom_ops = load( + name="custom_ops", + sources=sources, + extra_cxx_cflags=['-DONNX_NAMESPACE=onnx'], + build_directory=custom_dir, ) + return custom_ops + + +def str_to_bool(val): + return bool(strtobool(val)) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + type=str, + default="PRETRAINING", + help="task", ) + parser.add_argument( + "--input_files", + type=str, + default="", + help="Files to load data from. " + "For Pretraining: Path to tfrecord files" + "For SQuAD: Path to train-v1.1.json") + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=False, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--is_training", + type=str_to_bool, + default=True, + help="training or inference") + # graph + parser.add_argument( + "--seq_len", default=128, type=int, help="The sequence length") + parser.add_argument( + "--vocab_size", + default=30912, + type=int, + help="Set the size of the vocabulary") + parser.add_argument( + "--max_predictions_per_seq", + default=20, + type=int, + help="The maximum total of masked tokens in input sequence") + parser.add_argument( + "--max_position_embeddings", + default=512, + type=int, + help="the length of the input mask") + parser.add_argument( + "--num_hidden_layers", + type=int, + default=None, + help="Override config file if not None") + parser.add_argument( + "--hidden_size", + default=768, + type=int, + help="Set the size of the hidden state of the transformer layers size") + parser.add_argument( + "--ignore_index", type=int, default=-1, help="ignore mlm index") + parser.add_argument( + "--hidden_dropout_prob", + type=float, + default=0.1, + help="Set the layer dropout probability for fully connected layer in embedding and encoder", + ) + parser.add_argument( + "--attention_probs_dropout_prob", + type=float, + default=0.0, + help="Set the layer dropout probability for attention layer in encoder", + ) + # optimizer + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="Set the Adam/Lamb beta1 value") + parser.add_argument( + "--beta2", + type=float, + default=0.999, + help="Set the Adam/Lamb beta2 value") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--warmup_steps", + default=10, + type=int, + help="Linear warmup over warmup_steps.") + parser.add_argument( + "--scale_loss", + type=float, + default=1.0, + help="The value of scale_loss for fp16.") + parser.add_argument( + "--accl1_type", type=str, default='FLOAT', help="FLOAT or FLOAT16") + parser.add_argument( + "--accl2_type", type=str, default='FLOAT', help="FLOAT or FLOAT16") + parser.add_argument( + "--weight_decay_mode", + type=str, + default='', + help="decay or l2_regularization") + parser.add_argument( + "--optimizer_state_offchip", + type=str_to_bool, + default=True, + help="Set the store location of the optimizer tensors") + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.") + # ipu + parser.add_argument( + "--epochs", + type=int, + default=1, + help="the iteration of the whole dataset", ) + parser.add_argument( + "--batch_size", + default=8, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--micro_batch_size", type=int, default=1, help="micro batch size") + parser.add_argument( + "--batches_per_step", type=int, default=1, help="batches per step") + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for initialization") + parser.add_argument( + "--num_ipus", type=int, default=4, help="Number of IPUs to use") + parser.add_argument( + "--ipu_enable_fp16", + type=str_to_bool, + default=False, + help="ipu enable fp16 or not.") + parser.add_argument( + "--num_replica", type=int, default=1, help="number of replica") + parser.add_argument( + "--enable_grad_acc", + type=str_to_bool, + default=False, + help="enable gradient accumulation") + parser.add_argument( + "--grad_acc_factor", + type=int, + default=1, + help="factor of gradient accumulation") + parser.add_argument( + "--available_mem_proportion", + type=float, + default=0.0, + help="set the available memory proportion for matmul/conv") + parser.add_argument( + "--shuffle", + type=str_to_bool, + nargs="?", + const=True, + default=False, + help="Shuffle Dataset") + parser.add_argument( + "--wandb", + type=str_to_bool, + nargs="?", + const=True, + default=False, + help="Enable logging to Weights and Biases.") + parser.add_argument( + "--enable_load_params", + type=str_to_bool, + default=False, + help="load params or not") + parser.add_argument("--load_params_path", type=str, help="load params path") + parser.add_argument( + "--tf_checkpoint", + type=str, + help="Path to Tensorflow Checkpoint to initialise the model.") + parser.add_argument( + "--enable_engine_caching", + type=str_to_bool, + default=True, + help="enable engine caching or not") + args = parser.parse_args() + return args