-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add inference program for Transformer.
- Loading branch information
1 parent
3b54986
commit ff80721
Showing
4 changed files
with
496 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import numpy as np | ||
|
||
import paddle.v2 as paddle | ||
import paddle.fluid as fluid | ||
|
||
import model | ||
from model import wrap_encoder as encoder | ||
from model import wrap_decoder as decoder | ||
from config import InferTaskConfig, ModelHyperParams, \ | ||
encoder_input_data_names, decoder_input_data_names | ||
from train import pad_batch_data | ||
|
||
|
||
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, | ||
decoder, dec_in_names, dec_out_names, beam_size, max_length, | ||
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, | ||
bos_idx, eos_idx): | ||
""" | ||
Run the encoder program once and run the decoder program multiple times to | ||
implement beam search externally. | ||
""" | ||
# Prepare data for encoder and run the encoder. | ||
enc_in_data = pad_batch_data( | ||
src_words, | ||
src_pad_idx, | ||
n_head, | ||
is_target=False, | ||
return_pos=True, | ||
return_attn_bias=True, | ||
return_max_len=True) | ||
enc_output = exe.run(encoder, | ||
feed=dict(zip(enc_in_names, enc_in_data)), | ||
fetch_list=enc_out_names)[0] | ||
|
||
# Beam Search. | ||
# To store the beam info. | ||
scores = np.zeros((batch_size, beam_size), dtype="float32") | ||
prev_branchs = [[]] * batch_size | ||
next_ids = [[]] * batch_size | ||
# Use beam_map to map the instance idx in batch to beam idx, since the | ||
# size of feeded batch is changing. | ||
beam_map = range(batch_size) | ||
|
||
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): | ||
""" | ||
Decode and select n_best sequences for one instance by backtrace. | ||
""" | ||
seqs = [] | ||
for i in range(n_best): | ||
k = i | ||
seq = [] | ||
for j in range(len(prev_branchs) - 1, -1, -1): | ||
seq.append(next_ids[j][k]) | ||
k = prev_branchs[j][k] | ||
seq = seq[::-1] | ||
seq = [bos_idx] + seq if add_bos else seq | ||
seqs.append(seq) | ||
return seqs | ||
|
||
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output): | ||
""" | ||
Initialize the input data for decoder. | ||
""" | ||
trg_words = np.array( | ||
[[bos_idx]] * batch_size * beam_size, dtype="int64") | ||
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") | ||
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ | ||
-1], enc_in_data[-2], 1 | ||
trg_src_attn_bias = np.tile( | ||
src_slf_attn_bias[:, :, ::src_max_length, :], | ||
[beam_size, 1, trg_max_len, 1]) | ||
enc_output = np.tile(enc_output, [beam_size, 1, 1]) | ||
# No need for trg_slf_attn_bias because of no paddings. | ||
return trg_words, trg_pos, None, trg_src_attn_bias, enc_output | ||
|
||
def update_dec_in_data(dec_in_data, next_ids, active_beams): | ||
""" | ||
Update the input data of decoder mainly by slicing from the previous | ||
input data and dropping the finished instance beams. | ||
""" | ||
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data | ||
trg_words = np.array( | ||
[ | ||
beam_backtrace( | ||
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True) | ||
for beam_idx in active_beams | ||
], | ||
dtype="int64") | ||
trg_words = trg_words.reshape([-1, 1]) | ||
trg_pos = np.array( | ||
[range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size, | ||
dtype="int64").reshape([-1, 1]) | ||
active_beams_indice = ( | ||
(np.array(active_beams) * beam_size)[:, np.newaxis] + | ||
np.array(range(beam_size))[np.newaxis, :]).flatten() | ||
trg_src_attn_bias = np.tile(trg_src_attn_bias[ | ||
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], | ||
[1, 1, len(next_ids[0]) + 1, 1]) | ||
enc_output = enc_output[active_beams_indice, :, :] | ||
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output | ||
|
||
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, | ||
enc_output) | ||
for i in range(max_length): | ||
predict_all = exe.run(decoder, | ||
feed=dict( | ||
filter(lambda item: item[1] is not None, | ||
zip(dec_in_names, dec_in_data))), | ||
fetch_list=dec_out_names)[0] | ||
predict_all = np.log(predict_all) | ||
predict_all = ( | ||
predict_all.reshape( | ||
[len(beam_map) * beam_size, i + 1, -1])[:, -1, :] + | ||
scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape( | ||
[len(beam_map), beam_size, -1]) | ||
active_beams = [] | ||
for inst_idx, beam_idx in enumerate(beam_map): | ||
predict = (predict_all[inst_idx, :, :] | ||
if i != 0 else predict_all[inst_idx, 0, :]).flatten() | ||
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] | ||
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: | ||
-1]] | ||
top_scores = predict[top_scores_ids] | ||
scores[beam_idx] = top_scores | ||
prev_branchs[beam_idx].append(top_scores_ids / | ||
predict_all.shape[-1]) | ||
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) | ||
if next_ids[beam_idx][-1][0] != eos_idx: | ||
active_beams.append(beam_idx) | ||
beam_map = active_beams | ||
if len(beam_map) == 0: | ||
break | ||
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) | ||
|
||
# Decode beams and select n_best sequences for each instance by backtrace. | ||
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] | ||
|
||
return seqs, scores[:, :n_best].tolist() | ||
|
||
|
||
def main(): | ||
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
# The current program desc is coupled with batch_size and the only | ||
# supported batch size is 1 currently. | ||
encoder_program = fluid.Program() | ||
model.batch_size = InferTaskConfig.batch_size | ||
with fluid.program_guard(main_program=encoder_program): | ||
enc_output = encoder( | ||
ModelHyperParams.src_vocab_size + 1, | ||
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, | ||
ModelHyperParams.n_head, ModelHyperParams.d_key, | ||
ModelHyperParams.d_value, ModelHyperParams.d_model, | ||
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, | ||
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) | ||
|
||
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size | ||
decoder_program = fluid.Program() | ||
with fluid.program_guard(main_program=decoder_program): | ||
predict = decoder( | ||
ModelHyperParams.trg_vocab_size + 1, | ||
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, | ||
ModelHyperParams.n_head, ModelHyperParams.d_key, | ||
ModelHyperParams.d_value, ModelHyperParams.d_model, | ||
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, | ||
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) | ||
|
||
# Load model parameters of encoder and decoder separately from the saved | ||
# transformer model. | ||
encoder_var_names = [] | ||
for op in encoder_program.block(0).ops: | ||
encoder_var_names += op.input_arg_names | ||
encoder_param_names = filter( | ||
lambda var_name: isinstance(encoder_program.block(0).var(var_name), | ||
fluid.framework.Parameter), | ||
encoder_var_names) | ||
encoder_params = map(encoder_program.block(0).var, encoder_param_names) | ||
decoder_var_names = [] | ||
for op in decoder_program.block(0).ops: | ||
decoder_var_names += op.input_arg_names | ||
decoder_param_names = filter( | ||
lambda var_name: isinstance(decoder_program.block(0).var(var_name), | ||
fluid.framework.Parameter), | ||
decoder_var_names) | ||
decoder_params = map(decoder_program.block(0).var, decoder_param_names) | ||
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params) | ||
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) | ||
|
||
# This is used here to set dropout to the test mode. | ||
encoder_program = fluid.io.get_inference_program( | ||
target_vars=[enc_output], main_program=encoder_program) | ||
decoder_program = fluid.io.get_inference_program( | ||
target_vars=[predict], main_program=decoder_program) | ||
|
||
test_data = paddle.batch( | ||
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size, | ||
ModelHyperParams.trg_vocab_size), | ||
batch_size=InferTaskConfig.batch_size) | ||
|
||
trg_idx2word = paddle.dataset.wmt16.get_dict( | ||
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) | ||
|
||
for batch_id, data in enumerate(test_data()): | ||
batch_seqs, batch_scores = translate_batch( | ||
exe, [item[0] for item in data], encoder_program, | ||
encoder_input_data_names, [enc_output.name], decoder_program, | ||
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, | ||
InferTaskConfig.max_length, InferTaskConfig.n_best, | ||
InferTaskConfig.batch_size, ModelHyperParams.n_head, | ||
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, | ||
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) | ||
for i in range(len(batch_seqs)): | ||
seqs = batch_seqs[i] | ||
scores = batch_scores[i] | ||
for seq in seqs: | ||
print(" ".join([trg_idx2word[idx] for idx in seq])) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.