Skip to content

Commit

Permalink
Transformer supports ft faster encoder (#1430)
Browse files Browse the repository at this point in the history
* transformer support ft encoder


Co-authored-by: Guo Sheng <whucsgs@163.com>
  • Loading branch information
FrostML and guoshengCS authored Dec 10, 2021
1 parent 3a24b61 commit 0720afe
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 16 deletions.
30 changes: 24 additions & 6 deletions paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pprint import pprint

from paddlenlp.ops import FasterTransformer
from paddlenlp.ops import enable_faster_encoder

from paddlenlp.utils.log import logger
from paddlenlp.data import Pad
Expand All @@ -33,7 +34,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="./sample/config/decoding.sample.yaml",
default="./faster_transformer/sample/config/decoding.sample.yaml",
type=str,
help="Path of the config file. ")
parser.add_argument(
Expand All @@ -45,6 +46,15 @@ def parse_args():
"--use_fp16_decoding",
action="store_true",
help="Whether to use fp16 decoding to predict. ")
parser.add_argument(
"--enable_faster_encoder",
action="store_true",
help="Whether to use faster version encoder to predict. This is experimental option for now. "
)
parser.add_argument(
"--use_fp16_encoder",
action="store_true",
help="Whether to use fp16 encoder to predict. ")
args = parser.parse_args()
return args

Expand All @@ -69,7 +79,7 @@ def generate_src_word(batch_size, vocab_size, max_length, eos_idx, pad_idx):

def do_predict(args):
place = "gpu"
paddle.set_device(place)
place = paddle.set_device(place)

# Define model
transformer = FasterTransformer(
Expand All @@ -91,11 +101,17 @@ def do_predict(args):
topp=args.topp,
max_out_len=args.max_out_len,
decoding_lib=args.decoding_lib,
use_fp16_decoding=args.use_fp16_decoding)
use_fp16_decoding=args.use_fp16_decoding,
enable_faster_encoder=args.enable_faster_encoder,
use_fp16_encoder=args.use_fp16_encoder)

# Set evaluate mode
transformer.eval()

if args.enable_faster_encoder:
transformer = enable_faster_encoder(
transformer, need_build=False, use_fp16=args.use_fp16_encoder)

src_word = generate_src_word(
batch_size=args.infer_batch_size,
vocab_size=args.src_vocab_size,
Expand All @@ -107,10 +123,10 @@ def do_predict(args):
for i in range(100):
# For warmup.
if 50 == i:
paddle.device.cuda.synchronize()
paddle.device.cuda.synchronize(place)
start = time.time()
transformer(src_word=src_word)
paddle.device.cuda.synchronize()
paddle.device.cuda.synchronize(place)
logger.info("Average test time for encoder-decoding is %f ms" % (
(time.time() - start) / 50 * 1000))

Expand All @@ -120,8 +136,10 @@ def do_predict(args):
yaml_file = ARGS.config
with open(yaml_file, 'rt') as f:
args = AttrDict(yaml.safe_load(f))
pprint(args)
args.decoding_lib = ARGS.decoding_lib
args.use_fp16_decoding = ARGS.use_fp16_decoding
args.enable_faster_encoder = ARGS.enable_faster_encoder
args.use_fp16_encoder = ARGS.use_fp16_encoder
pprint(args)

do_predict(args)
8 changes: 4 additions & 4 deletions paddlenlp/ops/faster_transformer/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
def enable_faster_encoder(self,
need_build=True,
use_fp16=False,
decoding_lib=None):
encoder_lib=None):
"""
Compiles fusion encoder operator intergrated FasterTransformer using the
method of JIT(Just-In-Time) and replaces the `forward` function of
Expand Down Expand Up @@ -285,13 +285,13 @@ def init_func(layer):
try:
# Pass decoding lib to prevent re-building encoder.
# Todo: check weather decoding lib have contained encoder or not.
if decoding_lib is not None:
load_op_meta_info_and_register_op(decoding_lib)
if encoder_lib is not None:
load_op_meta_info_and_register_op(encoder_lib)
else:
load("FasterTransformer", verbose=True)
except Exception:
logger.warning(
"Exception occurs when using FasterTransformer. " \
"Exception occurs when using FasterEncoder. " \
"The original forward will be involved. ")
return self
for layer in self.children():
Expand Down
37 changes: 31 additions & 6 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ class FasterTransformer(TransformerModel):
for details. Bigger `diversity_rate` would lead to more diversity.
if `diversity_rate == 0` is equivalent to naive BeamSearch. Default
to 0 if not set.
use_fp16_decoding(bool, optional): Whether to use fp16 for decoding.
use_fp16_decoding(bool, optional):
Whether to use fp16 for decoding.
enable_faster_encoder(bool, optional):
Whether to use the faster version of encoder. This is experimental option for now.
Defaults to False.
use_fp16_encoder(bool, optional):
Whether to use fp16 for encoder. Only works when enable_faster_encoder is True.
Defaults to False.
rel_len(bool, optional):
Indicating whether `max_out_len` in is the length relative to that
of source text. Only works in `v2` temporarily. It is suggest to set
Expand Down Expand Up @@ -135,6 +142,8 @@ def __init__(self,
diversity_rate=0.0,
decoding_lib=None,
use_fp16_decoding=False,
enable_faster_encoder=False,
use_fp16_encoder=False,
rel_len=False,
alpha=0.6):
# if decoding_lib is None:
Expand All @@ -154,6 +163,8 @@ def __init__(self,
self.diversity_rate = args.pop("diversity_rate")
self.decoding_lib = args.pop("decoding_lib")
self.use_fp16_decoding = args.pop("use_fp16_decoding")
self.enable_faster_encoder = args.pop("enable_faster_encoder")
self.use_fp16_encoder = args.pop("use_fp16_encoder")
self.rel_len = args.pop("rel_len")
self.alpha = args.pop("alpha")
self.dropout = dropout
Expand All @@ -164,6 +175,13 @@ def __init__(self,
self.max_length = max_length
super(FasterTransformer, self).__init__(**args)

if self.enable_faster_encoder:
logger.warning(
"enable_faster_encoder is an experimental option and subject to change."
)
elif self.use_fp16_encoder:
self.use_fp16_encoder = False

self.decoding_linear = nn.Linear(
in_features=d_model, out_features=trg_vocab_size)

Expand Down Expand Up @@ -210,10 +228,16 @@ def forward(self, src_word, trg_word=None):
enc_input = F.dropout(
src_emb, p=self.dropout,
training=False) if self.dropout else src_emb

if self.enable_faster_encoder and self.use_fp16_encoder:
enc_input = paddle.cast(enc_input, dtype="float16")

enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias)

if self.use_fp16_decoding:
if self.use_fp16_decoding and enc_output.dtype != paddle.float16:
enc_output = paddle.cast(enc_output, dtype="float16")
elif not self.use_fp16_decoding and enc_output.dtype != paddle.float32:
enc_output = paddle.cast(enc_output, dtype="float32")

mem_seq_lens = paddle.sum(paddle.cast(
src_word != self.bos_id, dtype="int32"),
Expand Down Expand Up @@ -1104,12 +1128,12 @@ def forward(self,
forced_eos_token_id=None,
**model_kwargs):

self.encoder = enable_faster_encoder(self.encoder, need_build=False)
if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
self.encoder = disable_faster_encoder(self.encoder)
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
seq_len = paddle.sum(paddle.cast(
Expand Down Expand Up @@ -1207,12 +1231,13 @@ def forward(self,
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr(
self._model, 'decoder_start_token_id', None)

self.encoder = enable_faster_encoder(self.encoder, need_build=False)
#(gongenlei) Not enable_faster_encoder temporarily
if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
self.encoder = disable_faster_encoder(self.encoder)
batch_size = paddle.shape(encoder_output)[0]
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
Expand Down

0 comments on commit 0720afe

Please sign in to comment.