diff --git a/paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py b/paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py index 65d4a2d65bfc..502e464d0e0e 100644 --- a/paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py +++ b/paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py @@ -24,6 +24,7 @@ from pprint import pprint from paddlenlp.ops import FasterTransformer +from paddlenlp.ops import enable_faster_encoder from paddlenlp.utils.log import logger from paddlenlp.data import Pad @@ -33,7 +34,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--config", - default="./sample/config/decoding.sample.yaml", + default="./faster_transformer/sample/config/decoding.sample.yaml", type=str, help="Path of the config file. ") parser.add_argument( @@ -45,6 +46,15 @@ def parse_args(): "--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ") + parser.add_argument( + "--enable_faster_encoder", + action="store_true", + help="Whether to use faster version encoder to predict. This is experimental option for now. " + ) + parser.add_argument( + "--use_fp16_encoder", + action="store_true", + help="Whether to use fp16 encoder to predict. ") args = parser.parse_args() return args @@ -69,7 +79,7 @@ def generate_src_word(batch_size, vocab_size, max_length, eos_idx, pad_idx): def do_predict(args): place = "gpu" - paddle.set_device(place) + place = paddle.set_device(place) # Define model transformer = FasterTransformer( @@ -91,11 +101,17 @@ def do_predict(args): topp=args.topp, max_out_len=args.max_out_len, decoding_lib=args.decoding_lib, - use_fp16_decoding=args.use_fp16_decoding) + use_fp16_decoding=args.use_fp16_decoding, + enable_faster_encoder=args.enable_faster_encoder, + use_fp16_encoder=args.use_fp16_encoder) # Set evaluate mode transformer.eval() + if args.enable_faster_encoder: + transformer = enable_faster_encoder( + transformer, need_build=False, use_fp16=args.use_fp16_encoder) + src_word = generate_src_word( batch_size=args.infer_batch_size, vocab_size=args.src_vocab_size, @@ -107,10 +123,10 @@ def do_predict(args): for i in range(100): # For warmup. if 50 == i: - paddle.device.cuda.synchronize() + paddle.device.cuda.synchronize(place) start = time.time() transformer(src_word=src_word) - paddle.device.cuda.synchronize() + paddle.device.cuda.synchronize(place) logger.info("Average test time for encoder-decoding is %f ms" % ( (time.time() - start) / 50 * 1000)) @@ -120,8 +136,10 @@ def do_predict(args): yaml_file = ARGS.config with open(yaml_file, 'rt') as f: args = AttrDict(yaml.safe_load(f)) - pprint(args) args.decoding_lib = ARGS.decoding_lib args.use_fp16_decoding = ARGS.use_fp16_decoding + args.enable_faster_encoder = ARGS.enable_faster_encoder + args.use_fp16_encoder = ARGS.use_fp16_encoder + pprint(args) do_predict(args) diff --git a/paddlenlp/ops/faster_transformer/transformer/encoder.py b/paddlenlp/ops/faster_transformer/transformer/encoder.py index d50e46e6b9c3..d75e420a0fcd 100644 --- a/paddlenlp/ops/faster_transformer/transformer/encoder.py +++ b/paddlenlp/ops/faster_transformer/transformer/encoder.py @@ -241,7 +241,7 @@ def encoder_forward(self, src, src_mask=None, cache=None): def enable_faster_encoder(self, need_build=True, use_fp16=False, - decoding_lib=None): + encoder_lib=None): """ Compiles fusion encoder operator intergrated FasterTransformer using the method of JIT(Just-In-Time) and replaces the `forward` function of @@ -285,13 +285,13 @@ def init_func(layer): try: # Pass decoding lib to prevent re-building encoder. # Todo: check weather decoding lib have contained encoder or not. - if decoding_lib is not None: - load_op_meta_info_and_register_op(decoding_lib) + if encoder_lib is not None: + load_op_meta_info_and_register_op(encoder_lib) else: load("FasterTransformer", verbose=True) except Exception: logger.warning( - "Exception occurs when using FasterTransformer. " \ + "Exception occurs when using FasterEncoder. " \ "The original forward will be involved. ") return self for layer in self.children(): diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py index 075f764d04a3..2fd61c3d949b 100644 --- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py +++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py @@ -100,7 +100,14 @@ class FasterTransformer(TransformerModel): for details. Bigger `diversity_rate` would lead to more diversity. if `diversity_rate == 0` is equivalent to naive BeamSearch. Default to 0 if not set. - use_fp16_decoding(bool, optional): Whether to use fp16 for decoding. + use_fp16_decoding(bool, optional): + Whether to use fp16 for decoding. + enable_faster_encoder(bool, optional): + Whether to use the faster version of encoder. This is experimental option for now. + Defaults to False. + use_fp16_encoder(bool, optional): + Whether to use fp16 for encoder. Only works when enable_faster_encoder is True. + Defaults to False. rel_len(bool, optional): Indicating whether `max_out_len` in is the length relative to that of source text. Only works in `v2` temporarily. It is suggest to set @@ -135,6 +142,8 @@ def __init__(self, diversity_rate=0.0, decoding_lib=None, use_fp16_decoding=False, + enable_faster_encoder=False, + use_fp16_encoder=False, rel_len=False, alpha=0.6): # if decoding_lib is None: @@ -154,6 +163,8 @@ def __init__(self, self.diversity_rate = args.pop("diversity_rate") self.decoding_lib = args.pop("decoding_lib") self.use_fp16_decoding = args.pop("use_fp16_decoding") + self.enable_faster_encoder = args.pop("enable_faster_encoder") + self.use_fp16_encoder = args.pop("use_fp16_encoder") self.rel_len = args.pop("rel_len") self.alpha = args.pop("alpha") self.dropout = dropout @@ -164,6 +175,13 @@ def __init__(self, self.max_length = max_length super(FasterTransformer, self).__init__(**args) + if self.enable_faster_encoder: + logger.warning( + "enable_faster_encoder is an experimental option and subject to change." + ) + elif self.use_fp16_encoder: + self.use_fp16_encoder = False + self.decoding_linear = nn.Linear( in_features=d_model, out_features=trg_vocab_size) @@ -210,10 +228,16 @@ def forward(self, src_word, trg_word=None): enc_input = F.dropout( src_emb, p=self.dropout, training=False) if self.dropout else src_emb + + if self.enable_faster_encoder and self.use_fp16_encoder: + enc_input = paddle.cast(enc_input, dtype="float16") + enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias) - if self.use_fp16_decoding: + if self.use_fp16_decoding and enc_output.dtype != paddle.float16: enc_output = paddle.cast(enc_output, dtype="float16") + elif not self.use_fp16_decoding and enc_output.dtype != paddle.float32: + enc_output = paddle.cast(enc_output, dtype="float32") mem_seq_lens = paddle.sum(paddle.cast( src_word != self.bos_id, dtype="int32"), @@ -1104,12 +1128,12 @@ def forward(self, forced_eos_token_id=None, **model_kwargs): - self.encoder = enable_faster_encoder(self.encoder, need_build=False) if encoder_output is None: + self.encoder = enable_faster_encoder(self.encoder, need_build=False) assert input_ids is not None, "You have to specify either input_ids or encoder_output." encoder_output = self.prepare_encoder_decoder_kwargs_for_generation( input_ids, model_kwargs)["encoder_output"] - self.encoder = disable_faster_encoder(self.encoder) + self.encoder = disable_faster_encoder(self.encoder) if seq_len is None: assert input_ids is not None, "You have to specify either input_ids when generating seq_len." seq_len = paddle.sum(paddle.cast( @@ -1207,12 +1231,13 @@ def forward(self, decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr( self._model, 'decoder_start_token_id', None) - self.encoder = enable_faster_encoder(self.encoder, need_build=False) + #(gongenlei) Not enable_faster_encoder temporarily if encoder_output is None: + self.encoder = enable_faster_encoder(self.encoder, need_build=False) assert input_ids is not None, "You have to specify either input_ids or encoder_output." encoder_output = self.prepare_encoder_decoder_kwargs_for_generation( input_ids, model_kwargs)["encoder_output"] - self.encoder = disable_faster_encoder(self.encoder) + self.encoder = disable_faster_encoder(self.encoder) batch_size = paddle.shape(encoder_output)[0] if seq_len is None: assert input_ids is not None, "You have to specify either input_ids when generating seq_len."