From f8ae4ba0a0e8a48c5ee4b4d62ed5063f7a18e38f Mon Sep 17 00:00:00 2001 From: "Anastasia P. Tsitiridou" Date: Thu, 16 Feb 2023 18:29:32 +0100 Subject: [PATCH] gptj config changes to enable finetuning of gpt-j-6B and gpt-j-xl (#3785) * sync * clear notebook outputs and linting --- .../gpt-j/01_train_gptj_smp_notebook.ipynb | 29 + ...in_gptj_smp_tensor_parallel_notebook.ipynb | 499 +++++++++++++++++- .../pytorch/model_parallel/gpt-j/args.py | 28 +- .../model_parallel/gpt-j/data_pipeline.py | 28 +- .../model_parallel/gpt-j/learning_rates.py | 18 +- .../model_parallel/gpt-j/memory_tracker.py | 50 +- .../model_parallel/gpt-j/preprocess.py | 12 +- .../gpt-j/sharded_data_parallel_checkpoint.py | 98 ++-- .../model_parallel/gpt-j/smp_trainer.py | 4 +- .../gpt-j/train_gptj_smp_script.py | 30 +- .../train_gptj_smp_tensor_parallel_script.py | 229 +++++--- 11 files changed, 868 insertions(+), 157 deletions(-) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb b/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb index 4b4a0a61bc..d28e660859 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/01_train_gptj_smp_notebook.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "cf87f233", "metadata": {}, "source": [ "# Train EleutherAI GPT-J with PyTorch 1.8.1 and Pipeline Parallelism Using the SageMaker Model Parallelism Library\n", @@ -31,6 +32,7 @@ }, { "cell_type": "markdown", + "id": "77eb83a9", "metadata": {}, "source": [ "## SageMaker Distributed Training \n", @@ -67,6 +69,7 @@ }, { "cell_type": "markdown", + "id": "7aa9251a", "metadata": {}, "source": [ "### SageMaker Model Parallel configuration\n", @@ -90,6 +93,7 @@ }, { "cell_type": "markdown", + "id": "2b1f0327", "metadata": {}, "source": [ "#### Additional Resources\n", @@ -104,6 +108,7 @@ }, { "cell_type": "markdown", + "id": "f571615e", "metadata": {}, "source": [ "#### Amazon SageMaker Initialization\n", @@ -117,6 +122,7 @@ { "cell_type": "code", "execution_count": null, + "id": "95aaf0c2", "metadata": {}, "outputs": [], "source": [ @@ -126,6 +132,7 @@ { "cell_type": "code", "execution_count": null, + "id": "4408ceae", "metadata": { "scrolled": true }, @@ -137,6 +144,7 @@ { "cell_type": "code", "execution_count": null, + "id": "ebe376a8", "metadata": {}, "outputs": [], "source": [ @@ -146,6 +154,7 @@ { "cell_type": "code", "execution_count": null, + "id": "02a7d9e3", "metadata": { "scrolled": true }, @@ -157,6 +166,7 @@ { "cell_type": "code", "execution_count": null, + "id": "64d2c112", "metadata": {}, "outputs": [], "source": [ @@ -177,6 +187,7 @@ { "cell_type": "code", "execution_count": null, + "id": "452456a3", "metadata": {}, "outputs": [], "source": [ @@ -216,6 +227,7 @@ { "cell_type": "code", "execution_count": null, + "id": "129a6da2", "metadata": {}, "outputs": [], "source": [ @@ -224,6 +236,7 @@ }, { "cell_type": "markdown", + "id": "bebdc6e9", "metadata": {}, "source": [ "## Training Dataset" @@ -231,6 +244,7 @@ }, { "cell_type": "markdown", + "id": "02d2ac3a", "metadata": {}, "source": [ "The training script fine-tunes GPT-J on the `sst2` dataset. \n", @@ -242,6 +256,7 @@ }, { "cell_type": "markdown", + "id": "d676e00b", "metadata": {}, "source": [ "## Setup Hyperparameters\n", @@ -254,6 +269,7 @@ { "cell_type": "code", "execution_count": null, + "id": "175f94c2", "metadata": {}, "outputs": [], "source": [ @@ -264,6 +280,7 @@ { "cell_type": "code", "execution_count": null, + "id": "695b8ab3", "metadata": {}, "outputs": [], "source": [ @@ -295,6 +312,7 @@ { "cell_type": "code", "execution_count": null, + "id": "1fcd2760", "metadata": {}, "outputs": [], "source": [ @@ -314,6 +332,7 @@ { "cell_type": "code", "execution_count": null, + "id": "efd36f09", "metadata": {}, "outputs": [], "source": [ @@ -323,6 +342,7 @@ { "cell_type": "code", "execution_count": null, + "id": "e2f03456", "metadata": {}, "outputs": [], "source": [ @@ -364,6 +384,7 @@ }, { "cell_type": "markdown", + "id": "7ef20b10", "metadata": {}, "source": [ "## Setup SageMaker Training Job" @@ -372,6 +393,7 @@ { "cell_type": "code", "execution_count": null, + "id": "c43285cb", "metadata": {}, "outputs": [], "source": [ @@ -387,6 +409,7 @@ { "cell_type": "code", "execution_count": null, + "id": "bb9616e3", "metadata": {}, "outputs": [], "source": [ @@ -419,6 +442,7 @@ { "cell_type": "code", "execution_count": null, + "id": "f4cabfd4", "metadata": {}, "outputs": [], "source": [ @@ -431,6 +455,7 @@ { "cell_type": "code", "execution_count": null, + "id": "b9f05c07", "metadata": {}, "outputs": [], "source": [ @@ -459,6 +484,7 @@ }, { "cell_type": "markdown", + "id": "45eb0cde", "metadata": {}, "source": [ "If you receive a `ResourceLimitExceeded` error message when running the following cell, you can request an increase on the default quota by contacting [AWS support](https://console.aws.amazon.com/support). Open the [AWS Support Center](https://console.aws.amazon.com/support), and then choose Create case. Choose Service limit increase. For Limit Type choose SageMaker Training Jobs. Complete the rest of the form and submit." @@ -467,6 +493,7 @@ { "cell_type": "code", "execution_count": null, + "id": "2601cc8a", "metadata": { "scrolled": true }, @@ -484,6 +511,7 @@ }, { "cell_type": "markdown", + "id": "31b51fd8", "metadata": {}, "source": [ "## Accessing the Training Logs\n", @@ -511,6 +539,7 @@ { "cell_type": "code", "execution_count": null, + "id": "3de8f1d2", "metadata": {}, "outputs": [], "source": [] diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb b/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb index ebed139ea1..a0b51c54dd 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb @@ -29,7 +29,7 @@ "\n", "This notebook depends on the following files and folders:\n", "\n", - "1. `train_gptj_smp_tensor_parallel_script.py`: This is an entrypoint script that is passed to the PyTorch estimator in the notebook instructions. This script is responsible for end to end training of the GPT-J model with SMP. The script has additional comments at places where the SMP API is used.\n", + "1. `train_gptj_smp_tensor_parallel_script.py`: This is an entry-point script that is passed to the PyTorch estimator in the notebook instructions. This script is responsible for end to end training of the GPT-J model with SMP. The script has additional comments at places where the SMP API is used.\n", "\n", "3. `learning_rates.py`: This contains the functions for learning rate schedule.\n", "4. `requirements.txt`: This will install the dependencies, like the right version of huggingface transformers.\n", @@ -629,7 +629,7 @@ "Amazon FSx for Lustre is a high performance file system optimized for workloads, such as machine learning, analytics and high performance computing. With Amazon FSx for Lustre, you can accelerate your File mode training jobs by avoiding the initial Amazon S3 download time.\n", "\n", "\n", - "Please see the instructions at [Distributed Training of Mask-RCNN in Amazon SageMaker using FSx](https://github.com/aws/amazon-sagemaker-examples/blob/master/advanced_functionality/distributed_tensorflow_mask_rcnn/mask-rcnn-scriptmode-fsx.ipynb), to create the an Amazon FSx Lustre file-system and import data from the S3 bucket to your FSx file system. Note that the FSx must be created in a private subnet with internet gateway to ensure that training job has access to the internet. " + "Please see the instructions at [Distributed Training of Mask-RCNN in Amazon SageMaker using FSx](https://github.com/aws/amazon-sagemaker-examples/blob/master/advanced_functionality/distributed_tensorflow_mask_rcnn/mask-rcnn-scriptmode-fsx.ipynb), to create an Amazon FSx Lustre file-system and import data from the S3 bucket to your FSx file system. Note that the FSx must be created in a private subnet with internet gateway to ensure that training job has access to the internet. " ] }, { @@ -888,7 +888,12 @@ "metadata": {}, "outputs": [], "source": [ - "if instance_type in [\"ml.p3.16xlarge\", \"ml.p3dn.24xlarge\", \"ml.g5.48xlarge\", \"ml.p4d.24xlarge\"]:\n", + "if instance_type in [\n", + " \"ml.p3.16xlarge\",\n", + " \"ml.p3dn.24xlarge\",\n", + " \"ml.g5.48xlarge\",\n", + " \"ml.p4d.24xlarge\",\n", + "]:\n", " processes_per_host = 8\n", "elif instance_type == \"ml.p2.16xlarge\":\n", " processes_per_host = 16\n", @@ -1231,6 +1236,494 @@ } ], "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + } + ], "hide_input": false, "instance_type": "ml.m5.large", "kernelspec": { diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/args.py b/training/distributed_training/pytorch/model_parallel/gpt-j/args.py index 223e7ac850..e0c4a912c4 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/args.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/args.py @@ -106,7 +106,9 @@ class CustomTrainingArguments(TrainingArguments): ) plateau: float = field( default=0.4, - metadata={"help": "Percentage of total iterations to keep at max if using plateau lr"}, + metadata={ + "help": "Percentage of total iterations to keep at max if using plateau lr" + }, ) @@ -139,11 +141,15 @@ class ModelArguments: ) config_name: Optional[str] = field( default=None, - metadata={"help": "Pretrained config name or path if not the same as model_name"}, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, ) tokenizer_name: Optional[str] = field( default=None, - metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, ) cache_dir: Optional[str] = field( default=None, @@ -159,7 +165,9 @@ class ModelArguments: ) load_from_s3: bool = field( default=False, - metadata={"help": "Whether to load the model from a S3 location or from_pretrained."}, + metadata={ + "help": "Whether to load the model from a S3 location or from_pretrained." + }, ) model_revision: str = field( default="main", @@ -250,7 +258,9 @@ class DataTrainingArguments: default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}, ) - train_file: Optional[str] = field(default=None, metadata={"help": "Local path to train file."}) + train_file: Optional[str] = field( + default=None, metadata={"help": "Local path to train file."} + ) validation_file: Optional[str] = field( default=None, metadata={"help": "Local path to validation file."} ) @@ -268,7 +278,9 @@ class SMPArguments: microbatches: Optional[int] = field(default=1, metadata={"help": "Microbatches"}) - active_microbatches: Optional[int] = field(default=None, metadata={"help": "Microbatches"}) + active_microbatches: Optional[int] = field( + default=None, metadata={"help": "Microbatches"} + ) optimize: Optional[str] = field( default="speed", @@ -299,7 +311,9 @@ class SMPArguments: ) trace_device: Optional[str] = field( default="cpu", - metadata={"help": "The device ('cpu' or 'gpu') that you want load model to for tracing."}, + metadata={ + "help": "The device ('cpu' or 'gpu') that you want load model to for tracing." + }, ) match_weights: Optional[int] = field( default=0, diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/data_pipeline.py b/training/distributed_training/pytorch/model_parallel/gpt-j/data_pipeline.py index ce1858ae15..87da8b9d47 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/data_pipeline.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/data_pipeline.py @@ -53,7 +53,13 @@ def __getitem__(self, index): index = padded_mask_indices[0].item() masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] + return [ + input_ids, + segment_ids, + input_mask, + masked_lm_labels, + next_sentence_labels, + ] ###### Load GPT pretraining data ###### @@ -104,7 +110,9 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: self.actual_sequence_length = len(obj["input_ids"]) if self.actual_sequence_length > self.max_sequence_length: - s_idx = np.random.randint(0, self.actual_sequence_length - self.max_sequence_length) + s_idx = np.random.randint( + 0, self.actual_sequence_length - self.max_sequence_length + ) e_idx = s_idx + self.max_sequence_length iids = iids[s_idx:e_idx] attns = attns[s_idx:e_idx] @@ -154,7 +162,13 @@ def __getitem__(self, index): index = padded_mask_indices[0].item() masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] + return [ + input_ids, + segment_ids, + input_mask, + masked_lm_labels, + next_sentence_labels, + ] ###### Load Openwebtext pretraining data ###### @@ -205,7 +219,9 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: self.actual_sequence_length = len(obj["input_ids"]) if self.actual_sequence_length > self.max_sequence_length: - s_idx = np.random.randint(0, self.actual_sequence_length - self.max_sequence_length) + s_idx = np.random.randint( + 0, self.actual_sequence_length - self.max_sequence_length + ) e_idx = s_idx + self.max_sequence_length iids = iids[s_idx:e_idx] attns = attns[s_idx:e_idx] @@ -299,6 +315,8 @@ def create_pretraining_dataloader( else: data_len = smp.recv_from(0, smp.RankType.PP_RANK) dataset = DummyDataset(data_len * batch_size, data_type=data_type) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, drop_last=True + ) return dataloader diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/learning_rates.py b/training/distributed_training/pytorch/model_parallel/gpt-j/learning_rates.py index e1dabbca63..acb7b0f5fa 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/learning_rates.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/learning_rates.py @@ -81,7 +81,11 @@ def get_lr(self): / (self.end_iter - self.plateau_iter) ) elif self.decay_style == "cosine": - lr = self.start_lr / 2.0 * (math.cos(math.pi * num_iters_ / self.end_iter) + 1) + lr = ( + self.start_lr + / 2.0 + * (math.cos(math.pi * num_iters_ / self.end_iter) + 1) + ) elif self.decay_style == "exponential": # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) @@ -128,15 +132,21 @@ def _check_and_set(self, cls_value, sd_value, name): def load_state_dict(self, sd): - self.start_lr = self._check_and_set(self.start_lr, sd["start_lr"], "learning rate") - self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate") + self.start_lr = self._check_and_set( + self.start_lr, sd["start_lr"], "learning rate" + ) + self.min_lr = self._check_and_set( + self.min_lr, sd["min_lr"], "minimum learning rate" + ) self.warmup_iter = self._check_and_set( self.warmup_iter, sd["warmup_iter"], "warmup iterations" ) self.end_iter = self._check_and_set( self.end_iter, sd["end_iter"], "total number of iterations" ) - self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style") + self.decay_style = self._check_and_set( + self.decay_style, sd["decay_style"], "decay style" + ) self.num_iters = sd["num_iters"] self.step(self.num_iters) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py b/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py index 329926a26e..d0d84ae594 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/memory_tracker.py @@ -3,28 +3,30 @@ import smdistributed.modelparallel.torch as smp import torch + try: from py3nvml import py3nvml except ImportError: py3nvml = None dtype_to_bit = { -torch.float32 : 32, -torch.float64 : 64, -torch.float16: 16, -torch.bfloat16: 16, -torch.uint8: 8, -torch.int8: 8, -torch.int16: 16, -torch.int32: 32, -torch.int64: 64, -torch.bool: 1 + torch.float32: 32, + torch.float64: 64, + torch.float16: 16, + torch.bfloat16: 16, + torch.uint8: 8, + torch.int8: 8, + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, + torch.bool: 1, } process = psutil.Process(os.getpid()) base_mem_usage = process.memory_info().data last_mem_usage = base_mem_usage + def memory_status(msg="", reset_max=True, sync=True): rank = smp.rank() @@ -60,11 +62,11 @@ def memory_status(msg="", reset_max=True, sync=True): max_cached /= 1024**3 print( - f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', - f'device={local_rank} ' - f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} ' - f'cache {cached:0.4f} max_cached {max_cached:0.4f} ' - f'{total_used_str}' + f"[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}", + f"device={local_rank} " + f"alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} " + f"cache {cached:0.4f} max_cached {max_cached:0.4f} " + f"{total_used_str}", ) if reset_max: torch.cuda.reset_max_memory_cached() @@ -72,8 +74,10 @@ def memory_status(msg="", reset_max=True, sync=True): if py3nvml != None: py3nvml.nvmlShutdown() + def memory_status_cpu(msg=""): import gc + global last_mem_usage global base_mem_usage rdp_rank = smp.rdp_rank() @@ -81,12 +85,14 @@ def memory_status_cpu(msg=""): gc.collect() gc.collect() objects = gc.get_objects() - tensors = [obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda] + tensors = [ + obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda + ] torch_usage = 0 for t in tensors: torch_usage += t.numel() * dtype_to_bit[t.dtype] - #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes - current_usage = process.memory_info().data + # total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes + current_usage = process.memory_info().data total_usage = current_usage - base_mem_usage usage_change = current_usage - last_mem_usage last_mem_usage = current_usage @@ -105,7 +111,7 @@ def memory_status_cpu(msg=""): return print( - f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', - f'device={local_rank} ' - f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}' - ) \ No newline at end of file + f"[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}", + f"device={local_rank} " + f"torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}", + ) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/preprocess.py b/training/distributed_training/pytorch/model_parallel/gpt-j/preprocess.py index c84e3e3106..aca42ff1a0 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/preprocess.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/preprocess.py @@ -77,7 +77,9 @@ def datasets(model_args, data_args, training_args): } if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, **tokenizer_kwargs + ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, **tokenizer_kwargs @@ -94,7 +96,9 @@ def datasets(model_args, data_args, training_args): text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function - tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") + tok_logger = transformers.utils.logging.get_logger( + "transformers.tokenization_utils_base" + ) def tokenize_function(examples): @@ -178,7 +182,9 @@ def group_texts(examples): def main(): - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py b/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py index e9e7ebd79e..71d18a3b88 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/sharded_data_parallel_checkpoint.py @@ -7,20 +7,22 @@ from collections import OrderedDict # load to cpu -device = torch.device('cpu') +device = torch.device("cpu") smp_prefix = "module." + def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): - ''' + """ alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) - ''' - return [ atoi(c) for c in re.split(r'(\d+)', text) ] + """ + return [atoi(c) for c in re.split(r"(\d+)", text)] + def get_model_state_file(checkpoint_dir): if not os.path.isdir(checkpoint_dir): @@ -32,41 +34,50 @@ def get_model_state_file(checkpoint_dir): return file + def get_optim_files(checkpoint_dir): - optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "optimizer_*.pt")), key=natural_keys) + optim_files = sorted( + glob.glob(os.path.join(checkpoint_dir, "optimizer_*.pt")), key=natural_keys + ) if len(optim_files) == 0: raise FileNotFoundError( - f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'" + ) return optim_files + def get_user_content_file(checkpoint_dir): file = os.path.join(checkpoint_dir, "user_content.pt") if not os.path.exists(file): raise FileNotFoundError(f"can't find user content file at '{file}'") return file + def parse_model_state(model_file, user_content_file, dtype): state_dict = torch.load(model_file, map_location=device) user_content = torch.load(user_content_file, map_location=device) if "buffer_names" not in user_content: - raise ValueError(f"{user_content_file} miss buffer_names to reconstruct the full state") + raise ValueError( + f"{user_content_file} miss buffer_names to reconstruct the full state" + ) if "param_shapes" not in user_content: - raise ValueError(f"{user_content_file} miss param_shapes to reconstruct the full state") + raise ValueError( + f"{user_content_file} miss param_shapes to reconstruct the full state" + ) buffer_names = user_content["buffer_names"] param_shapes = user_content["param_shapes"] # recover just the buffers while restoring them to the specified dtype buffers = { - k: v.to(dtype) - for k, - v in state_dict["module"].items() if k in buffer_names + k: v.to(dtype) for k, v in state_dict["module"].items() if k in buffer_names } return buffers, param_shapes + def parse_optim_states(files, checkpoint_dir, dtype): total_files = len(files) state_dicts = [] @@ -77,7 +88,9 @@ def parse_optim_states(files, checkpoint_dir, dtype): states = torch.load(f, map_location=device) if i == 0: sharded_data_parallel_size = states["partition_count"] - states["fp32_flat_groups"] = [group.to(dtype) for group in states["fp32_flat_groups"]] + states["fp32_flat_groups"] = [ + group.to(dtype) for group in states["fp32_flat_groups"] + ] state_dicts.append(states["fp32_flat_groups"]) if type(sharded_data_parallel_size) is list: @@ -89,20 +102,21 @@ def parse_optim_states(files, checkpoint_dir, dtype): "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." ) - flat_groups = [ - torch.cat(state_dicts[i], - 0) for i in range(len(state_dicts)) - ] + flat_groups = [torch.cat(state_dicts[i], 0) for i in range(len(state_dicts))] return sharded_data_parallel_size, flat_groups + def partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size): remainder = unpartitioned_numel % sharded_data_parallel_size padding_numel = (sharded_data_parallel_size - remainder) if remainder else 0 partitioned_numel = math.ceil(unpartitioned_numel / sharded_data_parallel_size) return partitioned_numel, padding_numel - -def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dtype=torch.float32, tag=None, remove_smp_prefix=True): + + +def get_full_state_dict_from_sharded_data_parallel_checkpoint( + checkpoint_dir, dtype=torch.float32, tag=None, remove_smp_prefix=True +): """ Returns full state_dict reconstructed from sharded data parallel checkpoint @@ -114,9 +128,9 @@ def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dt """ if tag is None: - latest_path = os.path.join(checkpoint_dir, 'newest') + latest_path = os.path.join(checkpoint_dir, "newest") if os.path.isfile(latest_path): - with open(latest_path, 'r') as fd: + with open(latest_path, "r") as fd: tag = fd.read().strip() else: raise ValueError(f"Unable to find 'newest' file at {latest_path}") @@ -129,17 +143,19 @@ def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dt print(f"Processing checkpoint '{checkpoint_dir}'") optim_files = get_optim_files(checkpoint_dir) - sharded_data_parallel_size, flat_groups = parse_optim_states(optim_files, checkpoint_dir, dtype) + sharded_data_parallel_size, flat_groups = parse_optim_states( + optim_files, checkpoint_dir, dtype + ) model_file = get_model_state_file(checkpoint_dir) user_content_file = get_user_content_file(checkpoint_dir) buffers, param_shapes = parse_model_state(model_file, user_content_file, dtype) - + gc.collect() avail_numel = flat_groups[0].numel() * sharded_data_parallel_size # merge list of dicts, preserving order param_shapes = {k: v for d in param_shapes for k, v in d.items()} - + # params offset = 0 total_numel = 0 @@ -150,27 +166,32 @@ def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dt for name, shape in param_shapes.items(): if remove_smp_prefix and name.startswith(smp_prefix): - name = name[len(smp_prefix):] + name = name[len(smp_prefix) :] unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 - partitioned_numel, partitioned_padding_numel = partitioned_param_info(unpartitioned_numel, sharded_data_parallel_size) + partitioned_numel, partitioned_padding_numel = partitioned_param_info( + unpartitioned_numel, sharded_data_parallel_size + ) print( f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) # memory usage doubles here - state_dict[name] = torch.cat( - tuple(flat_groups[i].narrow(0, - offset, - partitioned_numel) - for i in range(sharded_data_parallel_size)), - 0).narrow(0, - 0, - unpartitioned_numel).view(shape) + state_dict[name] = ( + torch.cat( + tuple( + flat_groups[i].narrow(0, offset, partitioned_numel) + for i in range(sharded_data_parallel_size) + ), + 0, + ) + .narrow(0, 0, unpartitioned_numel) + .view(shape) + ) offset += partitioned_numel offset *= sharded_data_parallel_size @@ -178,14 +199,14 @@ def get_full_state_dict_from_sharded_data_parallel_checkpoint(checkpoint_dir, dt # Sanity check if offset != avail_numel: raise ValueError( - f"consumed {offset} numels out of {avail_numel} - something is wrong") + f"consumed {offset} numels out of {avail_numel} - something is wrong" + ) - print( - f"Reconstructed state dict with {total_params} params {total_numel} elements" - ) + print(f"Reconstructed state dict with {total_params} params {total_numel} elements") return state_dict + def get_param_shapes(model, optimizer): """Returns a dict of name to shape mapping, only for the flattened weights saved by the optimizer. the names are exactly as in state_dict. The order is absolutely important, since @@ -218,6 +239,7 @@ def get_param_shapes(model, optimizer): return param_group_shapes + def get_buffer_names(model): buffer_names = [] @@ -237,4 +259,4 @@ def get_layer_named_buffers(module, prefix=""): get_layer_named_buffers(model.module, prefix="") - return buffer_names \ No newline at end of file + return buffer_names diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/smp_trainer.py b/training/distributed_training/pytorch/model_parallel/gpt-j/smp_trainer.py index f27170a13c..13d4ffdfc5 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/smp_trainer.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/smp_trainer.py @@ -85,7 +85,9 @@ def __init__( @smp.step def train_step(model, optimizer, input_ids, attention_mask, args): - loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] + loss = model( + input_ids=input_ids, attention_mask=attention_mask, labels=input_ids + )["loss"] model.backward(loss) diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py index 6fdd64a892..a3cf93da2a 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_script.py @@ -124,19 +124,23 @@ def save( ) if partial: - save_dict["optimizer"] = optimizer.local_state_dict(gather_if_shard=args.gather_if_shard) + save_dict["optimizer"] = optimizer.local_state_dict( + gather_if_shard=args.gather_if_shard + ) else: if args.skip_full_optimizer: print("Skipping saving the final optimizer state") elif args.shard_optimizer_state > 0: print( - "Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping..." + "Saving the full optimizer state does not work with shard_optimizer_state > 0! Skipping..." ) else: save_dict["optimizer"] = optimizer.state_dict() if not args.gather_if_shard or (smp.rdp_rank() == 0 and partial) or smp.rank() == 0: - smp.save(save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard) + smp.save( + save_dict, output_save_file, partial=partial, v3=not args.gather_if_shard + ) print(f"Finished checkpointing after {total_steps} steps: {output_save_file}") @@ -233,10 +237,14 @@ def initialize_model_and_tokenizer(model_args): } if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, **tokenizer_kwargs + ) elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, **tokenizer_kwargs + ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script. " @@ -270,7 +278,9 @@ def initialize_smp(smp_args, training_args): if smp.rank() == 0: print("Arguments:", smp_args.__dict__) print(f"Transformers version: {transformers.__version__}") - print(f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}") + print( + f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" + ) print(f"smdistributed config: {smp_config}") set_seed(training_args.seed) @@ -282,7 +292,9 @@ def main(): model, tokenizer = initialize_model_and_tokenizer(model_args) # Get datasets - train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args, training_args) + train_dataset, eval_dataset = Preprocess.datasets( + model_args, data_args, training_args + ) if is_sagemaker_mp_enabled(): initialize_smp(smp_args, training_args) @@ -363,7 +375,9 @@ def main(): if training_args.save_final_full_model: # saves full model at the end - base_path = f"trained_gpt_nparams-{num_params}_steps-{training_args.max_steps}.pt" + base_path = ( + f"trained_gpt_nparams-{num_params}_steps-{training_args.max_steps}.pt" + ) out_path = os.path.join(training_args.model_dir, base_path) # if args.save_or_verify_ckptsum: # # Save optimizer and model tensor sums and scalars before saving diff --git a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py index a346f6ad9e..60e6159a58 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py +++ b/training/distributed_training/pytorch/model_parallel/gpt-j/train_gptj_smp_tensor_parallel_script.py @@ -95,10 +95,14 @@ def get_param_groups_by_weight_decay(module): @smp.step def train_step(model, optimizer, input_ids, attention_mask, args): if args.logits_output: - output = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) + output = model( + input_ids=input_ids, attention_mask=attention_mask, labels=input_ids + ) loss = output["loss"] else: - loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] + loss = model( + input_ids=input_ids, attention_mask=attention_mask, labels=input_ids + )["loss"] model.backward(loss) @@ -111,7 +115,9 @@ def train_step(model, optimizer, input_ids, attention_mask, args): # smdistributed: Define smp.step. Return any tensors needed outside. @smp.step def test_step(model, input_ids, attention_mask): - loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"] + loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)[ + "loss" + ] return loss @@ -171,7 +177,8 @@ def train( [ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) - if os.path.isfile(os.path.join(args.training_dir, p)) and "training" in p + if os.path.isfile(os.path.join(args.training_dir, p)) + and "training" in p ] ) else: @@ -283,9 +290,13 @@ def should_record(): if smp.rank() == 0: if args.use_bert_data: - print(f"Reading data from training path {train_dataloader.dataset.input_file}") + print( + f"Reading data from training path {train_dataloader.dataset.input_file}" + ) else: - print(f"Reading data from training path {train_dataloader.dataset.input_paths}") + print( + f"Reading data from training path {train_dataloader.dataset.input_paths}" + ) for batch_idx, input_data in enumerate(train_dataloader): if batch_idx < start_batch_index: @@ -315,7 +326,9 @@ def should_record(): optimizer.zero_grad(set_to_none=True) if args.logits_output: - train_output = train_step(model, optimizer, input_ids, attention_mask, args) + train_output = train_step( + model, optimizer, input_ids, attention_mask, args + ) loss_mb = train_output["loss"] logits_mb = train_output["logits"] if smp.tp_size() > 1: @@ -330,7 +343,7 @@ def should_record(): loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() - + if args.enable_memory_profiling > 0: memory_status_cpu("After_train_step_cpu") memory_status(msg="After_train_step") @@ -339,11 +352,10 @@ def should_record(): # empty the cache to avoid OOM torch.cuda.empty_cache() - if grad_accumulation_boundary(batch_idx): if args.fp16: optimizer.clip_master_grads(args.grad_clip) - + optimizer.step() if not (args.fp16 and optimizer.overflow): lr_scheduler.step() @@ -394,20 +406,22 @@ def should_record(): "total_steps": total_steps, "start_train_path_index": curr_train_path_index, "model_config": model_config, - "start_batch_index": batch_idx+1, + "start_batch_index": batch_idx + 1, } # to reconstruct the full model if args.sharded_data_parallel_degree > 1: user_content["buffer_names"] = get_buffer_names(model) user_content["param_shapes"] = get_param_shapes(model, optimizer) user_content["lr_scheduler"] = lr_scheduler.state_dict() - smp.save_checkpoint(args.checkpoint_dir, + smp.save_checkpoint( + args.checkpoint_dir, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content, - num_kept_partial_checkpoints=args.num_kept_checkpoints) + num_kept_partial_checkpoints=args.num_kept_checkpoints, + ) if args.logits_output: to_save["loss"].append(loss.item()) @@ -417,7 +431,9 @@ def should_record(): to_save["logits"] = logits.detach().cpu() output_file = f"rank_{smp.rank()}_" + args.logits_output torch.save(to_save, os.path.join(args.model_dir, output_file)) - print(f"logits and loss saved at {os.path.join(args.model_dir, output_file)}") + print( + f"logits and loss saved at {os.path.join(args.model_dir, output_file)}" + ) break del train_dataloader @@ -467,11 +483,19 @@ def parse_args(): opt_grp.add_argument("--seed", type=int, default=12345) opt_grp.add_argument("--same_seed", type=int, default=0) opt_grp.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) - opt_grp.add_argument("--fp16", default=0, type=int, help="automatic mixed precision training") - opt_grp.add_argument("--bf16", default=0, type=int, help="automatic mixed precision training") + opt_grp.add_argument( + "--fp16", default=0, type=int, help="automatic mixed precision training" + ) + opt_grp.add_argument( + "--bf16", default=0, type=int, help="automatic mixed precision training" + ) opt_grp.add_argument("--sharded_data_parallel_degree", default=1, type=int) - opt_grp.add_argument("--grad_clip", default=1.0, type=float, help="gradient clipping") - opt_grp.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") + opt_grp.add_argument( + "--grad_clip", default=1.0, type=float, help="gradient clipping" + ) + opt_grp.add_argument( + "--weight_decay", default=0.01, type=float, help="weight decay" + ) opt_grp.add_argument( "--beta1", default=0.9, type=float, help="beta1 parameter for Adam optimizer" ) @@ -485,17 +509,31 @@ def parse_args(): help="enable gradient checkpointing to reduce memory consumption", ) parser.add_argument( - "--logging_freq", type=int, default=1, help="number of iterations between logging" + "--logging_freq", + type=int, + default=1, + help="number of iterations between logging", ) # I/O - io_grp = parser.add_argument_group(title="io", description="location for input and output") - io_grp.add_argument("--use_bert_data", type=int, default=0, help="use wiki corpus data for training") - io_grp.add_argument("--zipped_data", type=int, default=0, help="input data is zipped files") + io_grp = parser.add_argument_group( + title="io", description="location for input and output" + ) + io_grp.add_argument( + "--use_bert_data", type=int, default=0, help="use wiki corpus data for training" + ) + io_grp.add_argument( + "--zipped_data", type=int, default=0, help="input data is zipped files" + ) + io_grp.add_argument( + "--epochs", + type=int, + default=1, + help="times of iterating over the training dataset", + ) io_grp.add_argument( - "--epochs", type=int, default=1, help="times of iterating over the training dataset" + "--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"] ) - io_grp.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) io_grp.add_argument( "--checkpoint-dir", type=str, @@ -508,7 +546,9 @@ def parse_args(): default=os.environ["SM_MODEL_DIR"], help="Saves full model for inference to this dir. Also used if load_full is given to load the model. Note the lack of optimizer state here.", ) - io_grp.add_argument("--training-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + io_grp.add_argument( + "--training-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"] + ) io_grp.add_argument("--test-dir", type=str, default=os.environ["SM_CHANNEL_TEST"]) io_grp.add_argument( "--parallel_proc_data_processing", @@ -522,12 +562,18 @@ def parse_args(): default=0, help="Enabling this will save a combined model only at the end", ) - io_grp.add_argument("--load_partial", type=int, default=0, help="Load from partial checkpoints") - io_grp.add_argument("--load_full", type=int, default=0, help="Load from full checkpoints") + io_grp.add_argument( + "--load_partial", type=int, default=0, help="Load from partial checkpoints" + ) + io_grp.add_argument( + "--load_full", type=int, default=0, help="Load from full checkpoints" + ) io_grp.add_argument( "--logits_output", type=str, default="", help="Path to save logits and loss" ) - io_grp.add_argument("--prescaled_batch", type=int, default=1, help="use prescaled batch") + io_grp.add_argument( + "--prescaled_batch", type=int, default=1, help="use prescaled batch" + ) # configure model size model_grp = parser.add_argument_group( @@ -542,10 +588,27 @@ def parse_args(): model_grp.add_argument("--embd_pdrop", type=float, default=0.1) model_grp.add_argument("--attn_pdrop", type=float, default=0.1) model_grp.add_argument("--summary_first_pdrop", type=float, default=0.1) - model_grp.add_argument("--use_adamw", type=int, default=0, help="Use adamw optimizer") - model_grp.add_argument("--finetune_6b", type=int, default=0, help="Flag to enable finetune 6B GPTJ model") - model_grp.add_argument("--use_distributed_transformer", type=int, default=1, help="Use distributed transformer") - model_grp.add_argument("--checkpoint_sublayers", type=int, default=0, help="Apply activation checkpointing to submodules of each transformer layer") + model_grp.add_argument( + "--use_adamw", type=int, default=0, help="Use adamw optimizer" + ) + model_grp.add_argument( + "--finetune_6b", + type=int, + default=0, + help="Flag to enable finetune 6B GPTJ model", + ) + model_grp.add_argument( + "--use_distributed_transformer", + type=int, + default=1, + help="Use distributed transformer", + ) + model_grp.add_argument( + "--checkpoint_sublayers", + type=int, + default=0, + help="Apply activation checkpointing to submodules of each transformer layer", + ) smp_grp = parser.add_argument_group(title="smp", description="smp") smp_grp.add_argument("--tensor_parallel_degree", type=int, default=1) @@ -630,7 +693,9 @@ def parse_args(): default=0, help="Clean torch reserved memory at he end of every step", ) - parser.add_argument("--use_fsx", type=int, default=0, help="Using FSx for checkpointing") + parser.add_argument( + "--use_fsx", type=int, default=0, help="Using FSx for checkpointing" + ) parser.add_argument( "--enable_memory_profiling", type=int, default=0, help="Enable memory profile" ) @@ -651,13 +716,15 @@ def parse_args(): "--lr_decay_iters", type=int, default=None, - help="number of iterations to decay learning rate over," " If None defaults to train iters", + help="number of iterations to decay learning rate over," + " If None defaults to train iters", ) lr_grp.add_argument( "--min_lr", type=float, default=0.0, - help="Minumum value for learning rate. The scheduler" "clip values below this threshold.", + help="Minumum value for learning rate. The scheduler" + "clip values below this threshold.", ) lr_grp.add_argument( "--warmup", @@ -674,13 +741,16 @@ def parse_args(): ) ci_grp = parser.add_argument_group(title="ci", description="ci related settings") - ci_grp.add_argument("--ci", default=False, action="store_true", help="Whether enable ci") + ci_grp.add_argument( + "--ci", default=False, action="store_true", help="Whether enable ci" + ) ci_grp.add_argument("--time_to_train", type=int, help="time to train threshold") ci_grp.add_argument("--throughput", type=float, help="throughput threshold") ci_grp.add_argument("--loss", type=float, help="loss threshold") args, _ = parser.parse_known_args() return args + def compute_num_params(model): num_params = 0 seen = set() @@ -688,11 +758,12 @@ def compute_num_params(model): if p not in seen: seen.add(p) if hasattr(p, "ds_shape"): - num_params += np.prod(p.ds_shape) + num_params += np.prod(p.ds_shape) else: num_params += np.prod(p.size()) - - return num_params + + return num_params + def main(): args = parse_args() @@ -731,7 +802,9 @@ def main(): if smp.rank() == 0: print("Arguments:", args.__dict__) print(f"Transformers version: {transformers.__version__}") - print(f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}") + print( + f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" + ) print(f"smdistributed config: {smp_config}") if args.save_final_full_model and smp.rank() == 0: @@ -772,7 +845,7 @@ def main(): # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel - # will override those weights anyway when we use distributed transformer. + # will override those weights anyway when we use distributed transformer. if args.use_distributed_transformer > 0: from transformers.modeling_utils import PreTrainedModel @@ -794,30 +867,34 @@ def main(): if args.finetune_6b: with smp.model_creation( - tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, + tensor_parallelism=smp.tp_size() > 1 + or args.use_distributed_transformer > 0, dtype=dtype, attention_in_fp32=args.attention_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, - ): - model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16) - model_config = model.config - # translated_state_dict = translate_hf_gptj_state_dict_to_smdistributed(model.state_dict(), max_seq_len=args.max_context_width) + ): + model = AutoModelForCausalLM.from_pretrained( + "EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16 + ) + model_config = model.config + # translated_state_dict = translate_hf_gptj_state_dict_to_smdistributed(model.state_dict(), max_seq_len=args.max_context_width) else: - + with smp.model_creation( - tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, + tensor_parallelism=smp.tp_size() > 1 + or args.use_distributed_transformer > 0, dtype=dtype, attention_in_fp32=args.attention_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, - ): - model = AutoModelForCausalLM.from_config(model_config) - + ): + model = AutoModelForCausalLM.from_config(model_config) + if args.enable_memory_profiling > 0: memory_status_cpu(msg="after model creation") @@ -840,19 +917,21 @@ def main(): # the model provided for DistributedModel class instantiation. if args.enable_memory_profiling > 0: memory_status_cpu(msg="before dist model creation") - model = smp.DistributedModel(model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation) - + model = smp.DistributedModel( + model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation + ) + # if args.finetune_6b: # model.load_state_dict(translated_state_dict) - + if args.enable_memory_profiling > 0: memory_status_cpu(msg="after dist model creation") - + if args.fp16: m = model.module else: m = model - + if args.use_distributed_transformer > 0: transformer_layers = m.module.module.transformer.seq_layers else: @@ -862,7 +941,9 @@ def main(): print(f"Manual partition enabled") if args.partition_assignment != "": get_num_layers = lambda x: int(partition_assignment[x]) - total_layers = sum([get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())]) + total_layers = sum( + [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())] + ) assert ( total_layers == args.num_layers ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}" @@ -883,11 +964,17 @@ def main(): if args.use_adamw > 0: optimizer = optim.AdamW( - param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay + param_groups, + betas=(args.beta1, args.beta2), + lr=args.lr, + weight_decay=args.weight_decay, ) else: optimizer = optim.Adam( - param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay + param_groups, + betas=(args.beta1, args.beta2), + lr=args.lr, + weight_decay=args.weight_decay, ) if args.activation_checkpointing: @@ -897,7 +984,9 @@ def main(): smp.set_activation_checkpointing(c.attention) smp.set_activation_checkpointing(c.output) else: - smp.set_activation_checkpointing(transformer_layers, strategy=args.activation_strategy) + smp.set_activation_checkpointing( + transformer_layers, strategy=args.activation_strategy + ) else: for c in transformer_layers.children(): if args.checkpoint_sublayers: @@ -907,11 +996,11 @@ def main(): smp.set_activation_checkpointing(c) optimizer = smp.DistributedOptimizer( - optimizer, - static_loss_scale=None, + optimizer, + static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, - ) + ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) if args.enable_memory_profiling > 0: @@ -981,14 +1070,22 @@ def main(): # Note: the shared parameter will not be reflected so during loading you might need to load with strict=False user_content["buffer_names"] = get_buffer_names(model) user_content["param_shapes"] = get_param_shapes(model, optimizer) - smp.save_checkpoint(args.model_dir, + smp.save_checkpoint( + args.model_dir, tag=f"sharded_data_parallel_final_full_{num_params}", partial=True, model=model, optimizer=optimizer, - user_content=user_content) + user_content=user_content, + ) else: - smp.save_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False, model=model, user_content=user_content) + smp.save_checkpoint( + args.model_dir, + tag="fullmodel.pt", + partial=False, + model=model, + user_content=user_content, + ) smp.barrier() if smp.rank() == 0: