Skip to content

Commit

Permalink
Merge pull request #701 from guoshengCS/add-transformer-initializer
Browse files Browse the repository at this point in the history
Add initializer for Transformer.
  • Loading branch information
lcy-seso authored Mar 12, 2018
2 parents 131f0ba + a9159a8 commit 3b54986
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
70 changes: 48 additions & 22 deletions fluid/neural_machine_translation/transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial
import numpy as np

import paddle.v2 as paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers

Expand Down Expand Up @@ -31,7 +30,7 @@ def multi_head_attention(queries,
d_key,
d_value,
d_model,
num_heads=1,
n_head=1,
dropout_rate=0.):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
Expand All @@ -42,41 +41,53 @@ def multi_head_attention(queries,
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")

def __compute_qkv(queries, keys, values, num_heads, d_key, d_value):
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * num_heads,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * num_heads,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * num_heads,
size=d_value * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_value,
fan_out=n_head * d_value),
bias_attr=False,
num_flatten_dims=2)
return q, k, v

def __split_heads(x, num_heads):
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, num_heads * hidden_dim] then output a tensor
with shape [bs, num_heads, max_sequence_length, hidden_dim].
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if num_heads == 1:
if n_head == 1:
return x

hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size.
reshaped = layers.reshape(
x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads])
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])

# permuate the dimensions into:
# [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

def __combine_heads(x):
Expand All @@ -95,7 +106,7 @@ def __combine_heads(x):
shape=map(int,
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))

def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
Scaled Dot-Product Attention
"""
Expand All @@ -114,7 +125,7 @@ def __softmax(x, eps=1e-9):
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

scaled_q = layers.scale(x=q, scale=d_key**-0.5)
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = __softmax(layers.elementwise_add(x=product, y=attn_bias))
if dropout_rate:
Expand All @@ -123,20 +134,21 @@ def __softmax(x, eps=1e-9):
out = layers.matmul(weights, v)
return out

q, k, v = __compute_qkv(queries, keys, values, num_heads, d_key, d_value)
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)

q = __split_heads(q, num_heads)
k = __split_heads(k, num_heads)
v = __split_heads(v, num_heads)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)

ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate)

out = __combine_heads(ctx_multiheads)

# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2)
return proj_out
Expand All @@ -151,8 +163,14 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_hid**-0.5), high=(d_hid**-0.5)),
act="relu")
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
return out


Expand All @@ -168,7 +186,11 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1)
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False)
Expand All @@ -195,7 +217,10 @@ def prepare_encoder(src_word,
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx)
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
Expand Down Expand Up @@ -462,6 +487,7 @@ def transformer(
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
Expand Down
4 changes: 2 additions & 2 deletions fluid/neural_machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main():
paddle.reader.shuffle(
paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=51200),
buf_size=100000),
batch_size=TrainTaskConfig.batch_size)

# Initialize the parameters.
Expand Down Expand Up @@ -143,7 +143,7 @@ def main():
fetch_list=[cost])
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" avg_cost = " + str(cost_val))
" cost = " + str(cost_val))


if __name__ == "__main__":
Expand Down

0 comments on commit 3b54986

Please sign in to comment.