diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2101f604d..554057f13 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,3 +2,4 @@ pyyaml==3.12 mxnet-mkl==1.3.0.post0 numpy>=1.14 typing +parsimonious diff --git a/sockeye/arguments.py b/sockeye/arguments.py index c457d7837..b61940a5f 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -671,6 +671,22 @@ def add_model_parameters(params): 'For example: n:drn ' 'Default: %(default)s.') + # Custom sequence encoder or decoder + model_params.add_argument('--custom-seq-encoder', + default='res(norm->mh_dot_att)->res(norm->ff->linear))', + help='Specify the layers the custom encoder will consist of.') + model_params.add_argument('--custom-seq-decoder', + default='res(norm->mh_dot_self_att)->res(norm->mh_dot_att)->res(norm->ff->linear))', + help='Specify the layers the custom decoder will consist of.') + model_params.add_argument('--custom-seq-num-hidden', + type=int_greater_or_equal(1), + default=512, + help='Number of hidden units for encoder and decoder. Default: %(default)s.') + model_params.add_argument('--custom-seq-dropout', + type=float, + default=.1, + help='Dropout used throughout the custom encoder and decoder. Default: %(default)s.') + # LHUC # TODO: The convolutional model does not support lhuc yet model_params.add_argument('--lhuc', diff --git a/sockeye/constants.py b/sockeye/constants.py index 0473bde73..8bf802c59 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -56,14 +56,15 @@ RNN_WITH_CONV_EMBED_NAME = "rnn-with-conv-embed" TRANSFORMER_TYPE = "transformer" CONVOLUTION_TYPE = "cnn" +CUSTOM_SEQ_TYPE = "custom-seq" TRANSFORMER_WITH_CONV_EMBED_TYPE = "transformer-with-conv-embed" IMAGE_PRETRAIN_TYPE = "image-pretrain-cnn" # available encoders -ENCODERS = [RNN_NAME, RNN_WITH_CONV_EMBED_NAME, TRANSFORMER_TYPE, TRANSFORMER_WITH_CONV_EMBED_TYPE, CONVOLUTION_TYPE, IMAGE_PRETRAIN_TYPE] +ENCODERS = [RNN_NAME, RNN_WITH_CONV_EMBED_NAME, TRANSFORMER_TYPE, TRANSFORMER_WITH_CONV_EMBED_TYPE, CONVOLUTION_TYPE, IMAGE_PRETRAIN_TYPE, CUSTOM_SEQ_TYPE] # available decoder -DECODERS = [RNN_NAME, TRANSFORMER_TYPE, CONVOLUTION_TYPE] +DECODERS = [RNN_NAME, TRANSFORMER_TYPE, CONVOLUTION_TYPE, CUSTOM_SEQ_TYPE] # rnn types LSTM_TYPE = 'lstm' @@ -133,6 +134,7 @@ WEIGHT_TYING_SRC_TRG = 'src_trg' WEIGHT_TYING_SRC_TRG_SOFTMAX = 'src_trg_softmax' +CUSTOM_SEQ_PREFIX = "custom_seq_" # default decoder prefixes RNN_DECODER_PREFIX = DECODER_PREFIX + "rnn_" TRANSFORMER_DECODER_PREFIX = DECODER_PREFIX + "transformer_" @@ -149,7 +151,8 @@ # Swish-1/SiLU (https://arxiv.org/pdf/1710.05941.pdf, https://arxiv.org/pdf/1702.03118.pdf) SWISH1 = "swish1" TANH = "tanh" -TRANSFORMER_ACTIVATION_TYPES = [GELU, RELU, SWISH1] +NO_ACTIVATION = "none" +TRANSFORMER_ACTIVATION_TYPES = [GELU, RELU, SWISH1, NO_ACTIVATION] CNN_ACTIVATION_TYPES = [GLU, RELU, SIGMOID, SOFT_RELU, TANH] # Convolutional block pad types: @@ -394,6 +397,9 @@ PREPARED_DATA_VERSION_FILE = "data.version" PREPARED_DATA_VERSION = 2 +SEQUENCE_LENGTH_MUST_NOT_CHANGE_MSG = "Sequence length may not change within the residual layers." + + # reranking RERANK_BLEU = "bleu" RERANK_CHRF = "chrf" diff --git a/sockeye/convolution.py b/sockeye/convolution.py index 7924e9cbb..d7dc712a8 100644 --- a/sockeye/convolution.py +++ b/sockeye/convolution.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -18,8 +18,10 @@ from . import utils from . import constants as C from . import layers +from typing import List, Tuple, Dict, Optional, Sequence import mxnet as mx +import math class ConvolutionConfig(Config): @@ -29,19 +31,32 @@ class ConvolutionConfig(Config): :param kernel_width: Kernel size for 1D convolution. :param num_hidden: Size of hidden representation after convolution. :param act_type: The type of activation to use. + :param dropout: The dropout rate. + :param dilate: Dilation rate. + :param stride: Kernel window stride. + :param weight_normalization: If True weight normalization is applied. """ def __init__(self, kernel_width: int, num_hidden: int, act_type: str = C.GLU, + dropout: float = 0.0, + dilate: int = 1, + stride: int = 1, weight_normalization: bool = False) -> None: super().__init__() self.kernel_width = kernel_width self.num_hidden = num_hidden utils.check_condition(act_type in C.CNN_ACTIVATION_TYPES, "Unknown activation %s." % act_type) self.act_type = act_type + self.dropout = dropout self.weight_normalization = weight_normalization + self.dilate = dilate + self.stride = stride + + def effective_kernel_size(self): + return self.kernel_width + (self.dilate - 1) * (self.kernel_width - 1) class ConvolutionBlock: @@ -67,6 +82,7 @@ def __init__(self, self.conv_weight = mx.sym.Variable("%sconv_weight" % prefix, shape=( self._pre_activation_num_hidden(), + # TODO: the next parameter should be input_num_hidden self.config.num_hidden, self.config.kernel_width) ) @@ -89,7 +105,7 @@ def _pre_activation_num_hidden(self): def __call__(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> mx.sym.Symbol: + seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Run the convolutional block. @@ -98,37 +114,71 @@ def __call__(self, :param seq_len: Maximum sequence length. :return: Shape: (batch_size, seq_len, num_hidden). """ + # TODO: rethink when we really need masking... + # Apply masking (so that we properly have zero padding for variable sequence length batches) + # (seq_len, batch_size, num_hidden) + data = mx.sym.SequenceMask(data=data, axis=1, sequence_length=data_length, use_sequence_length=True, value=0) + + # (batch_size, num_hidden, seq_len) + data = mx.sym.transpose(data, axes=(0, 2, 1)) + if self.pad_type == C.CNN_PAD_LEFT: + # TODO (tdomhan): Implement striding with left-padding + assert self.config.stride == 1, "Striding currently not supported with left padding." # we pad enough on both sides and later slice the extra padding from the right - padding = (self.config.kernel_width - 1,) + padding = self.config.effective_kernel_size() - 1 + # TODO: potentially remove zero-padding elif self.pad_type == C.CNN_PAD_CENTERED: - # we pad enough so that the output size is equal to the input size and we don't need to slice - utils.check_condition(self.config.kernel_width % 2 == 1, - "Only odd kernel widths supported, but got %d" % self.config.kernel_width) - padding = (int((self.config.kernel_width - 1) / 2),) + # we pad enough so that the output sizeis equal to the input size and we don't need to slice + utils.check_condition(self.config.effective_kernel_size() % 2 == 1, + "Only odd kernel widths supported, but got %d" % self.config.effective_kernel_size()) + padding = int((self.config.effective_kernel_size() - 1) / 2) + + seq_len_padded = seq_len + padding * 2 + + stride = self.config.stride + if stride > 1: + if seq_len_padded % stride != 0: + pad_after = stride - (seq_len_padded % stride) + # pad to the right so that stride int divides the time axis + # temporary 4d due to pad op constraint + data = mx.sym.expand_dims(data, axis=3) + data = mx.sym.pad(data=data, + mode="constant", + constant_value=0, + pad_width=(0, 0, + 0, 0, + 0, pad_after, + 0, 0)) + data = mx.sym.reshape(data, shape=(0, 0, -1)) + data_length = data_length + pad_after + seq_len = seq_len + pad_after + + # formula is: floor((x+2*p-k)/s)+1 + # with 2p = k - 1 we get: floor((x-1)/s)+1 + + data_length = mx.sym.BlockGrad(mx.sym.floor((data_length - 1) / stride) + 1) + seq_len = int(math.floor((seq_len - 1) / stride)) + 1 else: raise ValueError("Unknown pad type %s" % self.pad_type) num_hidden = self._pre_activation_num_hidden() - # Apply masking (so that we properly have zero padding for variable sequence length batches) - data = mx.sym.SequenceMask(data=data, axis=1, sequence_length=data_length, use_sequence_length=True, value=0) - - # (batch_size, num_hidden, seq_len) - data = mx.sym.transpose(data, axes=(0, 2, 1)) data_conv = mx.sym.Convolution(data=data, weight=self.conv_weight, bias=self.conv_bias, - pad=padding, + pad=(padding,), kernel=(self.config.kernel_width,), + stride=self.config.stride, num_filter=num_hidden, + dilate=(self.config.dilate,), layout="NCW") # (batch_size, 2 * num_hidden, seq_len) if self.pad_type == C.CNN_PAD_LEFT: data_conv = mx.sym.slice_axis(data=data_conv, axis=2, begin=0, end=seq_len) - return self._post_convolution(data_conv) + return self._post_convolution(data_conv), data_length, seq_len def step(self, data): """ @@ -137,30 +187,44 @@ def step(self, data): :param data: Shape: (batch_size, kernel_width, num_hidden). :return: Single result of a convolution. Shape: (batch_size, 1, num_hidden). """ - - # As we only run convolution over a single window that is exactly the size of the convolutional filter - # we can use FullyConnected instead of Convolution for efficiency reasons. Additionally we do not need to - # perform any masking. + assert self.config.stride == 1, "Striding not supported on the target side" num_hidden = self._pre_activation_num_hidden() + if self.config.dilate != 1: + # (batch_size, num_hidden, kernel_width) + data = mx.sym.swapaxes(data, dim1=1, dim2=2) - # (batch_size, num_hidden, kernel_width) - data = mx.sym.swapaxes(data, dim1=1, dim2=2) - # (batch_size, num_hidden * kernel_width) - data = mx.sym.reshape(data, shape=(0, -3)) - # (preact_num_hidden, num_hidden * kernel_width) - weight = mx.sym.reshape(self.conv_weight, shape=(0, -3)) - data_conv = mx.sym.FullyConnected(data=data, - weight=weight, - bias=self.conv_bias, - num_hidden=num_hidden) - # (batch_size, num_hidden, 1) - data_conv = mx.sym.expand_dims(data_conv, axis=2) - return self._post_convolution(data_conv) + # (batch_size, num_hidden, 1) + data_conv = mx.sym.Convolution(data=data, + weight=self.conv_weight, + bias=self.conv_bias, + kernel=(self.config.kernel_width,), + num_filter=num_hidden, + dilate=(self.config.dilate,), + layout="NCW") + + return self._post_convolution(data_conv) + else: + # As we only run convolution over a single window that is exactly the size of the convolutional filter + # we can use FullyConnected instead of Convolution for efficiency reasons. Additionally we do not need to + # perform any masking. + + # (batch_size, num_hidden, kernel_width) + data = mx.sym.swapaxes(data, dim1=1, dim2=2) + # (batch_size, num_hidden * kernel_width) + data = mx.sym.reshape(data, shape=(0, -3)) + # (preact_num_hidden, num_hidden * kernel_width) + weight = mx.sym.reshape(self.conv_weight, shape=(0, -3)) + data_conv = mx.sym.FullyConnected(data=data, + weight=weight, + bias=self.conv_bias, + num_hidden=num_hidden) + # (batch_size, num_hidden, 1) + data_conv = mx.sym.expand_dims(data_conv, axis=2) + return self._post_convolution(data_conv) def _post_convolution(self, data_conv): # data_conv: (batch_size, pre_activation_num_hidden, seq_len) - # TODO: add layer norm (can we do this without reshaping?!) if self.config.act_type == C.GLU: # GLU @@ -170,11 +234,445 @@ def _post_convolution(self, data_conv): # (batch_size, num_hidden, seq_len) block_output = mx.sym.broadcast_mul(gate_a, mx.sym.Activation(data=gate_b, act_type="sigmoid")) + # TODO: use the activation function from layers.py + elif self.config.act_type == "none": + block_output = data_conv else: # (batch_size, num_hidden, seq_len) block_output = mx.sym.Activation(data_conv, act_type=self.config.act_type) # (batch_size, seq_len, num_hidden) block_output = mx.sym.swapaxes(block_output, dim1=1, dim2=2) + + if self.config.dropout > 0.0: + block_output = mx.sym.Dropout(block_output, p=self.config.dropout) + return block_output +# TODO: encoder side left-padding +class ConvolutionalEncoderLayer(layers.EncoderLayer): + def __init__(self, cnn_config: ConvolutionConfig, prefix: str) -> None: + self.prefix = prefix + self.cnn_config = cnn_config + self.cnn_block = ConvolutionBlock(self.cnn_config, pad_type=C.CNN_PAD_CENTERED, prefix=prefix) + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + return self.cnn_block(source_encoded, source_encoded_lengths, source_encoded_max_length) + + def get_encoded_seq_len(self, seq_len: int): + stride = self.cnn_config.stride + if stride == 1: + return seq_len + else: + padding = int((self.cnn_config.effective_kernel_size() - 1) / 2) + seq_len_padded = seq_len + 2 * padding + if seq_len_padded % stride != 0: + pad_after = stride - (seq_len_padded % stride) + seq_len = seq_len + pad_after + return int(math.floor((seq_len - 1) / stride)) + 1 + + def get_num_hidden(self) -> int: + return self.cnn_config.num_hidden + + +class ConvolutionalDecoderLayer(layers.DecoderLayer): + + def __init__(self, input_num_hidden: int, cnn_config: ConvolutionConfig, prefix: str) -> None: + self.input_num_hidden = input_num_hidden + self.prefix = prefix + self.cnn_config = cnn_config + self.cnn_block = ConvolutionBlock(self.cnn_config, pad_type=C.CNN_PAD_LEFT, prefix=prefix) + + def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + # TODO: for the decoder we don't actually need the masking operation ... + return self.cnn_block(target_encoded, target_encoded_lengths, target_encoded_max_length)[0] + + def decode_step(self, step: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, states: Sequence[mx.sym.Symbol], att_dict) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + # (batch_size, kernel_width - 1, num_hidden) + prev_target = states[0] + + # target: (batch_size, num_hidden) -> (batch_size, 1, num_hidden) + target = mx.sym.expand_dims(target, axis=1) + + # (batch_size, kernel_width, num_hidden) + target = mx.sym.concat(prev_target, target, dim=1) + + # (batch_size, kernel_width, num_hidden) -> (batch_size, 1, num_hidden) + out = self.cnn_block.step(target) + + # arg: (batch_size, kernel_width - 1, num_hidden). + new_prev_target = mx.sym.slice_axis(data=target, axis=1, begin=1, end=self.cnn_config.effective_kernel_size()) + + out = mx.sym.reshape(out, shape=(0, -1)) + + return out, [new_prev_target] + + def num_states(self, step) -> int: + return 1 + + def state_variables(self, step: int): + return [mx.sym.Variable(name="%s_conv_state" % self.prefix)] + + def state_shapes(self, + batch_size: int, + target_max_length: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int): + input_num_hidden = self.input_num_hidden + kernel_width = self.cnn_config.effective_kernel_size() + return [mx.io.DataDesc("%s_conv_state" % self.prefix, + shape=(batch_size, kernel_width - 1, input_num_hidden), + layout="NTW")] + + def init_states(self, + batch_size: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int): + input_num_hidden = self.input_num_hidden + kernel_width = self.cnn_config.effective_kernel_size() + return [mx.sym.zeros(shape=(batch_size, kernel_width - 1, input_num_hidden), + name="%s_conv_state" % self.prefix)] + + def get_num_hidden(self) -> int: + return self.cnn_config.num_hidden + + +class ConvolutionalLayerConfig(layers.LayerConfig): + + def __init__(self, + num_hidden: int, + kernel_width: int = 3, + act_type: str = C.GLU, + dropout: float = 0.0, + dilate: int = 1, + stride: int = 1, + prefix: str="") -> None: + super().__init__() + self.num_hidden = num_hidden + self.kernel_width = kernel_width + self.act_type = act_type + self.dropout=dropout + self.dilate = dilate + self.stride = stride + self.prefix = prefix + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> layers.EncoderLayer: + cnn_config = ConvolutionConfig(kernel_width=self.kernel_width, + num_hidden=self.num_hidden, + dropout=self.dropout, + dilate=self.dilate, + stride=self.stride, + act_type=self.act_type) + return ConvolutionalEncoderLayer(cnn_config, prefix=prefix + "cnn_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> layers.DecoderLayer: + assert self.stride == 1, "Stride only supported on the encoder side." + cnn_config = ConvolutionConfig(kernel_width=self.kernel_width, + num_hidden=self.num_hidden, + dropout=self.dropout, + dilate=self.dilate, + act_type=self.act_type) + return ConvolutionalDecoderLayer(input_num_hidden=input_num_hidden, cnn_config=cnn_config, + prefix=prefix + "cnn_") + + +class PoolingEncoderLayer(layers.EncoderLayer): + """ + Pooling operating with a given stride and kernel. Sequences are extended to the right to cover sequences that + are not a multiple of the stride. + """ + + def __init__(self, num_hidden, stride: int = 3, kernel: int = 3, pool_type: str = "avg") -> None: + self.pool_type = pool_type + self.stride = stride + self.kernel = kernel + self.num_hidden = num_hidden + + def encode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + att_dict: Dict[str, mx.sym.Symbol]) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + # source_encoded: (batch_size, seq_len, num_hidden) -> (batch_size, num_hidden, seq_len) + source_encoded = mx.sym.transpose(source_encoded, axes=(0, 2, 1)) + # (batch_size, num_hidden, seq_len, 1) + source_encoded = mx.sym.expand_dims(source_encoded, axis=3) + + if self.kernel > source_encoded_max_length: + pad_after = self.kernel - source_encoded_max_length + source_encoded = mx.sym.pad(data=source_encoded, + mode="constant", + constant_value=0, + pad_width=(0, 0, + 0, 0, + 0, pad_after, + 0, 0)) + + # (batch_size, num_hidden, seq_len/stride, 1) + source_encoded = mx.sym.Pooling(data=source_encoded, + pool_type=self.pool_type, + pooling_convention='full', + kernel=(self.kernel, 1), + stride=(self.stride, 1)) + + # (batch_size, num_hidden, seq_len/stride) + source_encoded = mx.sym.reshape(source_encoded, shape=(0, 0, -1)) + source_encoded = mx.sym.transpose(source_encoded, axes=(0, 2, 1)) + + source_encoded_lengths = mx.sym.BlockGrad(mx.sym.ceil((source_encoded_lengths - self.kernel) / self.stride) + 1) + source_encoded_max_length = self.get_encoded_seq_len(source_encoded_max_length) + return source_encoded, source_encoded_lengths, source_encoded_max_length + + def get_num_hidden(self): + return self.num_hidden + + def get_encoded_seq_len(self, seq_len: int): + # if the sequence is not as large as the kernel it is padded to the kernel size: + seq_len = max(seq_len, self.kernel) + return int(math.ceil((seq_len - self.kernel) / self.stride)) + 1 + + +class PoolingLayerConfig(layers.LayerConfig): + + def __init__(self, stride: int = 3, kernel: Optional[int] = None, pool_type: str = "avg", prefix: str="") -> None: + super().__init__() + self.stride = stride + self.kernel = kernel if kernel is not None else stride + self.pool_type = pool_type + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> layers.EncoderLayer: + return PoolingEncoderLayer(num_hidden=input_num_hidden, stride=self.stride, kernel=self.kernel, + pool_type=self.pool_type) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> layers.DecoderLayer: + raise NotImplementedError("Pooling only available on the encoder side.") + + +class QRNNBlock: + """ + Implements Quasi-recurrent neural networks as described by Bradbury, James, et al. "Quasi-recurrent neural + networks." arXiv preprint arXiv:1611.01576 (2016). + + QRNNs do not have any recurrency in calculating the gates but rather use convolutions. We implement the f-pooling + variant so that the hidden states are calculated as + h_t = f_t * h_{t-1} + (1 - f_t) * z_t, + where f is the forget gate and z the input. + """ + + def __init__(self, + num_hidden: int, + input_num_hidden: int, + kernel_width: int, + act_type: str = "tanh", + prefix: str = "") -> None: + self.num_hidden = num_hidden + self.kernel_width = kernel_width + self.act_type = act_type + num_out = 2 + self.conv_weight = mx.sym.Variable("%sconv_weight" % prefix, shape=(num_out * num_hidden, + input_num_hidden, + kernel_width)) + self.conv_bias = mx.sym.Variable("%sconv_bias" % prefix) + + def __call__(self, + data: mx.sym.Symbol, + data_length: mx.sym.Symbol, + seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + # (batch_size, seq_len, num_hidden) -> (batch_size, num_hidden, seq_len) + data = mx.sym.transpose(data, axes=(0, 2, 1)) + + padding = self.kernel_width - 1 + num_out = 2 + # (batch_size, 2 * num_hidden, left_pad + seq_len) + data_conv = mx.sym.Convolution(data=data, + weight=self.conv_weight, + bias=self.conv_bias, + pad=(padding,), + kernel=(self.kernel_width,), + num_filter=num_out * self.num_hidden, + layout="NCW") + + # (batch_size, 2 * num_hidden, seq_len) + data_conv = mx.sym.slice_axis(data=data_conv, axis=2, begin=0, end=seq_len) + + # (batch_size, seq_len, 2 * num_hidden) + data_conv = mx.sym.transpose(data_conv, axes=(0, 2, 1)) + + # 2 * (batch_size, seq_len, num_hidden) + # pylint: disable=unbalanced-tuple-unpacking + out, f_gates = mx.sym.split(data_conv, num_outputs=2, axis=2) + out = mx.sym.Activation(data=out, act_type=self.act_type) + f_gates = mx.sym.Activation(data=f_gates, act_type="sigmoid") + + gated_out = mx.sym.broadcast_mul(1 - f_gates, out) + + # accumulate hidden state + hidden = mx.sym.zeros(shape=(0, self.num_hidden)) + hiddens = [] + + for f_gate, out in zip(mx.sym.split(f_gates, num_outputs=seq_len, axis=1, squeeze_axis=True), + mx.sym.split(gated_out, num_outputs=seq_len, axis=1, squeeze_axis=True)): + hidden = f_gate * hidden + out + hiddens.append(mx.sym.expand_dims(hidden, axis=1)) + # (batch_size, seq_len, num_hidden) + hiddens = mx.sym.concat(*hiddens, dim=1) + return hiddens, data_length, seq_len + + def step(self, data: mx.sym.Symbol, prev_h: mx.sym.Symbol): + """ + Run the qrnn cell over a single position. The data must be exactly as wide as the convolution filters. + + :param data: Shape: (batch_size, kernel_width, num_hidden). + :param prev_h: The previous hidden state, Shape: (batch_size, num_hidden). + :return: Single result of a convolution. Shape: (batch_size, 1, num_hidden). + """ + # (batch_size, num_hidden, kernel_width) + data = mx.sym.swapaxes(data, dim1=1, dim2=2) + # (batch_size, num_hidden * kernel_width) + data = mx.sym.reshape(data, shape=(0, -3)) + # (preact_num_hidden, num_hidden * kernel_width) + weight = mx.sym.reshape(self.conv_weight, shape=(0, -3)) + num_out = 2 + + # (batch_size, 2 * num_hidden) + data_conv = mx.sym.FullyConnected(data=data, + weight=weight, + bias=self.conv_bias, + num_hidden=num_out * self.num_hidden) + # TODO: refactor post FC code into a function to be shared by decode_step and decode_sequence + # pylint: disable=unbalanced-tuple-unpacking + out, f_gates = mx.sym.split(data_conv, num_outputs=2, axis=1) + + out = mx.sym.Activation(data=out, act_type=self.act_type) + f_gate = mx.sym.Activation(data=f_gates, act_type="sigmoid") + + curr_h = f_gate * prev_h + (1.0 - f_gate) * out + return curr_h + + +class QRNNDecoderLayer(layers.DecoderLayer): + """ + QRNN implemented with masked (left-padded) convolutions. + """ + + def __init__(self, num_hidden: int, input_num_hidden: int, kernel_width: int, act_type: str = "tanh", + prefix: str = "") -> None: + self.num_hidden = num_hidden + self.input_num_hidden = input_num_hidden + self.prefix = prefix + self.qrnn = QRNNBlock(num_hidden=num_hidden, input_num_hidden=input_num_hidden, + kernel_width=kernel_width, act_type=act_type, prefix=prefix) + + def decode_sequence(self, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + return self.qrnn(target_encoded, target_encoded_lengths, target_encoded_max_length)[0] + + def decode_step(self, step: int, source_encoded: Sequence[mx.sym.Symbol], source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, states: Sequence[mx.sym.Symbol], + att_dict: dict) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + # (batch_size, kernel_width - 1, num_hidden) + prev_h = states[0] + prev_target = states[1] + + # target: (batch_size, num_hidden) -> (batch_size, 1, num_hidden) + target = mx.sym.expand_dims(target, axis=1) + + # (batch_size, kernel_width, num_hidden) + target = mx.sym.concat(prev_target, target, dim=1) + + # (batch_size, kernel_width, num_hidden) -> (batch_size, num_hidden) + out = self.qrnn.step(target, prev_h) + + # arg: (batch_size, kernel_width - 1, num_hidden). + new_prev_target = mx.sym.slice_axis(data=target, axis=1, begin=1, end=self.qrnn.kernel_width) + + return out, [out, new_prev_target] + + def get_num_hidden(self) -> int: + return self.num_hidden + + def reset(self): + pass + + def num_states(self, step: int) -> int: + return 2 + + def state_variables(self, step: int) -> Sequence[mx.sym.Symbol]: + return [mx.sym.Variable(name="%s_qrnn_prev_h" % self.prefix), + mx.sym.Variable(name="%s_qrnn_in_state" % self.prefix)] + + def init_states(self, + batch_size: int, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int) -> Sequence[mx.sym.Symbol]: + input_num_hidden = self.input_num_hidden + kernel_width = self.qrnn.kernel_width + return [mx.sym.zeros(shape=(batch_size, self.qrnn.num_hidden), + name="%s_qrnn_prev_h" % self.prefix), + mx.sym.zeros(shape=(batch_size, kernel_width - 1, input_num_hidden), + name="%s_qrnn_in_state" % self.prefix)] + + def state_shapes(self, + batch_size: int, + target_max_length: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int) -> List[mx.io.DataDesc]: + input_num_hidden = self.input_num_hidden + kernel_width = self.qrnn.kernel_width + return [mx.io.DataDesc("%s_qrnn_prev_h" % self.prefix, + shape=(batch_size, self.qrnn.num_hidden), + layout="NTW"), + mx.io.DataDesc("%s_qrnn_in_state" % self.prefix, + shape=(batch_size, kernel_width - 1, input_num_hidden), + layout="NTW")] + + +class QRNNEncoderLayer(layers.EncoderLayer): + """ + QRNN encoder with f-pooling. + """ + + def __init__(self, num_hidden: int, input_num_hidden: int, + kernel_width: int, act_type: str = "tanh", prefix: str = "") -> None: + self.num_hidden = num_hidden + self.qrnn = QRNNBlock(num_hidden=num_hidden, input_num_hidden=input_num_hidden, + kernel_width=kernel_width, act_type=act_type, prefix=prefix) + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict: Dict[str, mx.sym.Symbol]) -> Tuple[ + mx.sym.Symbol, mx.sym.Symbol, int]: + return self.qrnn(source_encoded, source_encoded_lengths, source_encoded_max_length) + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class QRNNLayerConfig(layers.LayerConfig): + + def __init__(self, num_hidden: int, kernel_width: int = 3, act_type: str = "tanh") -> None: + super().__init__() + self.num_hidden = num_hidden + self.kernel_width = kernel_width + self.act_type = act_type + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> layers.EncoderLayer: + return QRNNEncoderLayer(num_hidden=self.num_hidden, input_num_hidden=input_num_hidden, + kernel_width=self.kernel_width, act_type=self.act_type, prefix=prefix + "qrnn_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> layers.DecoderLayer: + return QRNNDecoderLayer(num_hidden=self.num_hidden, input_num_hidden=input_num_hidden, + kernel_width=self.kernel_width, act_type=self.act_type, prefix=prefix + "qrnn_") + diff --git a/sockeye/coverage.py b/sockeye/coverage.py index 5b0925cee..51c7e4a84 100644 --- a/sockeye/coverage.py +++ b/sockeye/coverage.py @@ -227,7 +227,7 @@ def __init__(self, # optional layer normalization self.layer_norm = None if layer_normalization and not self.num_hidden != 1: - self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix) + self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix, num_hidden=self.num_hidden) def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ diff --git a/sockeye/custom_seq_parser.py b/sockeye/custom_seq_parser.py new file mode 100644 index 000000000..82ecf1cfa --- /dev/null +++ b/sockeye/custom_seq_parser.py @@ -0,0 +1,345 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import inspect +from typing import List, Optional, Dict, Tuple, Any, Union + +from parsimonious import Grammar, NodeVisitor + +from . import convolution +from . import layers +from . import rnn +from . import utils + +# TODO: parallel should take a list of parallel layers (instead of separating them by ->)! so basically a list of layer_chains... + +# TODO: boolean argument parsing? + +custom_seq_grammar = Grammar(r""" +network = layer_chain +layer_chain = layer more_layers +more_layers = sep_layer* +sep_layer = sep layer +sep = "->" +layer = meta_layer / parallel_layer / repeat_layer / subsample_layer / standard_layer +open = "(" +close = ")" +empty_paren = open close + +repeat_layer = "repeat" open int comma layer_chain close +subsample_layer = "subsample" open optional_params layer_chain_sep layer_chain close + +standard_layer = standard_layer_name optional_parenthesis_params +standard_layer_name = ~"[a-z_A-Z]+" + +meta_layer = meta_layer_name open layer_chain close +meta_layer_name = "res" / "highway" + +parallel_layer = parallel_name open layer_chain more_layer_chains close +parallel_name = "parallel_add" / "parallel" +layer_chain_sep = "|" +separated_layer_chain = layer_chain_sep layer_chain +more_layer_chains = separated_layer_chain* + +optional_parenthesis_params = parenthesis_params? +parenthesis_params = open param maybe_more_params close +optional_params = params? +params = param maybe_more_params +maybe_more_params = comma_param* +comma_param = comma param +comma = ~", *" +optional_comma = comma? +param = kw_param / arg_param +kw_param = string "=" arg_param +arg_param = float / int / bool / string +string = ~"[a-z_0-9]+" +float = ~"-?[0-9]+\.([0-9]+)?" +int = ~"-?[0-9]+" +bool = "True" / "False" +""") + + +# TODO: better error messages?! +# TODO: create documentation for each layer... + +class CustomSeqParser(NodeVisitor): + def __init__(self): + super().__init__() + + def visit_parenthesis_params(self, node, rest): + open_paran, param, more_params, close_paran = rest + return [param] + more_params + + def visit_params(self, node, rest): + param, maybe_more_params = rest + return [param] + maybe_more_params + + def visit_float(self, node, param): + return float(node.text) + + def visit_int(self, node, param): + return int(node.text) + + def visit_string(self, node, param): + return node.text + + def visit_maybe_more_params(self, node, children): + return children + + def visit_comma_param(self, node, comma_param): + # split off the comma: + comma, param = comma_param + return param + + def visit_kw_param(self, node, kw_param): + key, sep, value = kw_param + return key, value + + def visit_arg_param(self, node, children): + # param always has a single child, which is either numeric or string + return children[0] + + def visit_param(self, node, children): + # param always has a single child, which is either numeric or string + return children[0] + + def visit_bool(self, node, children): + if node.text == "True": + return True + elif node.text == "False": + return False + else: + raise ValueError("%s is not a boolean." % node.text) + + def visit_optional_parenthesis_params(self, node, children): + if len(children) == 0: + # no parameters + return None + else: + # parameters are given, return parameter list + return children[0] + + def visit_optional_params(self, node, children): + if len(children) == 0: + # no parameters + return None + else: + # parameters are given, return parameter list + return children[0] + + def visit_standard_layer(self, node, name_and_params): + name, params = name_and_params + return {"name": name, "params": params} + + def visit_standard_layer_name(self, node, children): + return node.text + + def visit_meta_layer(self, node, children): + meta_layer_name, open_paran, layer_chain, close_paran = children + return {"name": meta_layer_name, "layers": layer_chain} + + def visit_meta_layer_name(self, node, children): + return node.text + + def visit_repeat_layer(self, node, children): + name, open_paran, num, comma, layer_chain, close = children + return {"name": "repeat", "num": num, "layers": layer_chain} + + def visit_subsample_layer(self, node, children): + name, open_paran, optional_params, sep, layer_chain, close_paran = children + return {"name": "subsample", "layers": layer_chain, "params": optional_params} + + def visit_layer(self, node, children): + return children[0] + + def visit_layer_chain(self, node, layer_chain): + layer, more_layers = layer_chain + return [layer] + more_layers + + def visit_more_layers(self, node, children): + return children + + def visit_sep_layer(self, node, sep_layer): + sep, layer = sep_layer + return layer + + def visit_more_layer_chains(self, node, children): + return children + + def visit_parallel_layer(self, node, parallel_layer): + name, open_paran, layer_chain, more_layer_chains, close_paran = parallel_layer + return {"name": name[0].text, "layer_chains": [layer_chain] + more_layer_chains} + + def visit_separated_layer_chain(self, node, separated_layer_chain): + layer_chain_sep, layer_chain = separated_layer_chain + return layer_chain + + def generic_visit(self, node, visited_children): + # print("generic_visit", len(visited_children), node, visited_children) + return visited_children or node + + +ParsedLayer = Dict + + +# TODO: wrap parsimonious exceptions and try to make them more useful! +def parse(description: str) -> List[ParsedLayer]: + """ + Parse an architecture definition. + + :param description: A nested layer-chain description such as 'rnn->ff(512)'. + :return: The parsed layer configurations as dictionaries containing keys and values that depend on the individual + layers. The only key common to all layers is the layer name 'name'. + """ + parsed_layers = CustomSeqParser().visit(custom_seq_grammar.parse(description)) + return parsed_layers + + +class KwargsDefaultFiller: + + def __init__(self, default_dict): + self.default_dict = default_dict + + def fill(self, name, func, params: Optional[List[Union[Any, Tuple[str,Any]]]]): + param_names = list(inspect.signature(func).parameters) + param_names = [name for name in param_names if name != 'self'] + if params is not None: + args = [param for param in params if not isinstance(param, tuple)] + kwargs = [param for param in params if isinstance(param, tuple)] + utils.check_condition(len(args) <= len(param_names), + "Too many parameters given. %d were given, but only %d are needed (%s) for layer %s." % (len(args), len(param_names), ", ".join(param_names), name)) + param_kwargs = {name: value for name, value in zip(param_names, args)} + for name, value in kwargs: + param_kwargs[name] = value + else: + param_kwargs = {} + for default_name, default_value in self.default_dict.items(): + if default_name in param_names and default_name not in param_kwargs: + param_kwargs[default_name] = default_value + return param_kwargs + + +def _fill_and_create(kwargs_filler, name, func, args): + return func(**kwargs_filler.fill(name, func.__init__, args)) + + +# TODO: create documentation of the different layers we have available +def _create_layer_configs(default_kwargs_filler, parsed_layers: List[Dict]) -> Tuple[List[layers.LayerConfig], bool]: + source_attention_present = False + layer_configs = [] + for layer in parsed_layers: + name = layer['name'] + # TODO: can we simplify this? Maybe have LayerConfigs register themselves + if name == 'ff': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.FeedForwardLayerConfig, layer['params'])) + elif name == 'linear': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.LinearLayerConfig, layer['params'])) + elif name == 'id': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.IdentityLayerConfig, layer['params'])) + elif name == 'mh_dot_att': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.MultiHeadSourceAttentionLayerConfig, layer['params'])) + source_attention_present = True + elif name == 'mh_dot_self_att': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.MultiHeadSelfAttentionLayerConfig, layer['params'])) + elif name == 'cnn': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, convolution.ConvolutionalLayerConfig, layer['params'])) + elif name == 'qrnn': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, convolution.QRNNLayerConfig, layer['params'])) + elif name == 'pool': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, convolution.PoolingLayerConfig, layer['params'])) + elif name == 'rnn': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, rnn.RecurrentLayerConfig, layer['params'])) + elif name == 'birnn': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, rnn.BidirectionalRecurrentLayerConfig, layer['params'])) + elif name == 'dropout': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.DropoutLayerConfig, layer['params'])) + elif name == 'act': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.ActivationLayerConfig, layer['params'])) + elif name == 'res': + sub_layers = layer["layers"] + sub_layer_configs, sub_layers_source_attention_present = _create_layer_configs(default_kwargs_filler, + sub_layers) + source_attention_present = source_attention_present or sub_layers_source_attention_present + layer_configs.append(layers.ResidualLayerConfig(layer_configs=sub_layer_configs)) + elif name == 'highway': + sub_layers = layer["layers"] + sub_layer_configs, sub_layers_source_attention_present = _create_layer_configs(default_kwargs_filler, + sub_layers) + source_attention_present = source_attention_present or sub_layers_source_attention_present + layer_configs.append(layers.HighwayLayerConfig(layer_configs=sub_layer_configs)) + elif name == 'repeat': + num = layer["num"] + sub_layers = layer["layers"] + for i in range(0, num): + sub_layer_configs, sub_layers_source_attention_present = _create_layer_configs(default_kwargs_filler, + sub_layers) + source_attention_present = source_attention_present or sub_layers_source_attention_present + layer_configs.extend(sub_layer_configs) + elif name == 'pos': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.SinCosPositionalEmbeddingsLayerConfig, layer['params'])) + elif name == 'learn_pos': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.LearnedPositionalEmbeddingsLayerConfig, layer['params'])) + elif name == 'norm': + layer_configs.append(_fill_and_create(default_kwargs_filler, + name, layers.LayerNormalizationLayerConfig, layer['params'])) + else: + raise ValueError("Unknown layer %s." % name) + return layer_configs, source_attention_present + + +# TODO: adapt doc-string +def parse_custom_seq_layers_description(default_num_hidden: int, + default_num_embed: int, + default_dropout: float, + max_seq_len: int, + description: str, + source_attention_needed: bool, + source_attention_forbidden: bool) -> List[layers.LayerConfig]: + """ + + :param num_hidden: The default number of hidden units, for all layers which do not specify them. + :param description: A custom layer description string such as 'ff->cnn->self_att->ff'. + :return: The parsed list of layer descriptions. + """ + parsed_layers = CustomSeqParser().visit(custom_seq_grammar.parse(description)) + + kwargs_filler = KwargsDefaultFiller({"dropout": default_dropout, + "num_hidden": default_num_hidden, + "num_embed": default_num_embed, + "max_seq_len": max_seq_len}) + + layer_configs, source_attention_present = _create_layer_configs(kwargs_filler, parsed_layers) + + if source_attention_needed: + utils.check_condition(source_attention_present, + "At least one source attention mechanism needed.") + if source_attention_forbidden: + utils.check_condition(not source_attention_present, + "Source attention not allowed on the encoder side.") + + return layer_configs + diff --git a/sockeye/decoder.py b/sockeye/decoder.py index faeeca3f8..090951b75 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -31,7 +31,8 @@ from .config import Config logger = logging.getLogger(__name__) -DecoderConfig = Union['RecurrentDecoderConfig', transformer.TransformerConfig, 'ConvolutionalDecoderConfig'] +DecoderConfig = Union['RecurrentDecoderConfig', transformer.TransformerConfig, 'ConvolutionalDecoderConfig', + 'CustomSeqDecoderConfig'] def get_decoder(config: DecoderConfig, prefix: str = '') -> 'Decoder': @@ -117,6 +118,7 @@ def decode_step(self, step: int, target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, + att_dict: Dict[str, Dict[str, mx.sym.Symbol]], *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ Decodes a single time step given the current step, the previous embedded target word, @@ -127,6 +129,8 @@ def decode_step(self, :param step: Global step of inference procedure, starts with 1. :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. + :param att_dict: A dictionary of attention matrices used for visualization with separate entries for source and + self-attention. :param states: Arbitrary list of decoder states. :return: logit inputs, attention probabilities, next decoder states. """ @@ -148,6 +152,7 @@ def get_num_hidden(self) -> int: @abstractmethod def init_states(self, + batch_size: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int) -> List[mx.sym.Symbol]: @@ -155,6 +160,7 @@ def init_states(self, Returns a list of symbolic states that represent the initial states of this decoder. Used for inference. + :param batch_size: The batch size. :param source_encoded: Encoded source. Shape: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. @@ -175,7 +181,7 @@ def state_variables(self, target_max_length: int) -> List[mx.sym.Symbol]: @abstractmethod def state_shapes(self, batch_size: int, - target_max_length: int, + step: int, source_encoded_max_length: int, source_encoded_depth: int) -> List[mx.io.DataDesc]: """ @@ -183,7 +189,7 @@ def state_shapes(self, Used for inference. :param batch_size: Batch size during inference. - :param target_max_length: Current target sequence length. + :param step: Current target sequence length. :param source_encoded_max_length: Size of encoder time dimension. :param source_encoded_depth: Depth of encoded source. :return: List of shape descriptions. @@ -221,6 +227,7 @@ def __init__(self, config, prefix="%s%d_" % (prefix, i)) for i in range(config.num_layers)] self.final_process = transformer.TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=self.config.model_size, prefix="%sfinal_process_" % prefix) self.pos_embedding = encoder.get_positional_embedding(config.positional_embedding_type, @@ -260,7 +267,7 @@ def decode_sequence(self, source_bias = mx.sym.expand_dims(source_bias, axis=1) # (1, target_max_length, target_max_length) - target_bias = transformer.get_autoregressive_bias(target_embed_max_length, name="%starget_bias" % self.prefix) + target_bias = layers.get_autoregressive_bias(target_embed_max_length, name="%starget_bias" % self.prefix) # target: (batch_size, target_max_length, model_size) target, _, target_max_length = self.pos_embedding.encode(target_embed, None, target_embed_max_length) @@ -281,6 +288,7 @@ def decode_step(self, step: int, target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, + att_dict: Dict[str, Dict[str, mx.sym.Symbol]], *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ Decodes a single time step given the current step, the previous embedded target word, @@ -291,9 +299,12 @@ def decode_step(self, :param step: Global step of inference procedure, starts with 1. :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. + :param att_dict: A dictionary of attention matrices used for visualization with separate entries for source and + self-attention. :param states: Arbitrary list of decoder states. :return: logit inputs, attention probabilities, next decoder states. """ + # TODO: use att_dict # for step > 1, states contains source_encoded, source_encoded_lengths, and cache tensors. source_encoded, source_encoded_lengths, *cache = states # type: ignore @@ -315,7 +326,7 @@ def decode_step(self, # auto-regressive bias for last position in sequence # (1, target_max_length, target_max_length) - target_bias = transformer.get_autoregressive_bias(step, name="%sbias" % self.prefix) + target_bias = layers.get_autoregressive_bias(step, name="%sbias" % self.prefix) target_bias = mx.sym.slice_axis(target_bias, axis=1, begin=-1, end=step) new_states = [source_encoded, source_encoded_lengths] @@ -364,6 +375,7 @@ def get_num_hidden(self) -> int: return self.config.model_size def init_states(self, + batch_size: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int) -> List[mx.sym.Symbol]: @@ -371,6 +383,7 @@ def init_states(self, Returns a list of symbolic states that represent the initial states of this decoder. Used for inference. + :param batch_size: The batch size. :param source_encoded: Encoded source. Shape: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. @@ -542,7 +555,7 @@ def __init__(self, self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) self.hidden_norm = None if self.config.layer_normalization: - self.hidden_norm = layers.LayerNormalization(prefix="%shidden_norm" % prefix) + self.hidden_norm = layers.LayerNormalization(prefix="%shidden_norm" % prefix, num_hidden=self.num_hidden) def _create_state_init_parameters(self): """ @@ -558,7 +571,8 @@ def _create_state_init_parameters(self): self.init_bs.append(mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx))) if self.config.layer_normalization: self.init_norms.append(layers.LayerNormalization(prefix="%senc2decinit_%d_norm" % (self.prefix, - state_idx))) + state_idx), + num_hidden=init_num_hidden)) def decode_sequence(self, source_encoded: mx.sym.Symbol, @@ -622,6 +636,7 @@ def decode_step(self, step: int, target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, + att_dict: Dict[str, Dict[str, mx.sym.Symbol]], *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ Decodes a single time step given the current step, the previous embedded target word, @@ -632,9 +647,12 @@ def decode_step(self, :param step: Global step of inference procedure, starts with 1. :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. + :param att_dict: A dictionary of attention matrices used for visualization with separate entries for source and + self-attention. :param states: Arbitrary list of decoder states. :return: logit inputs, attention probabilities, next decoder states. """ + # TODO: use att_dict source_encoded, prev_dynamic_source, source_encoded_length, prev_hidden, *layer_states = states # Get last state from source (batch_size, num_target_embed) @@ -690,6 +708,7 @@ def get_num_hidden(self) -> int: return self.num_hidden def init_states(self, + batch_size: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int) -> List[mx.sym.Symbol]: @@ -697,6 +716,7 @@ def init_states(self, Returns a list of symbolic states that represent the initial states of this decoder. Used for inference. + :param batch_size: The batch size. :param source_encoded: Encoded source. Shape: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. @@ -1085,8 +1105,8 @@ def _decode(self, for layer, att_layer in zip(self.layers, self.attention_layers): # (batch_size, target_seq_len, num_hidden) - target_hidden = layer(mx.sym.Dropout(target_hidden, p=drop_prob) if drop_prob > 0 else target_hidden, - target_embed_lengths, target_embed_max_length) + target_hidden, _, __ = layer(mx.sym.Dropout(target_hidden, p=drop_prob) if drop_prob > 0 else target_hidden, + target_embed_lengths, target_embed_max_length) # (batch_size, target_seq_len, num_embed) context = att_layer(target_hidden, source_encoded, source_encoded_lengths) @@ -1101,6 +1121,7 @@ def decode_step(self, step: int, target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, + att_dict: Dict[str, Dict[str, mx.sym.Symbol]], *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ Decodes a single time step given the current step, the previous embedded target word, @@ -1111,6 +1132,8 @@ def decode_step(self, :param step: Global step of inference procedure, starts with 1. :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. + :param att_dict: A dictionary of attention matrices used for visualization with separate entries for source and + self-attention. :param states: Arbitrary list of decoder states. :return: logit inputs, attention probabilities, next decoder states. """ @@ -1191,6 +1214,7 @@ def get_num_hidden(self) -> int: return self.config.cnn_config.num_hidden def init_states(self, + batch_size: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int) -> List[mx.sym.Symbol]: @@ -1198,6 +1222,7 @@ def init_states(self, Returns a list of symbolic states that represent the initial states of this decoder. Used for inference. + :param batch_size: The batch size. :param source_encoded: Encoded source. Shape: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. @@ -1259,3 +1284,172 @@ def state_shapes(self, def get_max_seq_len(self) -> Optional[int]: # The positional embeddings potentially pose a limit on the maximum length at inference time. return self.pos_embedding.get_max_seq_len() + + +class CustomSeqDecoderConfig(Config): + """ + Configuration of a custom decoder layer consisting of a list of potentially nested decoder layers. + + :param decoder_layers: A list of layer configurations. + :param num_embed: The size of the target embeddings. + :param dtype: The data type. + """ + def __init__(self, decoder_layers: List[layers.LayerConfig], num_embed: int, dtype: str = C.DTYPE_FP32) -> None: + super().__init__() + self.num_embed = num_embed + self.decoder_layers = decoder_layers + self.dtype = dtype + + +@Decoder.register(CustomSeqDecoderConfig, C.CUSTOM_SEQ_PREFIX + C.DECODER_PREFIX) +class CustomSeqDecoder(Decoder): + """ + Decoder composed of a customizable list of building blocks/layers, such as RNNs, CNNs and attention mechanisms. + """ + + def __init__(self, config: CustomSeqDecoderConfig, + prefix: str = C.DECODER_PREFIX) -> None: + super().__init__(config.dtype) + self.config = config + self.prefix = prefix + self.layers = [] # type: List[layers.DecoderLayer] + input_num_hidden = config.num_embed + for idx, layer_config in enumerate(config.decoder_layers): + layer = layer_config.create_decoder_layer(input_num_hidden, "%sl%d_" % (self.prefix, idx)) + input_num_hidden = layer.get_num_hidden() + self.layers.append(layer) + + def get_num_hidden(self) -> int: + return self.layers[-1].get_num_hidden() if len(self.layers) > 0 else self.config.num_embed + + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + target_lengths: mx.sym.Symbol, + target_max_length: int) -> mx.sym.Symbol: + # (1, target_max_length, target_max_length) + target_autoregressive_bias = layers.get_autoregressive_bias(target_max_length, + name="%starget_bias" % self.prefix) + + for layer_idx, layer in enumerate(self.layers): + target = layer.decode_sequence(source_encoded, source_encoded_lengths, source_encoded_max_length, + target, target_lengths, target_max_length, target_autoregressive_bias) + + return target + + def decode_step(self, + step: int, + target_embed_prev: mx.sym.Symbol, + source_encoded_max_length: int, + att_dict: Dict[str, Dict[str, mx.sym.Symbol]], + *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol,mx.sym.Symbol, List[mx.sym.Symbol]]: + # Source_encoded: (batch_size, source_encoded_max_length, encoder_num_hidden) + source_encoded, source_encoded_lengths, *layer_states_flat = states + + new_layer_states_flat = list(source_encoded) + [source_encoded_lengths] + + target = target_embed_prev + + for layer, layer_states in layers.layers_with_states_iter(self.layers, step, layer_states_flat): + target, new_layer_states = layer.decode_step(step, + source_encoded, + source_encoded_lengths, + source_encoded_max_length, + target, + layer_states, + att_dict) + assert isinstance(target, mx.sym.Symbol) and isinstance(new_layer_states, list) and all( + isinstance(s, mx.sym.Symbol) + for s in new_layer_states), "decode_step of layer %s at step %d does not have the correct return " \ + "type. Did you forget to return the new states." % (str(layer), step) + + next_step = step + 1 + assert len(new_layer_states) == layer.num_states(next_step), "Layer %s did not return the correct number " \ + "of next states (got %d, expected %d) at " \ + "step %d" % (str(layer), + len(new_layer_states), + layer.num_states(next_step), step) + new_layer_states_flat.extend(new_layer_states) + + # TODO: how to properly support the attention probabilities?? -> make att_dict the only attention reporting functionality + # (batch_size, source_encoded_max_length) + attention_probs = mx.sym.reshape(mx.sym.slice_axis(mx.sym.zeros_like(source_encoded), + axis=2, begin=0, end=1), + shape=(0, -1)) + + return target, attention_probs, new_layer_states_flat + + def att_names(self): + """ + :return: The names of all attention mechanisms which will be added to `att_dict` in `decode_step`. + """ + att_names = [] + for layer in self.layers: + att_names.extend(layer.att_names()) + return att_names + + def self_att_names(self): + """ + :return: The names of all self attention mechanisms which will be added to `att_dict` in `decode_step`. + """ + att_names = [] + for layer in self.layers: + att_names.extend(layer.self_att_names()) + return att_names + + def reset(self): + for layer in self.layers: + layer.reset() + + # TODO: add test that the custom seq decoder combines states correctly + def state_variables(self, target_max_length: int) -> List[mx.sym.Symbol]: + state_variables = [mx.sym.Variable(C.SOURCE_ENCODED_NAME), mx.sym.Variable(C.SOURCE_LENGTH_NAME)] + + for layer in self.layers: + layer_state_variables = layer.state_variables(target_max_length) + assert len(layer_state_variables) == layer.num_states(target_max_length), \ + "Inconsistent number of layer state variables (%d) and number of state shapes (%d) " \ + "for layer %s." % (len(layer_state_variables), layer.num_states(target_max_length), str(layer)) + state_variables.extend(layer_state_variables) + return state_variables + + def state_shapes(self, batch_size: int, step: int, source_encoded_max_length: int, source_encoded_depth: int) -> List[mx.io.DataDesc]: + state_shapes = [mx.io.DataDesc(C.SOURCE_ENCODED_NAME, + (batch_size, source_encoded_max_length, source_encoded_depth), + layout=C.BATCH_MAJOR)] + state_shapes = state_shapes + [mx.io.DataDesc(C.SOURCE_LENGTH_NAME, (batch_size,), layout="N")] + + for layer in self.layers: + layer_state_shapes = layer.state_shapes(batch_size, step, source_encoded_max_length, source_encoded_depth) + assert len(layer_state_shapes) == layer.num_states(step), \ + "Inconsistent number of layer states (%d) and number of state " \ + "shapes (%d) for layer %s." % (layer.num_states(step), len(layer_state_shapes), str(layer)) + state_shapes.extend(layer_state_shapes) + + return state_shapes + + def init_states(self, + batch_size: int, + source_encoded: List[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int) -> List[mx.sym.Symbol]: + if not isinstance(source_encoded, list): + source_encoded = [source_encoded] + init_states = source_encoded + [source_encoded_lengths] + + for layer in self.layers: + layer_init_states = layer.init_states(batch_size, source_encoded, source_encoded_lengths, + source_encoded_max_length) + assert len(layer_init_states) == layer.num_states(1), \ + "Num states at step 1 (%d) and number of init states (%d) don't match." % (len(layer_init_states), + layer.num_states(1)) + init_states.extend(layer_init_states) + return init_states + + def get_max_seq_len(self): + # The smallest maximum length across layers, as it is the most constraining + return min((layer.get_max_seq_len() for layer in self.layers if layer.get_max_seq_len() is not None), + default=None) + diff --git a/sockeye/encoder.py b/sockeye/encoder.py index 95c313bb6..5ee0a9d8d 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -25,6 +25,7 @@ from . import config from . import constants as C from . import convolution +from . import layers from . import rnn from . import transformer from . import utils @@ -42,6 +43,8 @@ def get_encoder(config: 'EncoderConfig', prefix: str = '') -> 'Encoder': return get_transformer_encoder(config, prefix) elif isinstance(config, ConvolutionalEncoderConfig): return get_convolutional_encoder(config, prefix) + elif isinstance(config, CustomSeqEncoderConfig): + return get_custom_seq_encoder(config, prefix) elif isinstance(config, EmptyEncoderConfig): return EncoderSequence([EmptyEncoder(config)], config.dtype) else: @@ -102,6 +105,21 @@ def __init__(self, self.dtype = dtype +class CustomSeqEncoderConfig(config.Config): + """ + Configuration of a custom decoder layer consisting of a list of potentially nested decoder layers. + + :param encoder_layers: A list of layer configurations. + :param num_embed: The size of the source embeddings. + :param dtype: The data type. + """ + def __init__(self, encoder_layers: List[layers.LayerConfig], num_embed: int, dtype: str = C.DTYPE_FP32) -> None: + super().__init__() + self.encoder_layers = encoder_layers + self.num_embed = num_embed + self.dtype = dtype + + class EmptyEncoderConfig(config.Config): """ Empty encoder configuration. @@ -129,7 +147,6 @@ def get_recurrent_encoder(config: RecurrentEncoderConfig, prefix: str) -> 'Encod :param prefix: Prefix for variable names. :return: Encoder instance. """ - # TODO give more control on encoder architecture encoder_seq = EncoderSequence([], config.dtype) if config.conv_config is not None: @@ -216,10 +233,20 @@ def get_transformer_encoder(config: transformer.TransformerConfig, prefix: str) prefix=prefix + C.CHAR_SEQ_ENCODER_PREFIX) encoder_seq.append(TransformerEncoder, config=config, prefix=prefix + C.TRANSFORMER_ENCODER_PREFIX) - return encoder_seq +def get_custom_seq_encoder(config: CustomSeqEncoderConfig, prefix: str) -> 'Encoder': + """ + Creates a custom sequence encoder encoder. + + :param config: Configuration for convolutional encoder. + :param prefix: The prefix. + :return: Encoder instance. + """ + return CustomSeqEncoder(config, prefix + C.CUSTOM_SEQ_PREFIX + C.ENCODER_PREFIX) + + class Encoder(ABC): """ Generic encoder interface. @@ -236,13 +263,16 @@ def __init__(self, dtype): def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: An optional dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ pass @@ -285,13 +315,16 @@ def __init__(self, target_layout: str, num_hidden: int, dtype: str = C.DTYPE_FP3 def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ with mx.AttrScope(__layout__=self.target_layout): @@ -315,7 +348,8 @@ def __init__(self, num_hidden: int, dtype: str = C.DTYPE_FP32) -> None: def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: data = mx.sym.SequenceReverse(data=data, sequence_length=data_length, use_sequence_length=True) return data, data_length, seq_len @@ -385,13 +419,16 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ factor_embeddings = [] # type: List[mx.sym.Symbol] @@ -453,12 +490,16 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int = 0) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. + :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ return data, data_length, seq_len @@ -512,11 +553,14 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ :param data: (batch_size, source_seq_len, num_embed) :param data_length: (batch_size,) :param seq_len: sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: (batch_size, source_seq_len, num_embed) """ # add positional embeddings to data @@ -600,11 +644,14 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ :param data: (batch_size, source_seq_len, num_embed) :param data_length: (batch_size,) :param seq_len: sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: (batch_size, source_seq_len, num_embed) """ @@ -658,7 +705,8 @@ def __init__(self, num_embed, dtype: str = C.DTYPE_FP32) -> None: def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: return data, data_length, seq_len def encode_positions(self, @@ -706,6 +754,48 @@ def get_positional_embedding(positional_embedding_type: str, return cls(**encoder_params) +class CustomSeqEncoder(Encoder): + """ + Encoder consisting of a custom sequence of layers. + """ + def __init__(self, config: CustomSeqEncoderConfig, prefix: str = C.ENCODER_PREFIX) -> None: + super().__init__(config.dtype) + self.config = config + self.prefix = prefix + self.layers = [] # type: List[layers.EncoderLayer] + input_num_hidden = config.num_embed + for idx, layer_config in enumerate(config.encoder_layers): + layer = layer_config.create_encoder_layer(input_num_hidden, "%sl%d_" % (self.prefix, idx)) + input_num_hidden = layer.get_num_hidden() + self.layers.append(layer) + + def encode(self, + data: mx.sym.Symbol, + data_length: Optional[mx.sym.Symbol], + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + for layer in self.layers: + data, data_length, seq_len = layer.encode_sequence(data, data_length, seq_len, att_dict) + return data, data_length, seq_len + + def att_names(self): + """ + :return: The names of all attention mechanisms which will be added to `att_dict` in `decode_step`. + """ + att_names = [] + for layer in self.layers: + att_names.extend(layer.att_names()) + return att_names + + def get_num_hidden(self) -> int: + return self.layers[-1].get_num_hidden() if len(self.layers) > 0 else self.config.num_embed + + def get_encoded_seq_len(self, seq_len: int): + for layer in self.layers: + seq_len = layer.get_encoded_seq_len(seq_len) + return seq_len + + class EncoderSequence(Encoder): """ A sequence of encoders is itself an encoder. @@ -721,13 +811,16 @@ def __init__(self, encoders: List[Encoder], dtype: str = C.DTYPE_FP32) -> None: def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ for encoder in self.encoders: @@ -794,12 +887,15 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: An optional dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Expected number of empty states (zero-filled). """ # outputs: (batch_size, seq_len, num_hidden) @@ -834,13 +930,16 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: Optional[mx.sym.Symbol], - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ outputs, _ = self.rnn.unroll(seq_len, inputs=data, merge_outputs=True, layout=self.layout) @@ -897,13 +996,16 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data (data, data_length, seq_len). """ if self.layout[0] == 'N': @@ -972,7 +1074,8 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data with a stack of Convolution+GLU blocks given sequence lengths of individual examples and maximum sequence length. @@ -980,6 +1083,8 @@ def encode(self, :param data: Input data. Shape: (batch_size, seq_len, input_num_hidden). :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded version of the data. """ # data: (batch_size, seq_len, num_hidden) @@ -991,7 +1096,7 @@ def encode(self, # Multiple layers with residual connections: for layer in self.layers: - data = data + layer(data, data_length, seq_len) + data = data + layer(data, data_length, seq_len)[0] return data, data_length, seq_len def get_num_hidden(self) -> int: @@ -1019,18 +1124,22 @@ def __init__(self, config, prefix="%s%d_" % (prefix, i)) for i in range(config.num_layers)] self.final_process = transformer.TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=self.config.model_size, prefix="%sfinal_process_" % prefix) def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data data, data_length, seq_len. """ data = utils.cast_conditionally(data, self.dtype) @@ -1147,13 +1256,16 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Input data. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). :return: Encoded versions of input data data, data_length, seq_len. """ total_num_filters = sum(self.num_filters) @@ -1273,6 +1385,6 @@ def get_encoded_seq_len(self, seq_len: int) -> int: EncoderConfig = Union[RecurrentEncoderConfig, transformer.TransformerConfig, ConvolutionalEncoderConfig, - EmptyEncoderConfig] + CustomSeqEncoderConfig, EmptyEncoderConfig] if ImageEncoderConfig is not None: EncoderConfig = Union[EncoderConfig, ImageEncoderConfig] # type: ignore diff --git a/sockeye/image_captioning/encoder.py b/sockeye/image_captioning/encoder.py index 756f20d34..9bac393d3 100644 --- a/sockeye/image_captioning/encoder.py +++ b/sockeye/image_captioning/encoder.py @@ -16,7 +16,7 @@ """ import logging import mxnet as mx -from typing import List, Tuple +from typing import List, Tuple, Optional, Dict from .. import constants as C from ..config import Config @@ -141,13 +141,15 @@ def __init__(self, def encode(self, data: mx.sym.Symbol, data_length: mx.sym.Symbol, - seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + seq_len: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: """ Encodes data given sequence lengths of individual examples and maximum sequence length. :param data: Ignored. Assume that the input is the image. :param data_length: Vector with sequence lengths. :param seq_len: Maximum sequence length. + :param att_dict: An optional dictionary of attention matrices used for visualization. :return: Encoded versions of input data data, data_length, seq_len. """ diff --git a/sockeye/inference.py b/sockeye/inference.py index d3918655c..cf4975b8f 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -171,16 +171,20 @@ def sym_gen(source_seq_len: int): source_embed_length, source_embed_seq_len) = self.embedding_source.encode(source, source_length, source_seq_len) + att_dict = {} # type: Dict[str, mx.sym.Symbol] + # encoder # source_encoded: (source_encoded_length, batch_size, encoder_depth) (source_encoded, source_encoded_length, source_encoded_seq_len) = self.encoder.encode(source_embed, source_embed_length, - source_embed_seq_len) + source_embed_seq_len, + att_dict) # initial decoder states - decoder_init_states = self.decoder.init_states(source_encoded, + decoder_init_states = self.decoder.init_states(self.batch_size, + source_encoded, source_encoded_length, source_encoded_seq_len) @@ -224,6 +228,8 @@ def sym_gen(bucket_key: Tuple[int, int]): # (batch_size, num_embed) target_embed_prev, _, _ = self.embedding_target.encode(data=target_prev, data_length=None, seq_len=1) + att_dict = {"source": {}, "self": {}} # type: Dict[str, Dict[str, mx.sym.Symbol]] + # decoder # target_decoded: (batch_size, decoder_depth) (target_decoded, @@ -231,6 +237,7 @@ def sym_gen(bucket_key: Tuple[int, int]): states) = self.decoder.decode_step(decode_step, target_embed_prev, source_encoded_seq_len, + att_dict, *states) if self.decoder_return_logit_inputs: diff --git a/sockeye/layers.py b/sockeye/layers.py index 096bfb58a..2fd6c6e73 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -13,11 +13,13 @@ import logging import math -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, List, Iterator, Sequence +from abc import ABC, abstractmethod import mxnet as mx import numpy as np +from .config import Config from . import constants as C from . import utils @@ -46,11 +48,1528 @@ def activation(data: mx.sym.Symbol, act_type: str) -> mx.sym.Symbol: # Approximation of x * gaussian_cdf(x) used by Hendrycks and Gimpel return 0.5 * data * (1 + mx.sym.Activation((math.sqrt(2 / math.pi) * (data + (0.044715 * (data**3)))), act_type="tanh")) + elif act_type == C.NO_ACTIVATION: + return data else: return mx.sym.Activation(data, act_type=act_type) -class LayerNormalization: +class Layer(ABC): + + def get_max_seq_len(self) -> Optional[int]: + """ + :return: The maximum length supported by the layer if such a restriction exists. + """ + return None + + @abstractmethod + def get_num_hidden(self) -> int: + """ + :return: The representation size of this layer. + """ + raise NotImplementedError() + + +class EncoderLayer(Layer): + """ + Generic encoder layer interface for a layer which takes the sequence of previous hidden states and sequence lengths + and produces a new sequence hidden states, potentially changing the sequence length. + """ + + @abstractmethod + def encode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + """ + Encodes data given sequence lengths of individual examples and maximum sequence length. + + :param source_encoded: Input data of size (batch_size, seq_len, num_hidden). + :param source_encoded_lengths: Vector with sequence lengths of size (batch_size,). + :param source_encoded_max_length: Maximum sequence length. + :param att_dict: A dictionary of attention matrices used for visualization. + Each matrix must be of size (batch_size, source_length, source_length). + :return: Encoded versions of the input data, the new sequence lenghts and the new maximum length. + """ + pass + + def att_names(self) -> List[str]: + """Names of attention matrices produced by this layer.""" + return [] + + def get_encoded_seq_len(self, seq_len: int) -> int: + """ + :return: The size of the encoded sequence. + """ + return seq_len + + +class DecoderLayer(Layer): + """ + Generic decoder layer interface. + """ + + # TODO: do we need all of the arguments? + @abstractmethod + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Run the layer on the entire sequence. + + :param source_encoded: Encoded source layer: (source_encoded_max_length, batch_size, encoder_depth). + :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). + :param source_encoded_max_length: Size of encoder time dimension. + :param target_encoded: Input data. Shape: (batch_size, seq_len, num_hidden). + :param target_encoded_lengths: Vector with sequence lengths. Shape: (batch_size,). + :param target_encoded_max_length: Maximum sequence length. + :param target_autoregressive_bias: The auto-regressive bias. + :return: hidden_data, Shape: (batch_size, seq_len, num_hidden). + """ + pass + + # TODO: potentially define a DecoderAttDict class with source and self_att members which are dicts?! + @abstractmethod + def decode_step(self, + step: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + states: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + """ + Run the decoder layer for a single position given the current step, the previous embedded target word, + and previous decoder layer states. + + :param step: Global step of inference procedure, starts with 1. + :param source_encoded: Encoded source layer: (source_encoded_max_length, batch_size, encoder_depth). + :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). + :param source_encoded_max_length: Size of encoder time dimension. + :param target: Shape: (batch_size, input_num_hidden). + :param states: Arbitrary layer states. + :param att_dict: A dictionary of attention matrices used for visualization with separate entries for source and + self attention {'self': Dict, 'source': Dict}. Each source attention matrix must be of size + (batch_size, 1, source_length) and each self-attention of size (batch_size, 1, step). + :return: Step result of the layer (batch_size, num_hidden) and a list of new layer states. + """ + pass + + def att_names(self) -> List[str]: + """Names of attention matrices produced by this layer.""" + return [] + + def self_att_names(self) -> List[str]: + """Names of self attention matrices produced by this layer.""" + return [] + + def reset(self): + """ + Reset decoder method. Used for inference. + """ + pass + + def num_states(self, step: int) -> int: + """ + The number of input states at the given step. + :param step: + :return: + """ + return 0 + + def state_variables(self, step: int) -> Sequence[mx.sym.Symbol]: + """ + Returns the list of symbolic variables for this decoder to be used during inference. + + :param step: Current target sequence length. + :return: List of symbolic variables. + """ + return [] + + def init_states(self, + batch_size: int, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int) -> Sequence[mx.sym.Symbol]: + """ + Returns a list of symbolic states that represent the initial states of this decoder. + Used for inference. + + :param batch_size: The batch size. + :param source_encoded: Encoded source. Shape: (batch_size, source_encoded_max_length, encoder_depth). + :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). + :param source_encoded_max_length: Size of encoder time dimension. + :return: List of symbolic initial states. + """ + return [] + + def state_shapes(self, + batch_size: int, + target_max_length: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int) -> List[mx.io.DataDesc]: + """ + Returns a list of shape descriptions given batch size, encoded source max length and encoded source depth. + Used for inference. + + :param batch_size: Batch size during inference. + :param source_encoded_max_length: Size of encoder time dimension. + :param source_encoded_num_hidden: Depth of encoded source. + :return: List of shape descriptions. + """ + return [] + + +class LayerConfig(Config): + """ + A layer config object used for serializing layer parameters. Each layer config object also defines how an encoder + or decoder layer are created from the given parameters. + """ + + @abstractmethod + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + pass + + @abstractmethod + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + pass + + +class SharedEncoderDecoderLayer(DecoderLayer, EncoderLayer): + """ + A layer which does not depend on the source hidden states. On the target side it will be applied on the target + hidden states and on the source side on the source hidden states. Using this class as a base class only the combined + `process_sequence` method needs to be implemented. + """ + + @abstractmethod + def process_sequence(self, + data: mx.sym.Symbol, + lengths: mx.sym.Symbol, + max_length: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + """ + Process either the encoder or the decoder hidden states. + :param data: Encoded source layer: (source_encoded_max_length, batch_size, encoder_depth). + :param lengths: Lengths of hidden sequences. Shape: (batch_size,). + :param max_length: Maximum length. + :return: Encoded versions of the input data, the new sequence lenghts and the new maximum length. + """ + pass + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + return self.process_sequence(source_encoded, source_encoded_lengths, source_encoded_max_length) + + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + new_target_encoded, new_target_encoded_lengths, new_target_encoded_max_length = self.process_sequence( + target_encoded, target_encoded_lengths, target_encoded_max_length) + assert new_target_encoded_lengths is target_encoded_lengths, C.SEQUENCE_LENGTH_MUST_NOT_CHANGE_MSG + assert new_target_encoded_max_length == target_encoded_max_length, C.SEQUENCE_LENGTH_MUST_NOT_CHANGE_MSG + return new_target_encoded + + +def layers_with_states_iter( + layers: List[DecoderLayer], + step: int, + layer_states_flat: Sequence[mx.sym.Symbol]) -> Iterator[Tuple[DecoderLayer, Sequence[mx.sym.Symbol]]]: + """ + A generator for layers and corresponding layer states from a flat list of all layer states. + :param layers: A list of layers. + :param step: The current decoder step. + :param layer_states_flat: A flat list of layer states across all layers, e.g. [l1_state1, l1_state2, l2_state1, ..]. + :return: A generator of tuples of decoder layers and their states. + """ + state_idx = 0 + for layer in layers: + if layer.num_states(step) != 0: + layer_states = layer_states_flat[state_idx:state_idx + layer.num_states(step)] + state_idx += layer.num_states(step) + else: + layer_states = [] + yield layer, layer_states + + +class EncoderLayerChain(EncoderLayer): + + def __init__(self, layers: List[EncoderLayer]) -> None: + assert len(layers) >= 1, "At least one layer needed in layer chain." + self.layers = layers + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + for layer in self.layers: + source_encoded, source_encoded_lengths, source_encoded_max_length = layer.encode_sequence( + source_encoded, source_encoded_lengths, source_encoded_max_length, att_dict) + return source_encoded, source_encoded_lengths, source_encoded_max_length + + def att_names(self): + att_names = [] + for layer in self.layers: + att_names.extend(layer.att_names()) + return att_names + + def get_num_hidden(self) -> int: + return self.layers[-1].get_num_hidden() + + +class NestedDecoderLayer(DecoderLayer): + """ + A decoder layer which combines several other decoder sub-layers. + """ + + @property + @abstractmethod + def layers(self) -> List[DecoderLayer]: + pass + + def layers_with_states_iter(self, + step: int, + layer_states_flat: Sequence[mx.sym.Symbol]) -> Iterator[Tuple[DecoderLayer, + Sequence[mx.sym.Symbol]]]: + return layers_with_states_iter(self.layers, step, layer_states_flat) + + def att_names(self) -> List[str]: + att_names = [] + for layer in self.layers: + att_names.extend(layer.att_names()) + return att_names + + def self_att_names(self) -> List[str]: + att_names = [] + for layer in self.layers: + att_names.extend(layer.self_att_names()) + return att_names + + def reset(self) -> None: + for layer in self.layers: + layer.reset() + + def num_states(self, step) -> int: + return sum(layer.num_states(step) for layer in self.layers) + + def state_variables(self, step: int) -> Sequence[mx.sym.Symbol]: + return [var for layer in self.layers for var in layer.state_variables(step)] + + def state_shapes(self, + batch_size: int, + step: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int) -> List[mx.io.DataDesc]: + return [state_shape for layer in self.layers for state_shape in layer.state_shapes(batch_size, + step, + source_encoded_max_length, + source_encoded_num_hidden)] + + def init_states(self, + batch_size: int, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int) -> Sequence[mx.sym.Symbol]: + return [init_state for layer in self.layers for init_state in layer.init_states(batch_size, + source_encoded, + source_encoded_lengths, + source_encoded_max_length)] + + def get_max_seq_len(self) -> Optional[int]: + return min((layer.get_max_seq_len() for layer in self.layers if layer.get_max_seq_len() is not None), + default=None) + + +class DecoderLayerChain(NestedDecoderLayer): + + def __init__(self, layers: List[DecoderLayer]) -> None: + assert len(layers) >= 1, "At least one layer needed in layer chain." + self._layers = layers + + @property + def layers(self) -> List[DecoderLayer]: + return self._layers + + def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + for layer in self.layers: + target_encoded = layer.decode_sequence(source_encoded, + source_encoded_lengths, + source_encoded_max_length, + target_encoded, + target_encoded_lengths, + target_encoded_max_length, + target_autoregressive_bias) + + return target_encoded + + def decode_step(self, step: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, layer_states_flat: Sequence[mx.sym.Symbol], + att_dict) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + new_layer_states_flat = [] # type: List[mx.sym.Symbol] + for layer, layer_states in self.layers_with_states_iter(step, layer_states_flat): + target, new_layer_states = layer.decode_step(step, source_encoded, source_encoded_lengths, + source_encoded_max_length, target, layer_states, att_dict) + new_layer_states_flat.extend(new_layer_states) + + return target, new_layer_states_flat + + def get_num_hidden(self) -> int: + return self._layers[-1].get_num_hidden() + + +class StatelessBlock(ABC): + """ + A block which does not require to keep any state during inference time so that we are able to call it one + timestep at a time. This means that the block can be run as + + concat(block(data[:,t,:]) for t in range(seq_len), dim=1) + + Namely, at time t we only need the data point at time t from the previous layer. + """ + + @abstractmethod + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + """ + Compute the block for the entire sequence. + + :param data: Input data, Shape: (batch_size, seq_len, input_num_hidden). + :param lengths: Optional vector with sequence lengths, Shape: (batch_size,). + :param max_length: Optional maximum sequence length. + :return: new data (batch_size, seq_len, num_hidden). + """ + pass + + def step(self, step: int, data: mx.sym.Symbol) -> mx.sym.Symbol: + """ + Compute the block for a single time step. + + :param step: The current step. + :param data: Data for a single time step. Shape: (batch_size, 1, input_num_hidden). + :return: Shape: (batch_size, 1, num_hidden). + """ + return self.__call__(data) + + @abstractmethod + def get_num_hidden(self) -> int: + """ + :return: The representation size of this block. + """ + pass + + def get_max_seq_len(self) -> Optional[int]: + """ + :return: The maximum length supported by the block if such a restriction exists. + """ + return None + + +class StatelessBlockLayer(SharedEncoderDecoderLayer): + """ + A stateless block layer which can act as both a encoder or a decoder layer, applying the block to either the + source or target hidden state. + """ + + def __init__(self, block: StatelessBlock) -> None: + self.block = block + + def process_sequence(self, + data: mx.sym.Symbol, + lengths: mx.sym.Symbol, + max_length: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + return self.block(data, lengths, max_length), lengths, max_length + + def get_num_hidden(self) -> int: + return self.block.get_num_hidden() + + def decode_step(self, + step: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + states: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + # (batch_size, 1, num_hidden) + target = mx.sym.expand_dims(target, axis=1) + + # (batch_size, 1, num_hidden_block) + hidden = self.block.step(step, target) + + # (batch_size, num_hidden_block) + hidden = mx.sym.reshape(hidden, shape=(0, -1)) + return hidden, [] + + def get_max_seq_len(self): + return self.block.get_max_seq_len() + + +class FeedForwardBlock(StatelessBlock): + """ + Position-wise feed-forward network with activation and dropout. + """ + + def __init__(self, + num_hidden: int, + dropout: float = 0.0, + act_type: str = C.RELU, + prefix: str = "") -> None: + self.num_hidden = num_hidden + self.dropout = dropout + self.prefix = prefix + self.act_type = act_type + self.w_i2h = mx.sym.Variable('%si2h_weight' % prefix) + self.b_i2h = mx.sym.Variable('%si2h_bias' % prefix) + + def _pre_activation_num_hidden(self): + if self.act_type == C.GLU: + return 2 * self.num_hidden + else: + return self.num_hidden + + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + """ + Apply the feed-forward layer. + + :param x: Symbol of shape (batch_size, seq_len, num_hidden) + :return: Symbol of shape (batch_size, seq_len, num_hidden) + """ + h = mx.sym.FullyConnected(data=data, num_hidden=self._pre_activation_num_hidden(), + weight=self.w_i2h, bias=self.b_i2h, + flatten=False, name=self.prefix + "ff") + + if self.act_type == C.GLU: + # GLU + # two times: (batch_size, seq_len, num_hidden) + + # pylint: disable=unbalanced-tuple-unpacking + gate_a, gate_b = mx.sym.split(h, num_outputs=2, axis=2) + # (batch_size, seq_len, num_hidden) + h = mx.sym.broadcast_mul(gate_a, + mx.sym.Activation(data=gate_b, act_type="sigmoid")) + else: + h = activation(h, act_type=self.act_type) + if self.dropout > 0.0: + h = mx.sym.Dropout(h, p=self.dropout) + return h + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class FeedForwardLayerConfig(LayerConfig): + + def __init__(self, + num_hidden: int, + dropout: float, + act_type: str = C.RELU) -> None: + super().__init__() + self.num_hidden = num_hidden + self.act_type = act_type + self.dropout = dropout + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return self.create_block_layer(prefix) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return self.create_block_layer(prefix) + + def create_block_layer(self, prefix: str) -> StatelessBlockLayer: + return StatelessBlockLayer(FeedForwardBlock(self.num_hidden, + self.dropout, + self.act_type, + prefix + "ff_")) + + +class LinearBlock(StatelessBlock): + """ + Linear projection followed by dropout. + """ + + def __init__(self, + num_hidden: int, + dropout: float, + no_bias: bool, + prefix: str) -> None: + self.num_hidden = num_hidden + self.dropout = dropout + self.no_bias = no_bias + self.prefix = prefix + self.w_i2h = mx.sym.Variable('%si2h_weight' % prefix) + self.b_i2h = None if no_bias else mx.sym.Variable('%si2h_bias' % prefix) + + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + """ + Apply the linear projection. + + :param x: Symbol of shape (batch_size, seq_len, num_hidden) + :return: Symbol of shape (batch_size, seq_len, num_hidden) + """ + bias = None if self.no_bias else self.b_i2h + h = mx.sym.FullyConnected(data=data, num_hidden=self.num_hidden, weight=self.w_i2h, bias=bias, + no_bias=self.no_bias, flatten=False) + if self.dropout > 0.0: + h = mx.sym.Dropout(h, p=self.dropout) + return h + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class LinearLayerConfig(LayerConfig): + + def __init__(self, num_hidden: int, dropout: float, no_bias=False) -> None: + super().__init__() + self.num_hidden = num_hidden + self.dropout = dropout + self.no_bias = no_bias + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return self.create_block_layer(prefix) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return self.create_block_layer(prefix) + + def create_block_layer(self, prefix: str) -> StatelessBlockLayer: + return StatelessBlockLayer(LinearBlock(self.num_hidden, self.dropout, no_bias=self.no_bias, + prefix=prefix + "linear_")) + + +class ActivationBlock(StatelessBlock): + + def __init__(self, + act_type: str, + num_hidden: int) -> None: + self.act_type = act_type + self.num_hidden = num_hidden + + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + """ + Apply the activation function. + + :param x: Symbol of shape (batch_size, seq_len, num_hidden) + :return: Symbol of shape (batch_size, seq_len, num_hidden) + """ + data = activation(data, act_type=self.act_type) + return data + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class ActivationLayerConfig(LayerConfig): + + def __init__(self, act_type: str = C.RELU) -> None: + super().__init__() + self.act_type = act_type + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return StatelessBlockLayer(ActivationBlock(num_hidden=input_num_hidden, act_type=self.act_type)) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return StatelessBlockLayer(ActivationBlock(num_hidden=input_num_hidden, act_type=self.act_type)) + + +class DropoutBlock(StatelessBlock): + """ + Position-wise feed-forward network with activation. + """ + + def __init__(self, + dropout: float, + num_hidden: int) -> None: + self.dropout = dropout + self.num_hidden = num_hidden + + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + """ + Position-wise feed-forward network with activation. + + :param x: Symbol of shape (batch_size, seq_len, num_hidden) + :return: Symbol of shape (batch_size, seq_len, num_hidden) + """ + if self.dropout > 0.0: + data = mx.sym.Dropout(data, p=self.dropout) + return data + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class DropoutLayerConfig(LayerConfig): + + def __init__(self, dropout: float) -> None: + super().__init__() + self.dropout = dropout + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return StatelessBlockLayer(DropoutBlock(num_hidden=input_num_hidden, dropout=self.dropout)) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return StatelessBlockLayer(DropoutBlock(num_hidden=input_num_hidden, dropout=self.dropout)) + + +class LearnedAdditivePositionalEmbeddings(StatelessBlock): + + def __init__(self, + num_embed: int, + dropout: float, + max_seq_len: int, + prefix: str) -> None: + self.num_embed = num_embed + self.dropout = dropout + self.max_seq_len = max_seq_len + self.prefix = prefix + self.embed_weight = mx.sym.Variable(prefix + "weight") + + def __call__(self, data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + assert max_length is not None, "max_length needed for positional embeddings." + # (1, source_seq_len) + positions = mx.sym.expand_dims(data=mx.sym.arange(start=0, stop=max_length, step=1), axis=0) + + # (1, source_seq_len, num_embed) + pos_embedding = mx.sym.Embedding(data=positions, + input_dim=self.max_seq_len, + weight=self.embed_weight, + output_dim=self.num_embed, + name=self.prefix + "pos_embed") + data = mx.sym.broadcast_add(data, pos_embedding, name="%s_add" % self.prefix) + if self.dropout > 0.0: + data = mx.sym.Dropout(data, p=self.dropout) + return data + + def step(self, step: int, data: mx.sym.Symbol) -> mx.sym.Symbol: + position = step - 1 + position = position * mx.sym.reshape( + mx.sym.slice_axis(mx.sym.reshape(mx.sym.ones_like(data), shape=(0, -1)), axis=1, begin=0, end=1), + shape=(-1)) + pos_embedding = mx.sym.Embedding(data=position, + input_dim=self.max_seq_len, + weight=self.embed_weight, + output_dim=self.num_embed, + name=self.prefix + "pos_embed") + pos_embedding = mx.sym.expand_dims(pos_embedding, axis=1) + return mx.sym.broadcast_add(data, pos_embedding, name="%s_add" % self.prefix) + + def get_num_hidden(self) -> int: + return self.num_embed + + def get_max_seq_len(self): + return self.max_seq_len + + +class LearnedPositionalEmbeddingsLayerConfig(LayerConfig): + + def __init__(self, + num_embed: int, + dropout: float, + max_seq_len: int) -> None: + super().__init__() + self.num_embed = num_embed + self.dropout = dropout + self.max_seq_len = max_seq_len + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return self.create_layer(prefix) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return self.create_layer(prefix) + + def create_layer(self, prefix: str) -> StatelessBlockLayer: + return StatelessBlockLayer(LearnedAdditivePositionalEmbeddings(num_embed=self.num_embed, + dropout=self.dropout, + max_seq_len=self.max_seq_len, + prefix=prefix + "pos_embed_")) + + +# TODO: share the code with the encoder! +class AdditiveSinCosPositionalEmbeddings(StatelessBlock): + """ + Takes an encoded sequence and adds fixed positional embeddings as in Vaswani et al, 2017 to it. + + :param num_embed: Embedding size. + :param prefix: Name prefix for symbols of this encoder. + :param scale_up_input: If True, scales input data up by num_embed ** 0.5. + :param scale_down_positions: If True, scales positional embeddings down by num_embed ** -0.5. + """ + + def __init__(self, + num_embed: int, + dropout: float, + prefix: str, + scale_up_input: bool, + scale_down_positions: bool) -> None: + utils.check_condition(num_embed % 2 == 0, "Positional embeddings require an even embedding size it " + "is however %d." % num_embed) + self.scale_up_input = scale_up_input + self.scale_down_positions = scale_down_positions + self.num_embed = num_embed + self.dropout = dropout + self.prefix = prefix + + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None) -> mx.sym.Symbol: + assert max_length is not None, "max_length needed for positional embeddings." + # add positional embeddings to data + if self.scale_up_input: + data = data * (self.num_embed ** 0.5) + + positions = mx.sym.BlockGrad(mx.symbol.Custom(length=max_length, + depth=self.num_embed, + name="%spositional_encodings" % self.prefix, + op_type='positional_encodings')) + + if self.scale_down_positions: + positions = positions * (self.num_embed ** -0.5) + + embedding = mx.sym.broadcast_add(data, positions) + if self.dropout > 0.0: + embedding = mx.sym.Dropout(embedding, p=self.dropout) + return embedding + + def step(self, step: int, data: mx.sym.Symbol) -> mx.sym.Symbol: + position = step - 1 + # (batch_size, num_hidden) -> (batch_size,) + positions = position * mx.sym.reshape( + mx.sym.slice_axis(mx.sym.reshape(mx.sym.ones_like(data), shape=(0, -1)), axis=1, begin=0, end=1), + shape=(-1)) + # (batch_size, 1) + positions = mx.sym.expand_dims(positions, axis=1) + # (num_embed,) + channels = mx.sym.arange(0, self.num_embed // 2) + # (1, num_embed,) + scaling = mx.sym.expand_dims(1. / mx.sym.pow(10000, (2 * channels) / self.num_embed), axis=0) + + # (batch_size, num_embed/2) + scaled_positions = mx.sym.dot(positions, scaling) + + sin = mx.sym.sin(scaled_positions) + cos = mx.sym.cos(scaled_positions) + + # (batch_size, num_embed) + pos_embedding = mx.sym.concat(sin, cos, dim=1) + + if self.scale_up_input: + data = data * (self.num_embed ** 0.5) + + if self.scale_down_positions: + pos_embedding = pos_embedding * (self.num_embed ** -0.5) + + # (batch_size, 1, num_embed) + pos_embedding = mx.sym.expand_dims(pos_embedding, axis=1) + return mx.sym.broadcast_add(data, pos_embedding, name="%s_add" % self.prefix) + + def get_num_hidden(self) -> int: + return self.num_embed + + +class SinCosPositionalEmbeddingsLayerConfig(LayerConfig): + + def __init__(self, + num_embed: int, + dropout: float = 0.0, + scale_inputs: bool = True) -> None: + super().__init__() + self.num_embed = num_embed + self.dropout = dropout + if scale_inputs: + # Transformer default (seems to work better for the transformer) + self.scale_up_input = True + self.scale_down_positions = False + else: + self.scale_up_input = False + self.scale_down_positions = True + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return self.create_layer(prefix) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return self.create_layer(prefix) + + def create_layer(self, prefix) -> StatelessBlockLayer: + return StatelessBlockLayer(AdditiveSinCosPositionalEmbeddings(num_embed=self.num_embed, + dropout=self.dropout, + prefix=prefix + "pos_embed_", + scale_up_input=self.scale_up_input, + scale_down_positions=self.scale_down_positions)) + + +def get_autoregressive_bias(max_length: int, name: str) -> mx.sym.Symbol: + """ + Returns bias/mask to ensure position i can only attend to positions None: + super().__init__() + self.bias = self.get_bias(length, dtype, ctx) + + @staticmethod + def get_bias(length: int, dtype: str, ctx: mx.Context): + # matrix with lower triangle and main diagonal set to 0, upper triangle set to 1 + upper_triangle = np.triu(np.ones((length, length), dtype=dtype), k=1) + # (1, length, length) + bias = -C.LARGE_VALUES[dtype] * np.reshape(upper_triangle, (1, length, length)) + return mx.nd.array(bias, ctx=ctx) + + def forward(self, is_train, req, in_data, out_data, aux): + self.assign(out_data[0], req[0], self.bias) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + pass + + +@mx.operator.register("auto_regressive_bias") +class AutoRegressiveBiasProp(mx.operator.CustomOpProp): + + def __init__(self, length: str, dtype: str = C.DTYPE_FP32) -> None: + super().__init__() + self.length = int(length) + self.dtype = dtype + + def list_arguments(self): + return [] + + def list_outputs(self): + return ['output'] + + def infer_shape(self, in_shape): + return [], [(1, self.length, self.length)], [] + + def infer_type(self, in_type): + return [], [np.dtype(self.dtype).type], [] + + def create_operator(self, ctx, shapes, dtypes): + return AutoRegressiveBias(length=self.length, dtype=self.dtype, ctx=ctx) + + +class MultiHeadSourceAttentionDecoderLayer(DecoderLayer): + + def __init__(self, num_hidden, att_num_hidden, heads: int, dropout: float, dropout_attention: float, + prefix: str = "") -> None: + self.prefix = prefix + self.num_hidden = num_hidden + self.att_num_hidden = att_num_hidden if att_num_hidden is not None else num_hidden + self.dropout = dropout + self.att = MultiHeadAttention(prefix=prefix, heads=heads, depth_att=self.att_num_hidden, depth_out=num_hidden, + dropout=dropout_attention) + + def att_names(self): + return [self.prefix + ("h%d" % i) for i in range(1, self.att.heads+1)] + + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + contexts, probs = self.att(target_encoded, source_encoded, source_encoded_lengths) + + if self.dropout > 0.0: + contexts = mx.sym.Dropout(contexts, p=self.dropout) + return contexts + + def decode_step(self, + step: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + states: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + # (batch_size, 1, num_hidden) + target = mx.sym.expand_dims(target, axis=1) + + context, probs = self.att(target, source_encoded, source_encoded_lengths) + + # probs is of shape (batch, heads, 1, target_length) + for head_idx, head_prob in enumerate(mx.sym.split(probs, axis=1, squeeze_axis=True, + num_outputs=self.att.heads), 1): + name = self.prefix + ("h%d" % head_idx) + att_dict["source"][name] = mx.sym.identity(head_prob, name=name) + + # (batch_size, num_hidden) + context = mx.sym.reshape(context, shape=(0, -1)) + return context, [] + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class MultiHeadSourceAttentionLayerConfig(LayerConfig): + def __init__(self, + heads: int = 8, + dropout: float = 0.0, + dropout_attention: Optional[float] = None, + num_hidden: int = None, + att_num_hidden: Optional[int] = None) -> None: + super().__init__() + assert num_hidden is not None + self.num_hidden = num_hidden + self.att_num_hidden = att_num_hidden if att_num_hidden is not None else num_hidden + self.dropout = dropout + self.dropout_attention = dropout_attention if dropout_attention is not None else dropout + self.heads = heads + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + raise NotImplementedError("Source attention is only availabe on the decoder side.") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return MultiHeadSourceAttentionDecoderLayer(num_hidden=self.num_hidden, + att_num_hidden=self.att_num_hidden, + heads=self.heads, + dropout=self.dropout, + dropout_attention=self.dropout_attention, + prefix=prefix + "mh_att_") + + +class MultiHeadSelfAttentionEncoderLayer(EncoderLayer): + + def __init__(self, num_hidden, att_num_hidden: Optional[int], heads: int, dropout: float, dropout_attention: float, + prefix: str) -> None: + if att_num_hidden is None: + att_num_hidden = num_hidden + self.prefix = prefix + self.num_hidden = num_hidden + self.dropout = dropout + self.att = MultiHeadSelfAttention(prefix=prefix, + dropout=dropout_attention, + heads=heads, + depth_att=att_num_hidden, + depth_out=num_hidden) + + def encode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + contexts, probs = self.att(source_encoded, source_encoded_lengths) + + if att_dict is not None: + # probs is of shape (batch, heads, source_length, source_length) + for head_idx, head_prob in enumerate(mx.sym.split(probs, axis=1, squeeze_axis=True, + num_outputs=self.att.heads), 1): + name = self.prefix + ("h%d" % head_idx) + att_dict[name] = mx.sym.identity(head_prob, name=name) + + if self.dropout > 0.0: + contexts = mx.sym.Dropout(contexts, p=self.dropout) + return contexts, source_encoded_lengths, source_encoded_max_length + + def att_names(self): + return [self.prefix + ("h%d" % i) for i in range(1, self.att.heads+1)] + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class MultiHeadSelfAttentionDecoderLayer(DecoderLayer): + + def __init__(self, num_hidden, att_num_hidden: int, heads: int, dropout: float, dropout_attention: float, + prefix: str = "") -> None: + self.prefix = prefix + self.num_hidden = num_hidden + self.att_num_hidden = att_num_hidden if att_num_hidden is not None else num_hidden + self.dropout = dropout + self.att = MultiHeadSelfAttention(prefix=prefix, + dropout=dropout_attention, + heads=heads, + depth_att=self.att_num_hidden, + depth_out=num_hidden) + + def decode_sequence(self, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + + contexts, _ = self.att(target_encoded, bias=target_autoregressive_bias) + if self.dropout > 0.0: + contexts = mx.sym.Dropout(contexts, p=self.dropout) + return contexts + + def decode_step(self, + step: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + states: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + target = mx.sym.expand_dims(target, axis=1) + + if step > 1: + prev_keys, prev_values = states + cache = {'k': prev_keys, 'v': prev_values} + else: + cache = {'k': None, 'v': None} + + context, probs = self.att(target, cache=cache) + + new_states = [cache['k'], cache['v']] # type: Sequence[mx.sym.Symbol] + + # Fill the attention dictionary + # probs has shape (batch, heads, 1, target_length) + for head_idx, head_prob in enumerate(mx.sym.split(probs, axis=1, squeeze_axis=True, + num_outputs=self.att.heads), 1): + name = self.prefix + ("h%d" % head_idx) + att_dict["self"][name] = mx.sym.identity(head_prob, name=name) + + # context: (batch_size, 1, dv) -> (batch_size, num_hidden) + context = mx.sym.reshape(context, shape=(0, -1)) + + # note: self.att has updated the cache + + return context, new_states + + def self_att_names(self): + return [self.prefix + ("h%d" % i) for i in range(1, self.att.heads+1)] + + def num_states(self, step): + if step == 1: + return 0 + else: + return 2 + + def state_variables(self, step: int): + if step == 1: + return [] + else: + return [mx.sym.Variable("%sself_att_state0" % self.prefix), + mx.sym.Variable("%sself_att_state1" % self.prefix)] + + def state_shapes(self, + batch_size: int, + step: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int): + if step == 1: + return [] + else: + return [mx.io.DataDesc(name="%sself_att_state0" % self.prefix, + shape=(batch_size, + (step - 1), + self.att_num_hidden), + layout=C.BATCH_MAJOR), + mx.io.DataDesc(name="%sself_att_state1" % self.prefix, + shape=(batch_size, + (step - 1), + self.att_num_hidden), + layout=C.BATCH_MAJOR)] + + def init_states(self, + batch_size: int, + source_encoded: Sequence[mx.sym.Symbol], + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int): + return [] + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class MultiHeadSelfAttentionLayerConfig(LayerConfig): + + def __init__(self, + heads: int = 8, + dropout: float = 0.0, + dropout_attention: Optional[float] = None, + num_hidden: int = None, + att_num_hidden: Optional[int] = None) -> None: + super().__init__() + assert num_hidden is not None, "num_hidden required" + self.num_hidden = num_hidden + self.att_num_hidden = att_num_hidden if att_num_hidden is not None else num_hidden + self.heads = heads + self.dropout = dropout + self.dropout_attention = dropout_attention if dropout_attention is not None else dropout + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + # TODO: can we simplify this? (e.g. by using all attributes of the config object) + return MultiHeadSelfAttentionEncoderLayer(num_hidden=self.num_hidden, + att_num_hidden=self.att_num_hidden, + dropout=self.dropout, + dropout_attention=self.dropout_attention, + heads=self.heads, + prefix=prefix + "mh_self_att_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return MultiHeadSelfAttentionDecoderLayer(num_hidden=self.num_hidden, + att_num_hidden=self.att_num_hidden, + dropout=self.dropout, + dropout_attention=self.dropout_attention, + heads=self.heads, + prefix=prefix + "mh_self_att_") + + +# TODO: make sure the number of hidden units does not change! +class ResidualEncoderLayer(EncoderLayer): + def __init__(self, layers: List[EncoderLayer]) -> None: + self.layers = layers + + def encode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + att_dict: Optional[Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + new_source_encoded = source_encoded + for layer in self.layers: + new_source_encoded, new_source_encoded_lengths, new_source_encoded_max_length = layer.encode_sequence( + new_source_encoded, + source_encoded_lengths, + source_encoded_max_length, + att_dict) + assert source_encoded_max_length == new_source_encoded_max_length, C.SEQUENCE_LENGTH_MUST_NOT_CHANGE_MSG + assert source_encoded_lengths is new_source_encoded_lengths, C.SEQUENCE_LENGTH_MUST_NOT_CHANGE_MSG + + return source_encoded + new_source_encoded, source_encoded_lengths, source_encoded_max_length + + def att_names(self): + att_names = [] + for layer in self.layers: + att_names.extend(layer.att_names()) + return att_names + + def get_num_hidden(self) -> int: + return self.layers[-1].get_num_hidden() + + +# TODO: potentially add a projection layer (for when the shapes don't match up). Alternative: check that the input num hidden matches the output num_hidden (maybe add a get_input_num_hidden()) +# TODO: consider inheriting from both NestedDecoderLayer and SharedEncoderDecoderLayer to just have a single implementation +class ResidualDecoderLayer(NestedDecoderLayer): + + def __init__(self, layers: List[DecoderLayer]) -> None: + self._layers = layers + + @property + def layers(self) -> List[DecoderLayer]: + return self._layers + + def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + target_encoded_input = target_encoded + for layer in self.layers: + target_encoded = layer.decode_sequence(source_encoded, source_encoded_lengths, source_encoded_max_length, + target_encoded, target_encoded_lengths, target_encoded_max_length, + target_autoregressive_bias) + + return target_encoded_input + target_encoded + + def decode_step(self, step: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, layer_states_flat: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + + new_layer_states_flat = [] # type: List[mx.sym.Symbol] + target_input = target + + for layer, layer_states in self.layers_with_states_iter(step, layer_states_flat): + target, new_layer_states = layer.decode_step(step, source_encoded, source_encoded_lengths, + source_encoded_max_length, target, layer_states, att_dict) + new_layer_states_flat.extend(new_layer_states) + return target_input + target, new_layer_states_flat + + def get_num_hidden(self) -> int: + return self._layers[-1].get_num_hidden() + + +class ResidualLayerConfig(LayerConfig): + def __init__(self, layer_configs: List[LayerConfig]) -> None: + super().__init__() + self.layer_configs = layer_configs + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + layers = [] + original_input_num_hidden = input_num_hidden + for idx, layer in enumerate(self.layer_configs): + new_layer = layer.create_encoder_layer(input_num_hidden, "%sres%d_" % (prefix, idx)) + input_num_hidden = new_layer.get_num_hidden() + layers.append(new_layer) + utils.check_condition(original_input_num_hidden == input_num_hidden, + "The input and output number of hidden units of the residual connection must match (%d vs %d)" % ( + original_input_num_hidden, input_num_hidden)) + return ResidualEncoderLayer(layers) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + layers = [] + original_input_num_hidden = input_num_hidden + for idx, layer in enumerate(self.layer_configs): + new_layer = layer.create_decoder_layer(input_num_hidden, "%sres%d_" % (prefix, idx)) + input_num_hidden = new_layer.get_num_hidden() + layers.append(new_layer) + utils.check_condition(original_input_num_hidden == input_num_hidden, + "The input and output number of hidden units of the residual connection must match (%d vs %d)" % ( + original_input_num_hidden, input_num_hidden)) + return ResidualDecoderLayer(layers) + + +# TODO: make this a block!? +class HighwayLayer: + + def __init__(self, num_hidden: int, gate_input: str, gated: str, prefix: str) -> None: + self.gate_input = gate_input + self.gated = gated + self.ff = FeedForwardBlock(num_hidden=num_hidden, + dropout=0.0, + act_type=C.SIGMOID, + prefix=prefix) + + def highway(self, input: mx.sym.Symbol, gate_input: mx.sym.Symbol, input_lengths: mx.sym.Symbol, + input_max_length: int, output: mx.sym.Symbol): + if self.gate_input == 'input': + gate = self.ff(gate_input, input_lengths, input_max_length) + elif self.gate_input == 'output': + gate = self.ff(output, input_lengths, input_max_length) + elif self.gate_input == 'both': + gate = self.ff(mx.sym.concat(gate_input, output, dim=2), input_lengths, input_max_length) + else: + raise ValueError("unknown gate input %s" % self.gate_input) + if self.gated == "both": + return gate * input + (1. - gate) * output + elif self.gated == "output": + return input + gate * output + else: + raise ValueError("unknown gate method %s" % self.gated) + + def highway_step(self, step: int, input: mx.sym.Symbol, gate_input: mx.sym.Symbol, output: mx.sym.Symbol): + if self.gate_input == 'input': + gate = self.ff.step(step, gate_input) + elif self.gate_input == 'output': + gate = self.ff.step(step, output) + elif self.gate_input == 'both': + gate = self.ff.step(step, mx.sym.concat(gate_input, output, dim=1)) + else: + raise ValueError("unknown gate input %s" % self.gate_input) + if self.gated == "both": + return gate * input + (1. - gate) * output + elif self.gated == "output": + return input + gate * output + else: + raise ValueError("unknown gate method %s" % self.gated) + + +class HighwayEncoderLayer(EncoderLayer, HighwayLayer): + + def __init__(self, + layers: List[EncoderLayer], + gate_input: str, + gated: str, + prefix: str = "") -> None: + super().__init__(num_hidden=layers[-1].get_num_hidden(), gate_input=gate_input, gated=gated, prefix=prefix) + self._layers = layers + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + # TODO: make sure input num hidden equals output num hidden + highway_input = source_encoded + gate_input = source_encoded + + new_source_encoded = source_encoded + for layer in self._layers: + new_source_encoded = layer.encode_sequence(new_source_encoded, source_encoded_lengths, + source_encoded_max_length, att_dict)[0] + + return self.highway(highway_input, + gate_input, + source_encoded_lengths, + source_encoded_max_length, + new_source_encoded), source_encoded_lengths, source_encoded_max_length + + def att_names(self): + att_names = [] + for layer in self._layers: + att_names.extend(layer.att_names()) + return att_names + + def get_num_hidden(self) -> int: + return self._layers[-1].get_num_hidden() + + +class HighwayDecoderLayer(NestedDecoderLayer, HighwayLayer): + + def __init__(self, + layers: List[DecoderLayer], + gate_input: str, + gated: str, + prefix: str = "") -> None: + # TODO: make sure input num hidden equals output num hidden + super().__init__(num_hidden=layers[-1].get_num_hidden(), gate_input=gate_input, gated=gated, prefix=prefix) + self._layers = layers + num_hidden = layers[-1].get_num_hidden() + + @property + def layers(self) -> List[DecoderLayer]: + return self._layers + + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + highway_input = target_encoded + gate_input = target_encoded + + target_encoded_input = target_encoded + for layer in self.layers: + target_encoded = layer.decode_sequence(source_encoded, source_encoded_lengths, source_encoded_max_length, + target_encoded, target_encoded_lengths, target_encoded_max_length, + target_autoregressive_bias) + + return self.highway(highway_input, + gate_input, + target_encoded_lengths, + target_encoded_max_length, + target_encoded) + + def decode_step(self, step: int, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, layer_states_flat: Sequence[mx.sym.Symbol], + att_dict: Dict[str, Dict[str, mx.sym.Symbol]]) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + highway_input = target + gate_input = target + + new_layer_states_flat = [] # type: List[mx.sym.Symbol] + + for layer, layer_states in self.layers_with_states_iter(step, layer_states_flat): + target, new_layer_states = layer.decode_step(step, source_encoded, source_encoded_lengths, + source_encoded_max_length, target, layer_states, att_dict) + new_layer_states_flat.extend(new_layer_states) + + return self.highway_step(step, highway_input, gate_input, target), new_layer_states_flat + + def get_num_hidden(self) -> int: + return self._layers[-1].get_num_hidden() + + +class HighwayLayerConfig(LayerConfig): + + def __init__(self, layer_configs: List[LayerConfig], gate_input="input", gated: str = "both") -> None: + super().__init__() + self.layer_configs = layer_configs + self.gate_input = gate_input + self.gated = gated + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + original_input_num_hidden = input_num_hidden + layers = [] + for idx, layer_config in enumerate(self.layer_configs): + layer = layer_config.create_encoder_layer(input_num_hidden, "%shighway%d_" % (prefix, idx)) + input_num_hidden = layer.get_num_hidden() + layers.append(layer) + assert original_input_num_hidden == layers[-1].get_num_hidden(), "The number of hidden units of the output of the highway must be equal to its input." + return HighwayEncoderLayer(layers=layers, gate_input=self.gate_input, + gated=self.gated, + prefix=prefix + "highway_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + original_input_num_hidden = input_num_hidden + layers = [] + for idx, layer_config in enumerate(self.layer_configs): + layer = layer_config.create_decoder_layer(input_num_hidden, "%s_highway%d" % (prefix, idx)) + input_num_hidden = layer.get_num_hidden() + layers.append(layer) + assert original_input_num_hidden == layers[-1].get_num_hidden(), "The number of hidden units of the output of the highway must be equal to its input." + return HighwayDecoderLayer(layers=layers, gate_input=self.gate_input, + gated=self.gated, + prefix=prefix + "highway_") + + +# TODO: implement this as a stateless layer instead?! +class IdentityEncoderLayer(EncoderLayer): + def __init__(self, num_hidden): + self.num_hidden = num_hidden + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + return source_encoded, source_encoded_lengths, source_encoded_max_length + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class IdentityDecoderLayer(DecoderLayer): + def __init__(self, num_hidden): + self.num_hidden = num_hidden + + def decode_sequence(self, source_encoded: Sequence[mx.sym.Symbol], source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + return target_encoded + + def decode_step(self, step: int, source_encoded: Sequence[mx.sym.Symbol], source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, target: mx.sym.Symbol, states: Sequence[mx.sym.Symbol], att_dict) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + return target, [] + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class IdentityLayerConfig(LayerConfig): + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return IdentityEncoderLayer(num_hidden=input_num_hidden) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return IdentityDecoderLayer(num_hidden=input_num_hidden) + + +class LayerNormalization(StatelessBlock): """ Implements Ba et al, Layer Normalization (https://arxiv.org/abs/1607.06450). @@ -61,18 +1580,24 @@ class LayerNormalization: :param shift_init: Initial value of shift variable if shift is None. Default 0.0. """ def __init__(self, + num_hidden: int, prefix: str = 'layernorm', scale: Optional[mx.sym.Symbol] = None, shift: Optional[mx.sym.Symbol] = None, scale_init: float = 1.0, shift_init: float = 0.0) -> None: self.prefix = prefix + self.num_hidden = num_hidden self.scale = scale if scale is not None else mx.sym.Variable('%s_gamma' % prefix, init=mx.init.Constant(value=scale_init)) self.shift = shift if shift is not None else mx.sym.Variable('%s_beta' % prefix, init=mx.init.Constant(value=shift_init)) - def __call__(self, data: mx.sym.Symbol, eps: float = 1e-06) -> mx.sym.Symbol: + def __call__(self, + data: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + max_length: Optional[int] = None, + eps: float = 1e-06) -> mx.sym.Symbol: """ Normalizes hidden units of data as follows: @@ -87,6 +1612,24 @@ def __call__(self, data: mx.sym.Symbol, eps: float = 1e-06) -> mx.sym.Symbol: return mx.sym.LayerNorm(data=data, gamma=self.scale, beta=self.shift, axis=-1, eps=eps, output_mean_var=False, name=self.prefix) + def get_num_hidden(self): + return self.num_hidden + + +class LayerNormalizationLayerConfig(LayerConfig): + def __init__(self): + super().__init__() + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> EncoderLayer: + return self.create_layer(input_num_hidden, prefix) + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> DecoderLayer: + return self.create_layer(input_num_hidden, prefix) + + def create_layer(self, num_hidden: int, prefix) -> StatelessBlockLayer: + return StatelessBlockLayer(LayerNormalization(num_hidden=num_hidden, + prefix=prefix + "norm_")) + class LHUC: """ @@ -303,9 +1846,6 @@ def dot_attention(queries: mx.sym.Symbol, :param prefix: Optional prefix :return: 'Context' vectors for each query. Shape: (n, lq, dv). """ - utils.check_condition(lengths is not None or bias is not None, - "Must provide either length or bias argument for masking") - # (n, lq, lk) logits = mx.sym.batch_dot(lhs=queries, rhs=keys, transpose_b=True, name='%sdot' % prefix) @@ -327,7 +1867,7 @@ def dot_attention(queries: mx.sym.Symbol, probs = mx.sym.Dropout(probs, p=dropout) if dropout > 0.0 else probs # (n, lq, lk) x (n, lk, dv) -> (n, lq, dv) - return mx.sym.batch_dot(lhs=probs, rhs=values, name='%scontexts' % prefix) + return mx.sym.batch_dot(lhs=probs, rhs=values, name='%scontexts' % prefix), probs class MultiHeadAttentionBase: @@ -362,7 +1902,7 @@ def _attend(self, keys: mx.sym.Symbol, values: mx.sym.Symbol, lengths: Optional[mx.sym.Symbol] = None, - bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + bias: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: """ Returns context vectors of multi-head dot attention. @@ -371,7 +1911,8 @@ def _attend(self, :param values: Values. Shape: (batch_size, memory_max_length, depth). :param lengths: Optional lengths of keys. Shape: (batch_size,). :param bias: Optional 3d bias. - :return: Context vectors. Shape: (batch_size, query_max_length, output_depth). + :return: Context vectors. Shape: (batch_size, query_max_length, output_depth) and attention probabilities. + Shape: (batch_size, query_max_length, memory_max_length). """ # scale by sqrt(depth_per_head) queries = queries * (self.depth_per_head ** -0.5) @@ -383,8 +1924,8 @@ def _attend(self, lengths = broadcast_to_heads(lengths, self.heads, ndim=1, fold_heads=True) if lengths is not None else lengths # (batch*heads, query_max_length, depth_per_head) - contexts = dot_attention(queries, keys, values, - lengths=lengths, dropout=self.dropout, bias=bias, prefix=self.prefix) + contexts, probs = dot_attention(queries, keys, values, + lengths=lengths, dropout=self.dropout, bias=bias, prefix=self.prefix) # (batch, query_max_length, depth) contexts = combine_heads(contexts, self.depth_per_head, self.heads) @@ -396,7 +1937,7 @@ def _attend(self, num_hidden=self.depth_out, flatten=False) - return contexts + return contexts, probs class MultiHeadSelfAttention(MultiHeadAttentionBase): @@ -423,7 +1964,7 @@ def __call__(self, inputs: mx.sym.Symbol, input_lengths: Optional[mx.sym.Symbol] = None, bias: Optional[mx.sym.Symbol] = None, - cache: Optional[Dict[str, Optional[mx.sym.Symbol]]] = None) -> mx.sym.Symbol: + cache: Optional[Dict[str, Optional[mx.sym.Symbol]]] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: """ Computes multi-head attention on a set of inputs, serving as queries, keys, and values. If sequence lengths are provided, they will be used to mask the attention scores. @@ -487,7 +2028,7 @@ def __call__(self, queries: mx.sym.Symbol, memory: mx.sym.Symbol, memory_lengths: Optional[mx.sym.Symbol] = None, - bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + bias: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: """ Computes multi-head attention for queries given a memory tensor. If sequence lengths are provided, they will be used to mask the attention scores. @@ -584,7 +2125,7 @@ def __call__(self, queries = queries * (self.num_hidden ** -0.5) # (batch, queries_max_length, num_hidden) - contexts = dot_attention(queries, keys, values, memory_lengths) + contexts, probs = dot_attention(queries, keys, values, memory_lengths) return contexts @@ -608,7 +2149,7 @@ def __call__(self, """ # (batch*heads, queries_max_length, depth_per_head) - contexts = dot_attention(queries, memory, memory, memory_lengths) + contexts, probs = dot_attention(queries, memory, memory, memory_lengths) return contexts diff --git a/sockeye/rnn.py b/sockeye/rnn.py index 9837ed48b..a4112bcf7 100644 --- a/sockeye/rnn.py +++ b/sockeye/rnn.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -12,12 +12,13 @@ # permissions and limitations under the License. # List is needed for mypy, but not used in the code, only in special comments -from typing import Optional, List, Iterable # NOQA pylint: disable=unused-import +from typing import Optional, List, Iterable, Tuple, Sequence # NOQA pylint: disable=unused-import import mxnet as mx +import numpy as np from sockeye.config import Config -from sockeye.layers import LayerNormalization, LHUC +from . import layers from . import constants as C from . import utils @@ -114,7 +115,65 @@ def __call__(self, inputs, parallel_inputs, states): return output, states -def get_stacked_rnn(config: RNNConfig, prefix: str, +def get_rnn_cell( + cell_type: str, + num_hidden: int, + dropout_inputs: float, + dropout_states: float, + dropout_recurrent: float = 0, + forget_bias: float = 0.0, + lhuc: bool = False, + dtype: str = C.DTYPE_FP32, + prefix: str = ''): + """ + Create a single rnn cell. + :param cell_type: RNN cell type. + :param num_hidden: Number of RNN hidden units. + :param dropout_inputs: Dropout probability on RNN inputs (Gal, 2015). + :param dropout_states: Dropout probability on RNN states (Gal, 2015). + :param dropout_recurrent: Dropout probability on cell update (Semeniuta, 2016). + :param forget_bias: Initial value of forget biases. + :param lhuc: Apply LHUC (Vilar 2018) to the hidden units of the RNN. + :param dtype: Data type. + :param prefix: Variable name prefix. + """ + if cell_type == C.LSTM_TYPE: + if dropout_recurrent > 0.0: + cell = RecurrentDropoutLSTMCell(num_hidden=num_hidden, + prefix=prefix, + forget_bias=forget_bias, + dropout=dropout_recurrent) + else: + cell = mx.rnn.LSTMCell(num_hidden=num_hidden, prefix=prefix, forget_bias=forget_bias) + elif cell_type == C.LNLSTM_TYPE: + cell = LayerNormLSTMCell(num_hidden=num_hidden, prefix=prefix, forget_bias=forget_bias) + elif cell_type == C.LNGLSTM_TYPE: + cell = LayerNormPerGateLSTMCell(num_hidden=num_hidden, prefix=prefix, + forget_bias=forget_bias) + elif cell_type == C.GRU_TYPE: + cell = mx.rnn.GRUCell(num_hidden=num_hidden, prefix=prefix) + elif cell_type == C.LNGRU_TYPE: + cell = LayerNormGRUCell(num_hidden=num_hidden, prefix=prefix) + elif cell_type == C.LNGGRU_TYPE: + cell = LayerNormPerGateGRUCell(num_hidden=num_hidden, prefix=prefix) + elif cell_type == "simple": + cell = JanetCell(num_hidden=num_hidden, prefix=prefix, + forget_bias=forget_bias) + else: + raise NotImplementedError("Unknown cell type %s" % cell_type) + + if dropout_inputs > 0 or dropout_states > 0: + cell = VariationalDropoutCell(cell, + dropout_inputs=dropout_inputs, + dropout_states=dropout_states) + if lhuc: + cell = LHUCCell(cell, num_hidden, dtype) + + return cell + + +def get_stacked_rnn(config: RNNConfig, + prefix: str, parallel_inputs: bool = False, layers: Optional[Iterable[int]] = None) -> mx.rnn.SequentialRNNCell: """ @@ -135,33 +194,10 @@ def get_stacked_rnn(config: RNNConfig, prefix: str, # fhieber: the 'l' in the prefix does NOT stand for 'layer' but for the direction 'l' as in mx.rnn.rnn_cell::517 # this ensures parameter name compatibility of training w/ FusedRNN and decoding with 'unfused' RNN. cell_prefix = "%sl%d_" % (prefix, layer_idx) - if config.cell_type == C.LSTM_TYPE: - if config.dropout_recurrent > 0.0: - cell = RecurrentDropoutLSTMCell(num_hidden=config.num_hidden, prefix=cell_prefix, - forget_bias=config.forget_bias, dropout=config.dropout_recurrent) - else: - cell = mx.rnn.LSTMCell(num_hidden=config.num_hidden, prefix=cell_prefix, forget_bias=config.forget_bias) - elif config.cell_type == C.LNLSTM_TYPE: - cell = LayerNormLSTMCell(num_hidden=config.num_hidden, prefix=cell_prefix, forget_bias=config.forget_bias) - elif config.cell_type == C.LNGLSTM_TYPE: - cell = LayerNormPerGateLSTMCell(num_hidden=config.num_hidden, prefix=cell_prefix, - forget_bias=config.forget_bias) - elif config.cell_type == C.GRU_TYPE: - cell = mx.rnn.GRUCell(num_hidden=config.num_hidden, prefix=cell_prefix) - elif config.cell_type == C.LNGRU_TYPE: - cell = LayerNormGRUCell(num_hidden=config.num_hidden, prefix=cell_prefix) - elif config.cell_type == C.LNGGRU_TYPE: - cell = LayerNormPerGateGRUCell(num_hidden=config.num_hidden, prefix=cell_prefix) - else: - raise NotImplementedError() - - if config.dropout_inputs > 0 or config.dropout_states > 0: - cell = VariationalDropoutCell(cell, - dropout_inputs=config.dropout_inputs, - dropout_states=config.dropout_states) - - if config.lhuc: - cell = LHUCCell(cell, config.num_hidden, config.dtype) + cell = get_rnn_cell(cell_type=config.cell_type, num_hidden=config.num_hidden, + dropout_inputs=config.dropout_inputs, dropout_states=config.dropout_states, + dropout_recurrent=config.dropout_recurrent, forget_bias=config.forget_bias, + lhuc=config.lhuc, dtype=config.dtype, prefix=cell_prefix) # layer_idx is 0 based, whereas first_residual_layer is 1-based if config.residual and layer_idx + 1 >= config.first_residual_layer: @@ -195,15 +231,15 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) - self._iN = LayerNormalization(prefix="%si2h" % self._prefix, - scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) - self._hN = LayerNormalization(prefix="%sh2h" % self._prefix, - scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) - self._cN = LayerNormalization(prefix="%sc" % self._prefix, - scale=self.params.get('c_scale', shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('c_shift', shape=(num_hidden,), init=mx.init.Constant(value=norm_shift))) + self._iN = layers.LayerNormalization(num_hidden=num_hidden, prefix="%si2h" % self._prefix, + scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) + self._hN = layers.LayerNormalization(num_hidden=num_hidden, prefix="%sh2h" % self._prefix, + scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) + self._cN = layers.LayerNormalization(num_hidden=num_hidden, prefix="%sc" % self._prefix, + scale=self.params.get('c_scale', shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('c_shift', shape=(num_hidden,), init=mx.init.Constant(value=norm_shift))) def __call__(self, inputs, states): self._counter += 1 @@ -257,14 +293,15 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormPerGateLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) - self._norm_layers = list() # type: List[LayerNormalization] + self._norm_layers = list() # type: List[layers.LayerNormalization] for name in ['i', 'f', 'c', 'o', 's']: scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale if name != "f" else forget_bias)) self._norm_layers.append( - LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) + layers.LayerNormalization(prefix="%s%s" % (self._prefix, name), num_hidden=num_hidden, + scale=scale, shift=shift)) def __call__(self, inputs, states): self._counter += 1 @@ -309,7 +346,7 @@ def __init__(self, base_cell, num_hidden, dtype) -> None: super().__init__(base_cell) self.num_hidden = num_hidden self.lhuc_params = self.params.get(C.LHUC_NAME, shape=(num_hidden,), dtype=dtype, init=mx.init.Uniform(0.1)) - self.lhuc = LHUC(num_hidden, self.lhuc_params) + self.lhuc = layers.LHUC(num_hidden, self.lhuc_params) def __call__(self, inputs, states): output, states = self.base_cell(inputs, states) @@ -377,12 +414,16 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params) - self._iN = LayerNormalization(prefix="%si2h" % self._prefix, - scale=self.params.get('i2h_scale', init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('i2h_shift', init=mx.init.Constant(value=norm_shift))) - self._hN = LayerNormalization(prefix="%sh2h" % self._prefix, - scale=self.params.get('h2h_scale', init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('h2h_shift', init=mx.init.Constant(value=norm_shift))) + self._iN = layers.LayerNormalization( + prefix="%si2h" % self._prefix, + num_hidden=num_hidden, + scale=self.params.get('i2h_scale', init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('i2h_shift', init=mx.init.Constant(value=norm_shift))) + self._hN = layers.LayerNormalization( + prefix="%sh2h" % self._prefix, + num_hidden=num_hidden, + scale=self.params.get('h2h_scale', init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('h2h_shift', init=mx.init.Constant(value=norm_shift))) def __call__(self, inputs, states): self._counter += 1 @@ -442,11 +483,12 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormPerGateGRUCell, self).__init__(num_hidden, prefix, params) - self._norm_layers = list() # type: List[LayerNormalization] + self._norm_layers = list() # type: List[layers.LayerNormalization] for name in ['r', 'z', 'o']: scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale)) - self._norm_layers.append(LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) + self._norm_layers.append(layers.LayerNormalization( + prefix="%s%s" % (self._prefix, name), num_hidden=num_hidden, scale=scale, shift=shift)) def __call__(self, inputs, states): self._counter += 1 @@ -522,3 +564,298 @@ def reset(self): super(VariationalDropoutCell, self).reset() self.mask_inputs = None self.mask_states = None + + +class JanetCell(mx.rnn.BaseRNNCell): + """Janet cell, as described in: + https://arxiv.org/pdf/1804.04849.pdf + + Parameters + ---------- + num_hidden : int + Number of units in output symbol. + prefix : str, default 'lstm_' + Prefix for name of layers (and name of weight if params is None). + params : RNNParams, default None + Container for weight sharing between cells. Created if None. + forget_bias : bias added to forget gate, default 1.0. + Jozefowicz et al. 2015 recommends setting this to 1.0 + """ + def __init__(self, num_hidden, prefix='lstm_', params=None, forget_bias=1.0): + super().__init__(prefix=prefix, params=params) + + self._num_hidden = num_hidden + self._iW = self.params.get('i2h_weight') + self._hW = self.params.get('h2h_weight') + # we add the forget_bias to i2h_bias, this adds the bias to the forget gate activation + self._iB = self.params.get('i2h_bias', init=mx.init.LSTMBias(forget_bias=forget_bias)) + self._hB = self.params.get('h2h_bias') + + @property + def state_info(self): + return [{'shape': (0, self._num_hidden), '__layout__': 'NC'}] + + @property + def _gate_names(self): + return ['_f', '_o'] + + def __call__(self, inputs, states): + self._counter += 1 + name = '%st%d_'%(self._prefix, self._counter) + i2h = mx.sym.FullyConnected(data=inputs, weight=self._iW, bias=self._iB, + num_hidden=self._num_hidden*2, + name='%si2h'%name) + h2h = mx.sym.FullyConnected(data=states[0], weight=self._hW, bias=self._hB, + num_hidden=self._num_hidden*2, + name='%sh2h'%name) + gates = i2h + h2h + slice_gates = mx.sym.SliceChannel(gates, num_outputs=2, + name="%sslice"%name) + forget_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid", + name='%sf'%name) + in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") + + next_h = mx.sym._internal._plus(forget_gate * states[0], (1. - forget_gate) * in_transform, + name='%sstate' % name) + + return next_h, [next_h] + + +class RecurrentLayerRNNConfig(Config): + """ + :param num_hidden: Number of RNN hidden units. + :param dropout_recurrent: Dropout probability on cell update (Semeniuta, 2016). + :param dropout_inputs: Dropout probability on RNN inputs (Gal, 2015). + :param dropout_states: Dropout probability on RNN states (Gal, 2015). + :param cell_type: RNN cell type. + :param forget_bias: Initial value of forget biases. + :param lhuc: Apply LHUC (Vilar 2018) to the hidden units of the RNN. + :param dtype: Data type. + """ + + def __init__(self, + num_hidden: int, + dropout_recurrent: float = 0.0, + dropout_inputs: float = 0.0, + dropout_states: float = 0.0, + norm_states: bool = True, + norm_first_step: bool = True, + cell_type: str = C.LSTM_TYPE, + forget_bias: float = 0.0, + lhuc: bool = False, + dtype: str = C.DTYPE_FP32) -> None: + super().__init__() + self.num_hidden = num_hidden + # recurrent/inputs/states is for "old" cells and just "dropout" for "new states" + self.dropout_recurrent = dropout_recurrent + self.dropout_inputs = dropout_inputs + self.dropout_states = dropout_states + self.norm_states = norm_states + self.norm_first_step = norm_first_step + self.cell_type = cell_type + self.forget_bias = forget_bias + self.lhuc = lhuc + self.dtype = dtype + + def create_rnn_cell(self, prefix: str): + cell = get_rnn_cell(cell_type=self.cell_type, num_hidden=self.num_hidden, + dropout_inputs=self.dropout_inputs, dropout_states=self.dropout_states, + dropout_recurrent=self.dropout_recurrent, forget_bias=self.forget_bias, + lhuc=self.lhuc, dtype=self.dtype, prefix=prefix) + + if self.dropout_inputs > 0 or self.dropout_states > 0: + cell = VariationalDropoutCell(cell, + dropout_inputs=self.dropout_inputs, + dropout_states=self.dropout_states) + return cell + + +class RecurrentEncoderLayer(layers.EncoderLayer): + + def __init__(self, + rnn_config: RecurrentLayerRNNConfig, + prefix: str = "") -> None: + self.rnn_cell = rnn_config.create_rnn_cell(prefix) + self.num_hidden = rnn_config.num_hidden + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + outputs, _ = self.rnn_cell.unroll(length=source_encoded_max_length, + inputs=source_encoded, + merge_outputs=True, + layout=C.BATCH_MAJOR) + return outputs, source_encoded_lengths, source_encoded_max_length + + def get_num_hidden(self) -> int: + return self.num_hidden + + +class RecurrentDecoderLayer(layers.DecoderLayer): + + def __init__(self, + rnn_config: RecurrentLayerRNNConfig, + prefix: str = "") -> None: + self.prefix = prefix + self.rnn_cell = rnn_config.create_rnn_cell(prefix) + self.num_hidden = rnn_config.num_hidden + + def decode_sequence(self, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target_encoded: mx.sym.Symbol, + target_encoded_lengths: mx.sym.Symbol, + target_encoded_max_length: int, + target_autoregressive_bias: mx.sym.Symbol) -> mx.sym.Symbol: + outputs, _ = self.rnn_cell.unroll(length=target_encoded_max_length, + inputs=target_encoded, + merge_outputs=True, + layout=C.BATCH_MAJOR) + + return outputs + + def decode_step(self, step: int, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, + target: mx.sym.Symbol, + states: Sequence[mx.sym.Symbol], + att_dict) -> Tuple[mx.sym.Symbol, Sequence[mx.sym.Symbol]]: + return self.rnn_cell(target, states) + + def reset(self): + # TODO remove this once mxnet.rnn.ModifierCell.reset() invokes reset() of base_cell + cell = self.rnn_cell + if isinstance(cell, mx.rnn.ModifierCell): + cell.base_cell.reset() + cell.reset() + + def get_num_hidden(self) -> int: + return self.num_hidden + + def num_states(self, step: int) -> int: + return len(self.rnn_cell.state_info) + + def state_variables(self, step: int) -> Sequence[mx.sym.Symbol]: + return [mx.sym.Variable("%rnn_state_%d" % (self.prefix, i)) + for i, state_info in enumerate(self.rnn_cell.state_info)] + + def init_states(self, + batch_size, + source_encoded: mx.sym.Symbol, + source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int) -> Sequence[mx.sym.Symbol]: + return [mx.sym.zeros(shape=(batch_size, num_hidden)) for (_, num_hidden) in self.rnn_cell.state_shape] + + def state_shapes(self, + batch_size: int, + target_max_length: int, + source_encoded_max_length: int, + source_encoded_num_hidden: int) -> List[mx.io.DataDesc]: + return [mx.io.DataDesc("%rnn_state_%d" % (self.prefix, i), + (batch_size, num_hidden), + layout=C.BATCH_MAJOR) for i, (_, num_hidden) in enumerate(self.rnn_cell.state_shape)] + + +class RecurrentLayerConfig(layers.LayerConfig): + def __init__(self, + num_hidden: int, + cell_type: str = C.LSTM_TYPE, + dropout_inputs: float = 0.0, + dropout_states: float = 0.0, + dropout_recurrent: float = 0.0, + norm_states: bool = True, + norm_first_step: bool = True, + forget_bias: float = 0.0) -> None: + super().__init__() + self.rnn_config = RecurrentLayerRNNConfig(num_hidden=num_hidden, + dropout_recurrent=dropout_recurrent, + dropout_inputs=dropout_inputs, + dropout_states=dropout_states, + norm_states=norm_states, + norm_first_step=norm_first_step, + cell_type=cell_type, + forget_bias=forget_bias) + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> layers.EncoderLayer: + return RecurrentEncoderLayer(rnn_config=self.rnn_config, prefix=prefix + "rnn_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> layers.DecoderLayer: + return RecurrentDecoderLayer(rnn_config=self.rnn_config, prefix=prefix + "rnn_") + + +class BidirectionalRecurrentEncoderLayer(layers.EncoderLayer): + def __init__(self, + rnn_config: RecurrentLayerRNNConfig, + prefix: str = "") -> None: + self.prefix = prefix + utils.check_condition(rnn_config.num_hidden % 2 == 0, + "num_hidden must be a multiple of 2 for BiDirectionalRNNEncoders.") + self.rnn_config = rnn_config + self.internal_rnn_config = rnn_config.copy(num_hidden=rnn_config.num_hidden // 2) + + self.forward_rnn_cell = self.internal_rnn_config.create_rnn_cell(prefix + C.FORWARD_PREFIX) + self.backward_rnn_cell = self.internal_rnn_config.create_rnn_cell(prefix + C.REVERSE_PREFIX) + + def encode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, + source_encoded_max_length: int, att_dict) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]: + # Batch major to time major for sequence reverse + # (batch_size, seq_len, num_hidden) -> (seq_len, batch_size, num_hidden) + data = mx.sym.transpose(data=source_encoded, axes=(1, 0, 2)) + + # (seq_len, batch_size, num_embed) + data_reverse = mx.sym.SequenceReverse(data=data, + sequence_length=source_encoded_lengths, + use_sequence_length=True) + # (seq_length, batch, cell_num_hidden) + hidden_forward, _ = self.forward_rnn_cell.unroll(length=source_encoded_max_length, + inputs=data, + merge_outputs=True, + layout=C.TIME_MAJOR) + # (seq_length, batch, cell_num_hidden) + hidden_reverse, _ = self.backward_rnn_cell.unroll(length=source_encoded_max_length, + inputs=data_reverse, + merge_outputs=True, + layout=C.TIME_MAJOR) + + # (seq_length, batch, cell_num_hidden) + hidden_reverse = mx.sym.SequenceReverse(data=hidden_reverse, + sequence_length=source_encoded_lengths, + use_sequence_length=True) + + # (seq_length, batch, 2 * cell_num_hidden) + hidden_concat = mx.sym.concat(hidden_forward, hidden_reverse, dim=2, name="%s_rnn" % self.prefix) + + # Time major to batch major for sequence reverse + # (seq_len, batch_size, num_hidden) -> (batch_size, seq_len, num_hidden) + hidden_concat = mx.sym.transpose(data=hidden_concat, axes=(1, 0, 2)) + return hidden_concat, source_encoded_lengths, source_encoded_max_length + + def get_num_hidden(self) -> int: + return self.rnn_config.num_hidden + + +class BidirectionalRecurrentLayerConfig(layers.LayerConfig): + def __init__(self, + num_hidden: int, + cell_type: str = C.LSTM_TYPE, + dropout_inputs: float = 0.0, + dropout_states: float = 0.0, + dropout_recurrent: float = 0.0, + norm_states: bool = True, + forget_bias: float = 0.0) -> None: + super().__init__() + self.rnn_config = RecurrentLayerRNNConfig(num_hidden=num_hidden, + dropout_recurrent=dropout_recurrent, + dropout_inputs=dropout_inputs, + dropout_states=dropout_states, + norm_states=norm_states, + cell_type=cell_type, + forget_bias=forget_bias) + + def create_encoder_layer(self, input_num_hidden: int, prefix: str) -> layers.EncoderLayer: + return BidirectionalRecurrentEncoderLayer(rnn_config=self.rnn_config, prefix=prefix + "birnn_") + + def create_decoder_layer(self, input_num_hidden: int, prefix: str) -> layers.DecoderLayer: + raise ValueError("Bi-directional RNN can only be used on the encoder side.") + diff --git a/sockeye/rnn_attention.py b/sockeye/rnn_attention.py index 88c6e6f20..c13377ac2 100644 --- a/sockeye/rnn_attention.py +++ b/sockeye/rnn_attention.py @@ -645,7 +645,7 @@ def __init__(self, # layer normalization self._ln = None if layer_normalization: - self._ln = layers.LayerNormalization(prefix="%snorm" % self.prefix) + self._ln = layers.LayerNormalization(prefix="%snorm" % self.prefix, num_hidden=num_hidden) def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ diff --git a/sockeye/train.py b/sockeye/train.py index 49bd06d4d..789199dab 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -46,6 +46,7 @@ from .log import setup_main_logger from .optimizers import OptimizerConfig from .utils import check_condition +from . import custom_seq_parser # Temporary logger, the real one (logging to a file probably, will be created in the main function) logger = setup_main_logger(__name__, file_logging=False, console=True) @@ -422,6 +423,22 @@ def create_encoder_config(args: argparse.Namespace, positional_embedding_type=args.cnn_positional_embedding_type) encoder_num_hidden = args.cnn_num_hidden + elif args.encoder == C.CUSTOM_SEQ_TYPE: + logger.info("Creating encoder from configuration '%s'." % args.custom_seq_encoder) + utils.check_condition(args.custom_seq_encoder is not None, + "Please specify the custom encoder layer sequence using --custom-seq-encoder.") + layer_configs = custom_seq_parser.parse_custom_seq_layers_description(default_dropout=args.custom_seq_dropout, + default_num_hidden=args.custom_seq_num_hidden, + default_num_embed=num_embed_source, + max_seq_len=max_seq_len_source, + description=args.custom_seq_encoder, + source_attention_needed=False, + source_attention_forbidden=True) + config_encoder = encoder.CustomSeqEncoderConfig(encoder_layers=layer_configs, + num_embed=num_embed_source) + + # TODO: how to set this correctly!? + encoder_num_hidden = None else: encoder_rnn_dropout_inputs, _ = args.rnn_dropout_inputs encoder_rnn_dropout_states, _ = args.rnn_dropout_states @@ -458,7 +475,7 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int, _, decoder_num_layers = args.num_layers _, num_embed_target = args.num_embed - config_decoder = None # type: Optional[Config] + config_decoder = None # type: Optional[decoder.DecoderConfig] if args.decoder == C.TRANSFORMER_TYPE: if args.decoder_only: @@ -499,7 +516,7 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int, project_qkv=args.cnn_project_qkv, hidden_dropout=args.cnn_hidden_dropout) - else: + elif args.decoder == C.RNN_NAME: if args.decoder_only: args.rnn_decoder_state_init = C.RNN_DEC_INIT_ZERO args.rnn_context_gating = False @@ -548,7 +565,22 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int, attention_in_upper_layers=args.rnn_attention_in_upper_layers, state_init_lhuc=args.lhuc is not None and (C.LHUC_STATE_INIT in args.lhuc or C.LHUC_ALL in args.lhuc), enc_last_hidden_concat_to_embedding=args.rnn_enc_last_hidden_concat_to_embedding) - + elif args.decoder == C.CUSTOM_SEQ_TYPE: + logger.info("Creating decoder from configuration '%s'." % args.custom_seq_decoder) + # TODO: move argument to constant + utils.check_condition(args.custom_seq_decoder is not None, + "Please specify the custom decoder layer sequence using --custom-seq-decoder.") + layer_configs = custom_seq_parser.parse_custom_seq_layers_description(default_dropout=args.custom_seq_dropout, + default_num_hidden=args.custom_seq_num_hidden, + default_num_embed=num_embed_target, + max_seq_len=max_seq_len_target, + description=args.custom_seq_decoder, + source_attention_needed=True, + source_attention_forbidden=False) + config_decoder = decoder.CustomSeqDecoderConfig(decoder_layers=layer_configs, + num_embed=num_embed_target) + else: + raise ValueError("Unknown decoder type %s" % args.decoder) return config_decoder diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 7d49b6a1e..ac701d537 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -1,4 +1,4 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -73,6 +73,7 @@ def __init__(self, prefix: str) -> None: self.pre_self_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_self_pre_" % prefix) self.self_attention = layers.MultiHeadSelfAttention(depth_att=config.model_size, heads=config.attention_heads, @@ -81,10 +82,12 @@ def __init__(self, prefix="%satt_self_" % prefix) self.post_self_attention = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_self_post_" % prefix) self.pre_ff = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%sff_pre_" % prefix) self.ff = TransformerFeedForward(num_hidden=config.feed_forward_num_hidden, num_model=config.model_size, @@ -93,6 +96,7 @@ def __init__(self, prefix="%sff_" % prefix) self.post_ff = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%sff_post_" % prefix) self.lhuc = None if config.use_lhuc: @@ -100,9 +104,9 @@ def __init__(self, def __call__(self, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention - data_self_att = self.self_attention(inputs=self.pre_self_attention(data, None), - bias=bias, - cache=None) + data_self_att, _ = self.self_attention(inputs=self.pre_self_attention(data, None), + bias=bias, + cache=None) data = self.post_self_attention(data_self_att, data) # feed-forward @@ -127,6 +131,7 @@ def __init__(self, self.prefix = prefix self.pre_self_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_self_pre_" % prefix) self.self_attention = layers.MultiHeadSelfAttention(depth_att=config.model_size, heads=config.attention_heads, @@ -135,10 +140,12 @@ def __init__(self, prefix="%satt_self_" % prefix) self.post_self_attention = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_self_post_" % prefix) self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_enc_pre_" % prefix) self.enc_attention = layers.MultiHeadAttention(depth_att=config.model_size, heads=config.attention_heads, @@ -147,10 +154,12 @@ def __init__(self, prefix="%satt_enc_" % prefix) self.post_enc_attention = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%satt_enc_post_" % prefix) self.pre_ff = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%sff_pre_" % prefix) self.ff = TransformerFeedForward(num_hidden=config.feed_forward_num_hidden, num_model=config.model_size, @@ -159,6 +168,7 @@ def __init__(self, prefix="%sff_" % prefix) self.post_ff = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, + model_size=config.model_size, prefix="%sff_post_" % prefix) self.lhuc = None @@ -172,15 +182,15 @@ def __call__(self, source_bias: mx.sym.Symbol, cache: Optional[Dict[str, Optional[mx.sym.Symbol]]] = None) -> mx.sym.Symbol: # self-attention - target_self_att = self.self_attention(inputs=self.pre_self_attention(target, None), - bias=target_bias, - cache=cache) + target_self_att, _ = self.self_attention(inputs=self.pre_self_attention(target, None), + bias=target_bias, + cache=cache) target = self.post_self_attention(target_self_att, target) # encoder attention - target_enc_att = self.enc_attention(queries=self.pre_enc_attention(target, None), - memory=source, - bias=source_bias) + target_enc_att, _ = self.enc_attention(queries=self.pre_enc_attention(target, None), + memory=source, + bias=source_bias) target = self.post_enc_attention(target_enc_att, target) # feed-forward @@ -205,13 +215,14 @@ class TransformerProcessBlock: def __init__(self, sequence: str, dropout: float, + model_size: int, prefix: str) -> None: self.sequence = sequence self.dropout = dropout self.prefix = prefix self.layer_norm = None if "n" in sequence: - self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix) + self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix, num_hidden=model_size) def __call__(self, data: mx.sym.Symbol, @@ -356,70 +367,3 @@ def get_variable_length_bias(lengths: mx.sym.Symbol, x = layers.broadcast_to_heads(x, num_heads, ndim=2, fold_heads=fold_heads) return mx.sym.BlockGrad(x, name='%sbias' % name) - -def get_autoregressive_bias(max_length: int, name: str) -> mx.sym.Symbol: - """ - Returns bias/mask to ensure position i can only attend to positions None: - super().__init__() - self.bias = self.get_bias(length, dtype, ctx) - - @staticmethod - def get_bias(length: int, dtype: str, ctx: mx.Context): - # matrix with lower triangle and main diagonal set to 0, upper triangle set to 1 - upper_triangle = np.triu(np.ones((length, length), dtype=dtype), k=1) - # (1, length, length) - bias = -C.LARGE_VALUES[dtype] * np.reshape(upper_triangle, (1, length, length)) - return mx.nd.array(bias, ctx=ctx) - - def forward(self, is_train, req, in_data, out_data, aux): - self.assign(out_data[0], req[0], self.bias) - - def backward(self, req, out_grad, in_data, out_data, in_grad, aux): - pass - - -@mx.operator.register("auto_regressive_bias") -class AutoRegressiveBiasProp(mx.operator.CustomOpProp): - - def __init__(self, length: str, dtype: str = C.DTYPE_FP32) -> None: - super().__init__() - self.length = int(length) - self.dtype = dtype - - def list_arguments(self): - return [] - - def list_outputs(self): - return ['output'] - - def infer_shape(self, in_shape): - return [], [(1, self.length, self.length)], [] - - def infer_type(self, in_type): - return [], [np.dtype(self.dtype).type], [] - - def create_operator(self, ctx, shapes, dtypes): - return AutoRegressiveBias(length=self.length, dtype=self.dtype, ctx=ctx) diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index f5fd212a8..7508d35ea 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -15,7 +15,6 @@ import pytest import tempfile import os -import yaml import sockeye.arguments as arguments import sockeye.constants as C @@ -96,6 +95,10 @@ def test_device_args(test_params, expected_params): transformer_positional_embedding_type="fixed", transformer_preprocess=('n', 'n'), transformer_postprocess=('dr', 'dr'), + custom_seq_encoder='res(norm->mh_dot_att)->res(norm->ff->linear))', + custom_seq_decoder='res(norm->mh_dot_self_att)->res(norm->mh_dot_att)->res(norm->ff->linear))', + custom_seq_num_hidden=512, + custom_seq_dropout=0.1, rnn_attention_use_prev_word=False, rnn_decoder_state_init="last", rnn_encoder_reverse_input=False, diff --git a/test/unit/test_custom_seq_parser.py b/test/unit/test_custom_seq_parser.py new file mode 100644 index 000000000..20457fe6d --- /dev/null +++ b/test/unit/test_custom_seq_parser.py @@ -0,0 +1,45 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import pytest + +from sockeye import custom_seq_parser + +descriptions_with_parses = [ + ("rnn", [{'name': 'rnn', 'params': None}]), + ("rnn->rnn", [{'name': 'rnn', 'params': None}, {'name': 'rnn', 'params': None}]), + # Transformer + ("pos->repeat(6,res(norm->mh_dot_self_att)->res(norm->mh_dot_att)->res(norm->ff(2048)->linear(512)))->norm", + [ + {"name": 'pos', 'params': None}, + {"name": "repeat", "num": 6, 'layers': [ + {'name': 'res', 'layers': [{'name': 'norm', 'params': None}, {'name': 'mh_dot_self_att', 'params': None}]}, + {'name': 'res', 'layers': [{'name': 'norm', 'params': None}, {'name': 'mh_dot_att', 'params': None}]}, + {'name': 'res', 'layers': [{'name': 'norm', 'params': None}, + {'name': 'ff', 'params': [2048]}, + {'name': 'linear', 'params': [512]}]}]}, + {'name': 'norm', 'params': None} + ]), + # keyword args + ("ff(1,2,key1=3,key2=4)", [{"name": "ff", "params": [1, 2, ('key1', 3), ('key2', 4)]}]) +] +# TODO: add tests for parallel layers and other meta layers + + +@pytest.mark.parametrize("description, expected_parsed_layers", descriptions_with_parses) +def test_parser(description, expected_parsed_layers): + parsed_layers = custom_seq_parser.parse(description) + print(parsed_layers) + + assert parsed_layers == expected_parsed_layers diff --git a/test/unit/test_layers.py b/test/unit/test_layers.py index 2391ebcf0..493acc000 100644 --- a/test/unit/test_layers.py +++ b/test/unit/test_layers.py @@ -26,15 +26,15 @@ def test_layer_normalization(): x_nd = mx.nd.uniform(0, 10, (batch_size, other_dim, num_hidden)) x_np = x_nd.asnumpy() - ln = sockeye.layers.LayerNormalization(prefix="") + ln = sockeye.layers.LayerNormalization(num_hidden=num_hidden, prefix="") expected_mean = np.mean(x_np, axis=-1, keepdims=True) expected_var = np.var(x_np, axis=-1, keepdims=True) expected_norm = (x_np - expected_mean) / np.sqrt(expected_var) norm = ln(x).eval(x=x_nd, - _gamma=mx.nd.ones((num_hidden,)), - _beta=mx.nd.zeros((num_hidden,)))[0] + _gamma=mx.nd.ones((num_hidden,)), + _beta=mx.nd.zeros((num_hidden,)))[0] assert np.isclose(norm.asnumpy(), expected_norm, atol=1.e-6).all()