Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference program for Transformer. #727

Merged
merged 3 commits into from
Mar 21, 2018

Conversation

guoshengCS
Copy link
Collaborator

Add inference program for Transformer.

Copy link
Collaborator

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this work.

@@ -15,6 +15,23 @@ class TrainTaskConfig(object):
# the params for learning rate scheduling
warmup_steps = 4000

# the directory for saving inference models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for saving inference models --> for saving trained models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


class InferTaskConfig(object):
use_gpu = False
# number of sequences contained in a mini-batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • number of sequences contained in a mini-batch --> the number of examples in one run for sequence generation.
  • Please add a comment here to warn users currently the batch size can only be set to 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


class InferTaskConfig(object):
use_gpu = False
# number of sequences contained in a mini-batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of sequences contained in a mini-batch --> the number of examples in one run for sequence generation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# the params for beam search
beam_size = 5
max_length = 30
n_best = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please comment n_best. It is confusing to me about what is the difference between beam_size and n_best.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -15,6 +15,23 @@ class TrainTaskConfig(object):
# the params for learning rate scheduling
warmup_steps = 4000

# the directory for saving inference models
model_dir = "transformer_model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the name to "trained_models"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

max_length,
slf_attn_bias_flag,
src_attn_bias_flag,
pos_flag=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change "pos_flag" into "is_pos=True"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dtype="float32",
append_batch_size=False)
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 --> False

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the last two 1 --> True

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, 0, 0, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the last three parameters of make_inputs boolean parameters.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

name=input_data_names[2]
if slf_attn_bias_flag == 1 else input_data_names[-1],
shape=[batch_size, n_head, max_length, max_length]
if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make src_attn_bias_flag a boolean parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@lcy-seso lcy-seso merged commit ae792ec into PaddlePaddle:develop Mar 21, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants