Skip to content

Commit

Permalink
Merge pull request #98 from luotao1/beam
Browse files Browse the repository at this point in the history
update beam_search and seqToseq config, and add ExpActivation api
  • Loading branch information
lcy-seso committed Sep 20, 2016
2 parents 425e5b0 + d2e1b46 commit 04876d0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 99 deletions.
22 changes: 9 additions & 13 deletions demo/seqToseq/seqToseq_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
return out

decoder_group_name = "decoder_group"
group_inputs=[StaticInput(input=encoded_vector,is_seq=True),
StaticInput(input=encoded_proj,is_seq=True)]

if not is_generating:
trg_embedding = embedding_layer(
input=data_layer(name='target_language_word',
size=target_dict_dim),
size=word_vector_dim,
param_attr=ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding)

# For decoder equipped with attention mechanism, in training,
# target embeding (the groudtruth) is the data input,
Expand All @@ -142,22 +146,13 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
# for the recurrent_group.
decoder = recurrent_group(name=decoder_group_name,
step=gru_decoder_with_attention,
input=[
StaticInput(input=encoded_vector,
is_seq=True),
StaticInput(input=encoded_proj,
is_seq=True), trg_embedding
])
input=group_inputs)

lbl = data_layer(name='target_language_next_word',
size=target_dict_dim)
cost = classification_cost(input=decoder, label=lbl, )
cost = classification_cost(input=decoder, label=lbl)
outputs(cost)
else:
gen_inputs = [StaticInput(input=encoded_vector,
is_seq=True),
StaticInput(input=encoded_proj,
is_seq=True), ]
# In generation, the decoder predicts a next target word based on
# the encoded source sequence and the last generated target word.

Expand All @@ -171,10 +166,11 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
size=target_dict_dim,
embedding_name='_target_language_embedding',
embedding_size=word_vector_dim)
gen_inputs.append(trg_embedding)
group_inputs.append(trg_embedding)

beam_gen = beam_search(name=decoder_group_name,
step=gru_decoder_with_attention,
input=gen_inputs,
input=group_inputs,
id_input=data_layer(name="sent_id",
size=1),
dict_file=trg_dict_path,
Expand Down
7 changes: 7 additions & 0 deletions doc/ui/api/trainer_config_helpers/activations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ AbsActivation
:members: AbsActivation
:noindex:

ExpActivation
===============

.. automodule:: paddle.trainer_config_helpers.activations
:members: ExpActivation
:noindex:

IdentityActivation
==================

Expand Down
127 changes: 42 additions & 85 deletions paddle/trainer/tests/sample_trainer_rnn_gen.conf
Original file line number Diff line number Diff line change
Expand Up @@ -13,96 +13,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.

import math
from paddle.trainer_config_helpers import *

beam_search = get_config_arg('beam_search', bool, False)

model_type("recurrent_nn")

Settings(learning_rate=0, batch_size=15, algorithm='sgd')

Inputs("sent_id", "dummy_data_input")
Outputs("predict_word")
settings(batch_size=15, learning_rate=0)

num_words = 5
beam_flag = get_config_arg('beam_search', bool, False)

DataLayer(name="sent_id", size=1, )
sent_id = data_layer(name="sent_id", size=1)

# This layer has no actual use, but only to decide batch_size in generation.
# When generating, at least one Memory in RecurrentLayer MUST have a boot layer.
DataLayer(name="dummy_data_input", size=2, )

if beam_search:
RecurrentLayerGroupBegin("decoding_layer_group",
in_links=[],
out_links=["predict_word"],
generator=Generator(max_num_frames=10,
beam_size=2,
num_results_per_sample=2, ))
else:
RecurrentLayerGroupBegin("decoding_layer_group",
in_links=[],
out_links=["predict_word"],
generator=Generator(max_num_frames=10, ))
dummy_memory = Memory(name="dummy_memory",
size=2,
boot_layer="dummy_data_input")
MixedLayer(name="dummy_memory",
size=2,
bias=False,
inputs=[IdentityProjection(dummy_memory)], )
state_memory = Memory(name="state",
size=num_words,
#boot_bias=True,
#boot_bias_active_type = "tanh",
)

predict_word_memory = Memory(name="predict_word",
size=num_words,
boot_with_const_id=0, )

MixedLayer(
name = "word_embedding",
size = num_words, # word embedding dim is the same as num_words in this test.
bias = False,
inputs = TableProjection(predict_word_memory,
initial_std=1,
learning_rate=0,
parameter_name="wordvec"))

Layer( # simplified RNN for testing
name="state",
type="mixed",
size=num_words,
bias=False,
inputs=[FullMatrixProjection("word_embedding",
parameter_name="transtable")])

Layer(name="output",
type="mixed",
size=num_words,
active_type="exponential",
bias=False,
inputs=TransposedFullMatrixProjection("state",
initial_std=1,
learning_rate=0,
parameter_name="wordvec"), )

Layer(name="predict_word", type="maxid", inputs=["output"], )

Layer(name="eos_check",
type="eos_id",
eos_id=num_words - 1,
inputs=["predict_word"], )
RecurrentLayerGroupEnd("decoding_layer_group")

Evaluator(name="answer_printer",
type="seq_text_printer",
dict_file="./trainer/tests/test_gen_dict.txt",
result_file="./trainer/tests/dump_text.test",
inputs=[
"sent_id",
"predict_word",
], )
dummy_data = data_layer(name="dummy_data_input", size=2)

gen_inputs = [StaticInput(input=dummy_data, size=2),
GeneratedInput(size=num_words,
embedding_name="wordvec",
embedding_size=num_words)]

def step(dummy_memory, predict_word):

# simplified RNN for testing
with mixed_layer(size=num_words) as layer:
layer += full_matrix_projection(input=predict_word,
param_attr=ParamAttr(name="transtable"))

with mixed_layer(size=num_words, act=ExpActivation()) as out:
out += trans_full_matrix_projection(input=layer,
param_attr=ParamAttr(name="wordvec"))

return out

beam_gen = beam_search(name="rnn_gen",
step=step,
input=gen_inputs,
id_input=sent_id,
dict_file="./trainer/tests/test_gen_dict.txt",
result_file="./trainer/tests/dump_text.test",
bos_id=0,
eos_id=num_words-1,
beam_size=2 if beam_flag else 1,
num_results_per_sample=2 if beam_flag else 1,
max_length=10)

#outputs(beam_gen)
# In this config, as dummy_data_input doesn't work on beam_gen (we can find dummy_memory
# is read-only memory, and isn't used by other layers of step), we show the Inputs and Outputs
# as follows. Note that "__beam_search_predict__" is the default output name of beam_search.
Inputs("sent_id","dummy_data_input")
Outputs("__beam_search_predict__")
11 changes: 10 additions & 1 deletion python/paddle/trainer_config_helpers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

__all__ = ["TanhActivation", "SigmoidActivation",
"SoftmaxActivation", "IdentityActivation", "LinearActivation",
'SequenceSoftmaxActivation',
'SequenceSoftmaxActivation', 'ExpActivation',
"ReluActivation", "BReluActivation", "SoftReluActivation", "STanhActivation",
"AbsActivation", "SquareActivation", "BaseActivation"]

Expand Down Expand Up @@ -185,3 +185,12 @@ class SquareActivation(BaseActivation):
"""

def __init__(self): BaseActivation.__init__(self, 'square', False)

class ExpActivation(BaseActivation):
"""
Exponential Activation.
.. math::
f(z) = e^z.
"""
def __init__(self): BaseActivation.__init__(self, 'exponential', False)

0 comments on commit 04876d0

Please sign in to comment.