From d674e7ebfbb1fbc630bdb91702b5be3a4783c9a9 Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Wed, 23 Mar 2022 16:14:12 +0800 Subject: [PATCH] use hdf5 dataset --- .../language_model/bert/static_ipu/README.md | 41 ++- .../bert/static_ipu/dataset_ipu.py | 331 +++++++++--------- .../bert/static_ipu/requirements.txt | 3 +- .../bert/static_ipu/run_pretrain.py | 22 +- .../bert/static_ipu/run_pretrain.sh | 2 +- .../bert/static_ipu/run_pretrain_phase2.sh | 4 +- .../bert/static_ipu/run_squad.py | 35 +- .../bert/static_ipu/run_squad.sh | 1 - .../bert/static_ipu/run_squad_infer.sh | 1 - .../language_model/bert/static_ipu/utils.py | 31 +- 10 files changed, 232 insertions(+), 239 deletions(-) diff --git a/examples/language_model/bert/static_ipu/README.md b/examples/language_model/bert/static_ipu/README.md index 0ac0449245c6..3afa9c58163f 100644 --- a/examples/language_model/bert/static_ipu/README.md +++ b/examples/language_model/bert/static_ipu/README.md @@ -13,11 +13,11 @@ This project enabled BERT-Base pre-training and SQuAD fine-tuning task using [Pa | `run_squad.py` | The algorithm script to run SQuAD finetune and validation task. | | `modeling.py` | The algorithm script to build the Bert-Base model. | | `dataset_ipu.py` | The algorithm script to load input data in pretraining. | +| `custom_ops/` | The folder contains custom ops that will be used. | | `run_pretrain.sh` | Test script to run pretrain phase 1. | | `run_pretrain_phase2.sh` | Test script to run pretrain phase 2. | | `run_squad.sh` | Test script to run SQuAD finetune. | | `run_squad_infer.sh` | Test script to run SQuAD validation. | -| `LICENSE` | The license of Apache. | ## Dataset @@ -27,37 +27,36 @@ This project enabled BERT-Base pre-training and SQuAD fine-tuning task using [Pa The sequence length used in pretraining phase1 and phase2 are: 128 and 384. Following steps are provided for dataset generation. - ``` - # Code base:https://github.com/NVIDIA/DeepLearningExamples/tree/88eb3cff2f03dad85035621d041e23a14345999e/TensorFlow/LanguageModeling/BERT + ```bash + # Here we use a specific commmit, the latest commit should also be fine. git clone https://github.com/NVIDIA/DeepLearningExamples.git git checkout 88eb3cff2f03dad85035621d041e23a14345999e - cd DeepLearningExamples/TensorFlow/LanguageModeling/BERT - - bash scripts/docker/build.sh + cd DeepLearningExamples/PyTorch/LanguageModeling/BERT - cd data/ + # Modified the parameters `--max_seq_length 512` to `--max_seq_length 384` at line 50 and + # `--max_predictions_per_seq 80` to `--max_predictions_per_seq 56` at line 51. + vim data/create_datasets_from_start.sh - # Modified the parameters `--max_seq_length 512` to `--max_seq_length 384` at line 68, `--max_predictions_per_seq 80` to `--max_predictions_per_seq 56` at line 69. - vim create_datasets_from_start.sh + # Build docker image + bash scripts/docker/build.sh - cd ../ + # Use NV's docker to download and generate hdf5 file. This may requires GPU available. + # You can Remove `--gpus $NV_VISIBLE_DEVICES` to avoid GPU requirements. + bash scripts/docker/launch.sh - # Use NV's docker to download and generate tfrecord. This may requires GPU available. Removing `--gpus $NV_VISIBLE_DEVICES` in data_download.sh to avoid GPU requirements. - bash scripts/data_download.sh wiki_only + # generate dataset with wiki_only + bash data/create_datasets_from_start.sh wiki_only ``` -2. SQuAD 1.1 dataset +2. SQuAD v1.1 dataset - ``` - curl --create-dirs -L https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -o data/squad/train-v1.1.json + paddlenlp will download SQuAD v1.1 dataset automatically. You don't have to download manually. - curl --create-dirs -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -o data/squad/dev-v1.1.json - ``` ## Quick Start Guide -### 1)Prepare Project Environment +### Prepare Project Environment PaddlePaddle with IPU implementation, which is provided by Graphcore, is required by this application. User can either download the released package or build it from source. @@ -75,7 +74,7 @@ git clone -b bert_base_sdk_2.3.0 https://github.com/graphcore/Paddle.git cd Paddle # build docker image -docker build -t paddlepaddle/paddle:dev-ipu-2.3.0 -f tools/dockerfile/Dockerfile.ipu . +docker build -t paddlepaddle/paddle:ipu-dev-2.3.0 -f tools/dockerfile/Dockerfile.ipu . # create container # The ipuof.conf is required here. @@ -83,7 +82,7 @@ docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \ --device=/dev/infiniband/ --ipc=host --name paddle-ipu-dev \ -v ${HOST_IPUOF_PATH}:/ipuof \ -e IPUOF_CONFIG_PATH=/ipuof/ipu.conf \ --it paddlepaddle/paddle:dev-ipu-2.3.0 bash +-it paddlepaddle/paddle:ipu-dev-2.3.0 bash ``` All of later processes are required to be executed in the container. @@ -104,7 +103,7 @@ cmake --build `pwd`/build --config Release --target paddle_python -j$(nproc) pip3.7 install -U build/python/dist/paddlepaddle-0.0.0-cp37-cp37m-linux_x86_64.whl ``` -### 2) Execution +### Execution - Run pretraining phase1 (sequence_length = 128) diff --git a/examples/language_model/bert/static_ipu/dataset_ipu.py b/examples/language_model/bert/static_ipu/dataset_ipu.py index 62f5f790c079..1205aff67389 100644 --- a/examples/language_model/bert/static_ipu/dataset_ipu.py +++ b/examples/language_model/bert/static_ipu/dataset_ipu.py @@ -12,138 +12,100 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing -import random import threading -from collections import deque from queue import Queue +import h5py import numpy as np import paddle -try: - from torch_xla.utils.tf_record_reader import TfRecordReader -except ImportError: - raise ImportError("""Torch-xla required for TFRecord dataset. - Please install torch 1.7.0 & torch-xla using - `pip install torch==1.7.0 torch-xla@https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl`""" - ) +KEYS = ('input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', + 'masked_lm_ids', 'next_sentence_labels') -KEYS = ('masked_lm_ids', 'masked_lm_weights', 'segment_ids', 'input_ids', - 'input_mask', 'next_sentence_labels', 'masked_lm_positions') +def shuffle_dict(dic, len): + idxs = np.arange(len) + np.random.shuffle(idxs) + for k, v in dic.items(): + dic[k] = v[idxs] -class PretrainingTfRecordDataLoader: + +class PretrainingHDF5DataLoader: def __init__(self, input_files, max_seq_length=128, max_mask_tokens=20, batch_size=1, - micro_batch_size=1, dtype=np.int32, shuffle=False, pad_position_value=511, - prefetch=1, - drop_remainder=False, - enable_fp16=False, - enable_ipu=False, - enable_check_data=False, - ignore_index=-1): + num_workers=3): self.files = input_files self.batch_size = batch_size - self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.max_mask_tokens = max_mask_tokens self.dtype = dtype - self.file_index = 0 - self.data_index = 0 self.shuffle = shuffle - self.len = None self.pad_position_value = pad_position_value - self.drop_remainder = drop_remainder - self.enable_fp16 = enable_fp16 - self.enable_ipu = enable_ipu - self.enable_check_data = enable_check_data - self.ignore_index = ignore_index + if shuffle: + np.random.shuffle(self.files) + + self.counter = 0 + + # get total number of samples pool = multiprocessing.Pool(min(multiprocessing.cpu_count(), 32)) num_samples = pool.map(self.samples_in_file, self.files) pool.close() pool.join() self.total_samples = sum(num_samples) - self.len = self.total_samples // (self.batch_size) - self.num_prefetch_batches = prefetch - self.prefetch_buffer = deque() - self.process_buffer = multiprocessing.Manager().Queue(10) + self.len = self.total_samples // self.batch_size + assert self.len > 1, f"Batch size {self.batch_size} larger than number of samples {self.total_samples}" + + # notify feed and fetch processes/thread to stop self.event_queue = multiprocessing.Manager().Queue(10) - self.feed_buffer = Queue(20) - if self.len < 1: - raise ValueError(f"""Batch size {self.batch_size} larger than - number of samples in the TFRecord files {self.total_samples}.""" - ) - - if self.len < self.num_prefetch_batches: - raise ValueError( - f"""Not enough samples to prefetch: (length = {self.len}, - num_to_prefech = {self.num_prefetch_batches}), - lower the number of prefetch batches.""") - self.samples_per_file = { - f: n - for (f, n) in zip(self.files, num_samples) - } - self.data = None - self.counter = 0 - # multi-process - # workers = multiprocessing.Pool() - self.thread_stop = False - self.thread_process = threading.Thread(target=self.post_fetch) + # buffer to store final data + self.feed_buffer = Queue(20) - self.process_number = 3 - self.files_per_process = len(self.files) // self.process_number - self.split_files = np.array_split(self.files, self.process_number) - self.processor = [ + # number of processes to do remask + self.num_workers = num_workers + # each feed_worker has one process_buffer to use + self.process_buffers = [ + multiprocessing.Manager().Queue(10) for _ in range(num_workers) + ] + self.split_files = np.array_split(self.files, self.num_workers) + # feed_worker will load data frm h5py files, and do remask process + self.feed_workers = [ multiprocessing.Process( - target=self.fill_buffer_loop, args=(i, self.split_files[i])) - for i in range(self.process_number) + target=self.fill_buffer_loop, + args=(self.split_files[idx], self.process_buffers[idx])) + for idx in range(self.num_workers) ] - for p in self.processor: + for p in self.feed_workers: p.start() - self.thread_process.start() - def post_fetch(self): - while True: - if not self.event_queue.empty(): - return - if not self.process_buffer.empty(): - np_feed_list = self.process_buffer.get() - lod_feed_list = [] - for data in np_feed_list: - tensor = paddle.fluid.core.LoDTensor() - place = paddle.CPUPlace() - tensor.set(data, place) - lod_feed_list.append(tensor) - self.feed_buffer.put(lod_feed_list) + # index for which process_buffer is used each time + self.post_fetch_idx = 0 + # load final data from process_buffers + self.fetch_worker = threading.Thread(target=self.post_fetch) + self.fetch_worker.start() def samples_in_file(self, filename): - reader = TfRecordReader( - filename, - transforms={ - k: lambda x: x.numpy().astype(self.dtype) - for k in KEYS - }) - count = 0 - while reader.read_example(): - count += 1 - return count + with h5py.File(filename, "r") as f: + data_len = f[KEYS[0]].shape[0] + return data_len def release(self): self.event_queue.put('END') while not self.feed_buffer.empty(): self.feed_buffer.get() - while not self.process_buffer.empty(): - self.process_buffer.get() - self.thread_process.join() - for p in self.processor: + for process_buffer in self.process_buffers: + while not process_buffer.empty(): + process_buffer.get() + self.fetch_worker.join() + for p in self.feed_workers: p.join() return @@ -151,77 +113,111 @@ def __len__(self): return self.len def __iter__(self): - self.file_index = 0 - self.data_index = 0 self.counter = 0 - self.data = None - # if self.shuffle: - # random.shuffle(self.files) - # self.fill_buffer(self.num_prefetch_batches) return self - def fill_buffer_loop(self, i, files): + def __next__(self): + result = self.feed_buffer.get() + self.counter += 1 + return result + + def post_fetch(self): + while True: + if not self.event_queue.empty(): + return + if not self.process_buffers[self.post_fetch_idx].empty(): + logging.debug(f"self.post_fetch_idx: {self.post_fetch_idx}") + np_feed_list = self.process_buffers[self.post_fetch_idx].get() + self.post_fetch_idx += 1 + if self.post_fetch_idx == self.num_workers: + self.post_fetch_idx = 0 + elif self.post_fetch_idx > self.num_workers: + raise Exception('post_fetch_idx must < num_workers') + + lod_feed_list = [] + for data in np_feed_list: + tensor = paddle.fluid.core.LoDTensor() + place = paddle.CPUPlace() + tensor.set(data, place) + lod_feed_list.append(tensor) + self.feed_buffer.put(lod_feed_list) + + def fill_buffer_loop(self, files, process_buffer): data = None data_index = 0 file_index = 0 - def multiprocess_fill_buffer(num_batches, data, file_index, data_index): + def multiprocess_fill_buffer(data, file_index, data_index): if data is None: - data = self.return_load_data(files[file_index]) + data = self.load_one_file(files[file_index]) file_index += 1 data_index = 0 - for _ in range(num_batches): - curr_batch = [] - still_required = self.batch_size - while still_required > 0: - data_batch = data[data_index:data_index + still_required] - data_index += len(data_batch) - curr_batch += data_batch - still_required = self.batch_size - len(curr_batch) - if still_required > 0: - if file_index >= len(files): - random.shuffle(files) - file_index = 0 - - data = self.return_load_data(files[file_index]) - file_index += 1 - data_index = 0 - if len(curr_batch) == self.batch_size: - result = {} - for k in KEYS: - result[k] = np.vstack([item[k] for item in curr_batch]) - self.process_buffer.put(self.post_process(result)) + + curr_batch = [] + still_required = self.batch_size + while still_required > 0: + data_batch = { + k: data[k][data_index:data_index + still_required] + for k in KEYS + } + data_batch_len = len(data_batch[KEYS[0]]) + data_index += data_batch_len + curr_batch.append(data_batch) + curr_batch_len = sum(len(x[KEYS[0]]) for x in curr_batch) + still_required = self.batch_size - curr_batch_len + if still_required > 0: + if file_index >= len(files): + np.random.shuffle(files) + file_index = 0 + + data = self.load_one_file(files[file_index]) + file_index += 1 + data_index = 0 + if not curr_batch_len == self.batch_size: + raise Exception("data length should equal to batch_size") + + result = {} + for k in KEYS: + result[k] = np.concatenate( + [item[k] for item in curr_batch], axis=0) + process_buffer.put(self.do_remask(result)) return data, file_index, data_index while True: if self.event_queue.empty(): data, file_index, data_index = multiprocess_fill_buffer( - 1, data, file_index, data_index) + data, file_index, data_index) else: return - def post_process(self, samples): - # process_start = time.time() - batch_size, seq_len = samples['input_ids'].shape + def do_remask(self, samples): + input_ids = samples['input_ids'] + segment_ids = samples['segment_ids'] + masked_lm_positions = samples['masked_lm_positions'] + masked_lm_ids = samples['masked_lm_ids'] + next_sentence_labels = samples['next_sentence_labels'] + masked_lm_weights = np.ones_like(masked_lm_ids, dtype=np.int32) + masked_lm_weights[masked_lm_ids == 0] = 0 + + # post process + batch_size, seq_len = input_ids.shape formatted_pos = self.pad_position_value * np.ones_like(samples[ 'input_ids']) - formatted_input = np.zeros_like(samples['input_ids']) - formatted_seg = np.zeros_like(samples['segment_ids']) + formatted_input = np.zeros_like(input_ids) + formatted_seg = np.zeros_like(segment_ids) formatted_mask_labels = np.zeros( - (batch_size, self.max_mask_tokens), - dtype=samples['masked_lm_ids'].dtype) + (batch_size, self.max_mask_tokens), dtype=masked_lm_ids.dtype) valid_seq_positions = [] - valid_mask_positions = samples['masked_lm_weights'] == 1 + valid_mask_positions = masked_lm_weights == 1 valid_mask_len = np.sum(valid_mask_positions, axis=1).reshape(-1, 1) - for i, mask_pos in enumerate(samples['masked_lm_positions']): + for i, mask_pos in enumerate(masked_lm_positions): pos = [True] * seq_len for mask_index, m in enumerate(mask_pos): if mask_index < valid_mask_len[i]: pos[m] = False - valid_seq_positions.append( - np.logical_and(pos, samples['input_ids'][i] != 0)) + valid_seq_positions.append(np.logical_and(pos, input_ids[i] != 0)) valid_seq_len = np.minimum( np.sum(valid_seq_positions, axis=1) + self.max_mask_tokens, self.max_seq_length).reshape(-1, 1) @@ -232,8 +228,8 @@ def post_process(self, samples): target_mask_indices = np.arange(valid_mask_len[i]) target_seq_indices = self.max_mask_tokens + np.arange(unmasked_len[ i]) - source_mask_indices = samples['masked_lm_positions'][i][ - valid_mask_positions[i]] + source_mask_indices = masked_lm_positions[i][valid_mask_positions[ + i]] source_seq_indices = np.arange(seq_len)[valid_seq_positions[ i]][:unmasked_len[i]] @@ -243,60 +239,45 @@ def post_process(self, samples): [source_mask_indices, source_seq_indices]) formatted_pos[i, target_indices] = source_indices - formatted_input[i, target_indices] = samples['input_ids'][ - i, source_indices] - formatted_seg[i, target_indices] = samples['segment_ids'][ - i, source_indices] - formatted_mask_labels[i] = samples['masked_lm_ids'][ - i, :self.max_mask_tokens] - - # process_cost = time.time() - process_start - # print("DEBUG: process cost: {}".format(process_cost)) + formatted_input[i, target_indices] = input_ids[i, source_indices] + formatted_seg[i, target_indices] = segment_ids[i, source_indices] + formatted_mask_labels[i] = masked_lm_ids[i, :self.max_mask_tokens] return [ formatted_input.astype(np.int32), formatted_seg.astype(np.int32), formatted_pos.astype(np.int32), valid_mask_len.astype(np.int32), valid_seq_len.astype(np.int32), formatted_mask_labels.astype(np.int32), - samples['next_sentence_labels'].astype(np.int32) + next_sentence_labels.astype(np.int32) ] - def __next__(self): - if self.drop_remainder: - if self.counter == self.len: - raise StopIteration - - result = self.feed_buffer.get() - self.counter += 1 - return result + def load_one_file(self, file_path): + data = self.load_hdf5(file_path) - def load_data(self): - if self.file_index >= len(self.files): - raise ValueError('No more files to load.') - self.data = self.load_file(self.files[self.file_index]) - self.file_index += 1 - self.data_index = 0 if self.shuffle: - np.random.shuffle(self.data) + shuffle_dict(data, len(data[KEYS[0]])) - def return_load_data(self, file_path): - data = self.load_file(file_path) - # self.file_index += 1 - # self.data_index = 0 - if self.shuffle: - np.random.shuffle(data) return data - def load_file(self, filename): - reader = TfRecordReader( - filename, - transforms={ - k: lambda x: x.numpy().astype(self.dtype) - for k in KEYS - }) - data = [] - ex = reader.read_example() - while ex: - data.append(ex) - ex = reader.read_example() + def load_hdf5(self, filename): + with h5py.File(filename, "r") as f: + data = {key: np.asarray(f[key][:]) for key in KEYS} return data + + +if __name__ == "__main__": + import glob + base_dir = 'data_path/wikicorpus_en/' + input_files = glob.glob(f"{base_dir}/*training*.hdf5") + input_files.sort() + # print(input_files) + + seed = 1984 + np.random.seed(seed) + paddle.seed(seed) + + data_loader = PretrainingHDF5DataLoader( + input_files, batch_size=65536, shuffle=True) + + for idx, batch in enumerate(data_loader): + print(f"{idx}: {batch[0].shape()}") diff --git a/examples/language_model/bert/static_ipu/requirements.txt b/examples/language_model/bert/static_ipu/requirements.txt index 2742b37e7bf0..b81914ad7863 100644 --- a/examples/language_model/bert/static_ipu/requirements.txt +++ b/examples/language_model/bert/static_ipu/requirements.txt @@ -3,5 +3,4 @@ multiprocess numpy scipy wandb -torch==1.7.0 -torch-xla@https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl +h5py diff --git a/examples/language_model/bert/static_ipu/run_pretrain.py b/examples/language_model/bert/static_ipu/run_pretrain.py index 281403b0cdef..e09f6adbe708 100644 --- a/examples/language_model/bert/static_ipu/run_pretrain.py +++ b/examples/language_model/bert/static_ipu/run_pretrain.py @@ -25,7 +25,7 @@ from paddlenlp.transformers import LinearDecayWithWarmup from scipy.stats import truncnorm -from dataset_ipu import PretrainingTfRecordDataLoader +from dataset_ipu import PretrainingHDF5DataLoader from modeling import ( BertModel, DeviceScope, IpuBertConfig, IpuBertPretrainingMLMAccAndLoss, IpuBertPretrainingMLMHeads, IpuBertPretrainingNSPAccAndLoss, @@ -282,6 +282,9 @@ def main(args): ipu_compiler = paddle.static.IpuCompiledProgram( main_program, ipu_strategy=ipu_strategy) logging.info(f'start compiling, please wait some minutes') + logging.info( + f'you can run `export POPART_LOG_LEVEL=INFO` before running program to see the compile progress' + ) cur_time = time.time() main_program = ipu_compiler.compile(feed_list, fetch_list) time_cost = time.time() - cur_time @@ -290,30 +293,26 @@ def main(args): # Load the training dataset input_files = [ os.path.join(args.input_files, f) for f in os.listdir(args.input_files) - if os.path.isfile(os.path.join(args.input_files, f)) and "tfrecord" in f + if os.path.isfile(os.path.join(args.input_files, f)) and "training" in f ] input_files.sort() - dataset = PretrainingTfRecordDataLoader( + dataset = PretrainingHDF5DataLoader( input_files=input_files, max_seq_length=args.seq_len, max_mask_tokens=args.max_predictions_per_seq, batch_size=args.batch_size, - micro_batch_size=args.micro_batch_size, - enable_fp16=True, - enable_ipu=True, - ignore_index=args.ignore_index, shuffle=args.shuffle) - logging.info(f"Dataset length: {len(dataset)}") + logging.info(f"dataset length: {len(dataset)}") total_samples = dataset.total_samples - logging.info("Total samples: %d, Total batch_size: %d, Max steps: %d" % + logging.info("total samples: %d, total batch_size: %d, max steps: %d" % (total_samples, args.batch_size, args.max_steps)) batch_start = time.time() global_step = 0 for batch in dataset: global_step += 1 - epoch = global_step * args.batch_size / total_samples + epoch = global_step * args.batch_size // total_samples read_cost = time.time() - batch_start feed = { @@ -338,6 +337,7 @@ def main(args): if args.wandb: wandb.log({ + "epoch": epoch, "global_step": global_step, "loss/MLM": np.mean(loss_return[1]), "loss/NSP": np.mean(loss_return[3]), @@ -407,4 +407,4 @@ def main(args): logging.info(args) main(args) - logging.info("Program Finished") + logging.info("program finished") diff --git a/examples/language_model/bert/static_ipu/run_pretrain.sh b/examples/language_model/bert/static_ipu/run_pretrain.sh index 83ed12b30e88..cd1c5bb00f40 100755 --- a/examples/language_model/bert/static_ipu/run_pretrain.sh +++ b/examples/language_model/bert/static_ipu/run_pretrain.sh @@ -2,7 +2,7 @@ export RDMAV_FORK_SAFE=1 python3 run_pretrain.py \ - --input_files /alleng/dataset/train128wiki \ + --input_files "path_to_phase1_hdf5_dataset" \ --output_dir pretrain_128_model \ --seq_len 128 \ --hidden_size 768 \ diff --git a/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh b/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh index 6e763076263a..a3ca4b293b45 100755 --- a/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh +++ b/examples/language_model/bert/static_ipu/run_pretrain_phase2.sh @@ -2,7 +2,7 @@ export RDMAV_FORK_SAFE=1 python3 run_pretrain.py \ - --input_files /alleng/dataset/train384wiki \ + --input_files "path_to_phase2_hdf5_dataset" \ --output_dir pretrain_384_model \ --seq_len 384 \ --hidden_size 768 \ @@ -13,7 +13,7 @@ python3 run_pretrain.py \ --weight_decay 1e-2 \ --max_steps 2137 \ --warmup_steps 274 \ - --logging_steps 20 \ + --logging_steps 10 \ --seed 1984 \ --beta1 0.9 \ --beta2 0.999 \ diff --git a/examples/language_model/bert/static_ipu/run_squad.py b/examples/language_model/bert/static_ipu/run_squad.py index 66e511e38ff6..41d895ad7ad7 100644 --- a/examples/language_model/bert/static_ipu/run_squad.py +++ b/examples/language_model/bert/static_ipu/run_squad.py @@ -198,24 +198,26 @@ def load_squad_dataset(args): args.max_seq_length = args.seq_len args.doc_stride = 128 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - cache_file = f'{args.input_files}.{args.max_seq_length}.cache' + if args.is_training: + cache_file = f'squad_train_v1.{args.max_seq_length}.cache' + else: + cache_file = f'squad_dev_v1.{args.max_seq_length}.cache' features_fn = prepare_train_features if args.is_training else prepare_validation_features if os.path.exists(cache_file): - logging.info(f"Loading Cache {cache_file}") + logging.info(f"loading cache {cache_file}") with open(cache_file, "rb") as f: dataset = pickle.load(f) else: + logging.info(f"loading squad dataset, it will take a few minutes") if args.is_training: - dataset = load_dataset( - 'squad', splits='train_v1', data_files=args.input_files) + dataset = load_dataset('squad', splits='train_v1') else: - dataset = load_dataset( - 'squad', splits='dev_v1', data_files=args.input_files) + dataset = load_dataset('squad', splits='dev_v1') dataset.map(partial( features_fn, tokenizer=tokenizer, args=args), batched=True, num_workers=20) - logging.info(f"saving cache {cache_file}") + logging.info(f"saving cache to {cache_file}") with open(cache_file, "wb") as f: pickle.dump(dataset, f) @@ -286,7 +288,7 @@ def main(args): # custom_ops custom_ops = load_custom_ops() - logging.info("Building Model") + logging.info("building model") if args.is_training: [indices, segments, positions, input_mask, start_labels, @@ -313,7 +315,7 @@ def main(args): total_samples = len(data_loader.dataset) max_steps = total_samples // args.batch_size * args.epochs - logging.info("Total samples: %d, Total batch_size: %d, Max steps: %d" % + logging.info("total samples: %d, total batch_size: %d, max steps: %d" % (total_samples, args.batch_size, max_steps)) if args.is_training: @@ -370,6 +372,9 @@ def main(args): ipu_compiler = paddle.static.IpuCompiledProgram( main_program, ipu_strategy=ipu_strategy) logging.info(f'start compiling, please wait some minutes') + logging.info( + f'you can run `export POPART_LOG_LEVEL=INFO` before running program to see the compile progress' + ) cur_time = time.time() main_program = ipu_compiler.compile(feed_list, fetch_list) time_cost = time.time() - cur_time @@ -403,11 +408,14 @@ def main(args): tput = args.batch_size / total_cost if args.wandb: wandb.log({ + "epoch": epoch, + "global_step": global_step, "loss": np.mean(outputs[0]), "accuracy": np.mean(outputs[1:]), + "train_cost": train_cost, + "total_cost": total_cost, "throughput": tput, "learning_rate": lr_scheduler(), - "global_step": global_step, }) if global_step % args.logging_steps == 0: @@ -483,7 +491,10 @@ def main(args): if __name__ == "__main__": args = parse_args() - logging.basicConfig(level=logging.INFO) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt='%Y-%m-%d %H:%M:%S %a') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir, exist_ok=True) @@ -500,4 +511,4 @@ def main(args): logging.info(args) main(args) - logging.info("Program Finished") + logging.info("program finished") diff --git a/examples/language_model/bert/static_ipu/run_squad.sh b/examples/language_model/bert/static_ipu/run_squad.sh index bba592a1c8a0..a1e9e383c0f7 100755 --- a/examples/language_model/bert/static_ipu/run_squad.sh +++ b/examples/language_model/bert/static_ipu/run_squad.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash python3 run_squad.py \ - --input_files data/squad/train-v1.1.json \ --output_dir squad_model \ --task "SQUAD" \ --is_training True \ diff --git a/examples/language_model/bert/static_ipu/run_squad_infer.sh b/examples/language_model/bert/static_ipu/run_squad_infer.sh index 0fef66dd9f15..8b1e52ef8af2 100755 --- a/examples/language_model/bert/static_ipu/run_squad_infer.sh +++ b/examples/language_model/bert/static_ipu/run_squad_infer.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash python3 run_squad.py \ - --input_files data/squad/dev-v1.1.json \ --output_dir squad_model \ --task "SQUAD" \ --is_training False \ diff --git a/examples/language_model/bert/static_ipu/utils.py b/examples/language_model/bert/static_ipu/utils.py index 465015d35fc0..f41f5bc995d7 100644 --- a/examples/language_model/bert/static_ipu/utils.py +++ b/examples/language_model/bert/static_ipu/utils.py @@ -21,22 +21,27 @@ def load_custom_ops(): cur_dir = os.path.dirname(os.path.realpath(__file__)) custom_dir = cur_dir + "/custom_ops" + sources = [ + f"{custom_dir}/custom_checkpointoutput.cc", + f"{custom_dir}/custom_detach.cc", f"{custom_dir}/custom_identity.cc", + f"{custom_dir}/custom_nll_loss.cc", + f"{custom_dir}/tied_gather_pattern.cc", f"{custom_dir}/tied_gather.cc", + f"{custom_dir}/disable_attn_dropout_bwd_pattern.cc", + f"{custom_dir}/workarounds/prevent_const_expr_folding_op.cc", + f"{custom_dir}/utils.cc" + ] + + if '2.5.0' in os.environ['POPLAR_SDK_ENABLED']: + build_dir = cur_dir + "/custom_ops_2.5" + else: + sources.append(f"{custom_dir}/custom_shape_infer.cc") + build_dir = custom_dir + custom_ops = load( name="custom_ops", - sources=[ - f"{custom_dir}/custom_shape_infer.cc", - f"{custom_dir}/custom_checkpointoutput.cc", - f"{custom_dir}/custom_detach.cc", - f"{custom_dir}/custom_identity.cc", - f"{custom_dir}/custom_nll_loss.cc", - f"{custom_dir}/tied_gather_pattern.cc", - f"{custom_dir}/tied_gather.cc", - f"{custom_dir}/disable_attn_dropout_bwd_pattern.cc", - f"{custom_dir}/workarounds/prevent_const_expr_folding_op.cc", - f"{custom_dir}/utils.cc", - ], + sources=sources, extra_cxx_cflags=['-DONNX_NAMESPACE=onnx'], - build_directory=custom_dir, ) + build_directory=build_dir, ) return custom_ops