Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT-3] fix auto_engine and eager_engine, add inference script #5563

Merged
merged 6 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from paddle.distributed.fleet import auto
from ppfleetx.core.engine import BasicEngine
from ppfleetx.core.module import BasicModule
from ppfleetx.optims import build_lr_scheduler, build_optimizer
try:
from ppfleetx.optims import build_lr_scheduler, build_optimizer
except Exception:
pass
from ppfleetx.utils.log import logger
from ppfleetx.utils.version import version_check

Expand Down
12 changes: 9 additions & 3 deletions model_zoo/gpt-3/ppfleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,21 @@
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
from paddle.distributed.parallel import sync_params_buffers
try:
from paddle.distributed.parallel import sync_params_buffers
except Exception:
pass
from paddle.distributed.sharding import group_sharded_parallel
from paddle.incubate.distributed.utils.io import save_for_auto_inference
from paddle.profiler import SummaryView
from ppfleetx.core.engine import BasicEngine, InferenceEngine, TensorRTConfig
from ppfleetx.core.module import BasicModule
from ppfleetx.distributed.apis import amp, env
from ppfleetx.optims import build_lr_scheduler, build_optimizer
from ppfleetx.utils.compression_helper import prune_model, quant_model
try:
from ppfleetx.optims import build_lr_scheduler, build_optimizer
from ppfleetx.utils.compression_helper import prune_model, quant_model
except Exception:
pass
from ppfleetx.utils.device import synchronize as device_synchronize
from ppfleetx.utils.export import export_inference_model
from ppfleetx.utils.log import convert_timestamp_to_data, get_timestamp, logger
Expand Down
4 changes: 4 additions & 0 deletions model_zoo/gpt-3/ppfleetx/core/engine/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
try:
from ppfleetx_ops import topp_sampling
except Exception as e:
pass

# TensorRT precisions
TRT_PRECISIONS = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
try:
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
except:
FusedDropoutAdd = None
FusedDropoutAdd = None


def get_attr(layer, name):
Expand Down Expand Up @@ -575,8 +579,12 @@ def __init__(
mark_as_sequence_parallel_parameter(self.norm1.bias)
mark_as_sequence_parallel_parameter(self.norm2.weight)
mark_as_sequence_parallel_parameter(self.norm2.bias)
self.fused_dropout_add1 = FusedDropoutAdd(dropout, mode="upscale_in_train")
self.fused_dropout_add2 = FusedDropoutAdd(act_dropout, mode="upscale_in_train")
if not FusedDropoutAdd:
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
else:
self.fused_dropout_add1 = FusedDropoutAdd(dropout, mode="upscale_in_train")
self.fused_dropout_add2 = FusedDropoutAdd(act_dropout, mode="upscale_in_train")

self.activation = getattr(F, activation)

Expand All @@ -600,7 +608,10 @@ def forward(self, tgt, memory=None, tgt_mask=None, use_cache=False, cache=None):
else:
current_seed = "global_seed"
with get_rng_state_tracker().rng_state(current_seed):
tgt = self.fused_dropout_add1(tgt, residual)
if not FusedDropoutAdd:
tgt = residual + self.dropout1(tgt)
else:
tgt = self.fused_dropout_add1(tgt, residual)

if not self.normalize_before:
tgt = self.norm1(tgt)
Expand All @@ -610,7 +621,10 @@ def forward(self, tgt, memory=None, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm2(tgt)

with get_rng_state_tracker().rng_state(current_seed):
tgt = self.fused_dropout_add2(self.linear2(F.gelu(self.linear1(tgt), approximate=True)), residual)
if not FusedDropoutAdd:
tgt = residual + self.linear2(F.gelu(self.linear1(tgt), approximate=True))
else:
tgt = self.fused_dropout_add2(self.linear2(F.gelu(self.linear1(tgt), approximate=True)), residual)

if not self.normalize_before:
tgt = self.norm2(tgt)
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/gpt-3/ppfleetx/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_device_and_mapping():
"gpu": paddle.is_compiled_with_cuda(),
"xpu": paddle.is_compiled_with_xpu(),
"rocm": paddle.is_compiled_with_rocm(),
"npu": paddle.is_compiled_with_custom_device("npu"),
"npu": paddle.is_compiled_with_npu() or paddle.is_compiled_with_custom_device("npu"),
"cpu": True,
}
for d, v in suppoted_device_map.items():
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/gpt-3/projects/gpt/auto_export_gpt_345M_mp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ rm -rf $log_dir

python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1" \
./tools/auto_export.py \
-c ./ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_mp2.yaml \
-c ./ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_mp2.yaml
2 changes: 1 addition & 1 deletion model_zoo/gpt-3/projects/gpt/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def predict(engine, data, args):
for _ in range(args.iter):
engine.predictor.run()
end = time.perf_counter()
print(f"batch {args.iter} run time: {1000 * (end - start) / args.iter}ms")
print(f"batch {data.shape} run time: {1000 * (end - start) / args.iter}ms")

return {name: engine.predictor.get_output_handle(name).copy_to_cpu() for name in engine.output_names()}

Expand Down
9 changes: 9 additions & 0 deletions model_zoo/gpt-3/projects/gpt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
from __future__ import absolute_import, division, print_function

import argparse
import os
import time
import sys
import numpy as np
import paddle
import paddle.distributed.fleet as fleet

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, "../")))

import paddle.distributed.fleet as fleet
from ppfleetx.core.engine.inference_engine import InferenceEngine
Expand Down
10 changes: 10 additions & 0 deletions model_zoo/gpt-3/run_mp8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# cd external_ops && python setup.py install && cd -

export USE_FAST_LN=1
export USE_LINEAR_WITH_GRAD_ADD=1

python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./tools/auto_export.py -c ./ppfleetx/configs/nlp/gpt/auto/generation_gpt_175B_mp8.yaml

python -m paddle.distributed.launch projects/gpt/inference.py --mp_degree 8 --model_dir output

python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" projects/gpt/benchmark.py --seq_len 128 --iter 10 --mp_degree 8 --model_dir ./output