Skip to content

Commit

Permalink
Add inference program for Transformer.
Browse files Browse the repository at this point in the history
  • Loading branch information
guoshengCS committed Mar 13, 2018
1 parent 3b54986 commit ff80721
Show file tree
Hide file tree
Showing 4 changed files with 496 additions and 156 deletions.
35 changes: 32 additions & 3 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ class TrainTaskConfig(object):
# the params for learning rate scheduling
warmup_steps = 4000

# the directory for saving inference models
model_dir = "transformer_model"


class InferTaskConfig(object):
use_gpu = False
# number of sequences contained in a mini-batch
batch_size = 1

# the params for beam search
beam_size = 5
max_length = 30
n_best = 1

# the directory for loading inference model
model_path = "transformer_model/pass_1.infer.model"


class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
Expand All @@ -33,6 +50,11 @@ class ModelHyperParams(object):
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size

# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1

# position value corresponding to the <pad> token.
pos_pad_idx = 0

Expand Down Expand Up @@ -64,14 +86,21 @@ class ModelHyperParams(object):
"src_pos_enc_table",
"trg_pos_enc_table", )

# Names of all data layers listed in order.
input_data_names = (
# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias", )

# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
"trg_word",
"trg_pos",
"src_slf_attn_bias",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )

# Names of label related data layers listed in order.
label_data_names = (
"lbl_word",
"lbl_weight", )
220 changes: 220 additions & 0 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import numpy as np

import paddle.v2 as paddle
import paddle.fluid as fluid

import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from train import pad_batch_data


def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
return_pos=True,
return_attn_bias=True,
return_max_len=True)
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]

# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size
next_ids = [[]] * batch_size
# Use beam_map to map the instance idx in batch to beam idx, since the
# size of feeded batch is changing.
beam_map = range(batch_size)

def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
seqs.append(seq)
return seqs

def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1])
# No need for trg_slf_attn_bias because of no paddings.
return trg_words, trg_pos, None, trg_src_attn_bias, enc_output

def update_dec_in_data(dec_in_data, next_ids, active_beams):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, len(next_ids[0]) + 1, 1])
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output

dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(
filter(lambda item: item[1] is not None,
zip(dec_in_names, dec_in_data))),
fetch_list=dec_out_names)[0]
predict_all = np.log(predict_all)
predict_all = (
predict_all.reshape(
[len(beam_map) * beam_size, i + 1, -1])[:, -1, :] +
scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
beam_map = active_beams
if len(beam_map) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)

# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]

return seqs, scores[:, :n_best].tolist()


def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)

model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)

# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)

# This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program(
target_vars=[enc_output], main_program=encoder_program)
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)

test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)

trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)

for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
InferTaskConfig.batch_size, ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))


if __name__ == "__main__":
main()
Loading

0 comments on commit ff80721

Please sign in to comment.