diff --git a/requirements/requirements.gpu-cu90.txt b/requirements/requirements.gpu-cu90.txt index a97c79120..86c40a5bf 100644 --- a/requirements/requirements.gpu-cu90.txt +++ b/requirements/requirements.gpu-cu90.txt @@ -1,4 +1,4 @@ pyyaml==3.12 mxnet-cu90mkl==1.2.0 -numpy>=1.12 +numpy>=1.12,<1.15.0 typing diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 97a28672c..78305661f 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -62,6 +62,7 @@ def register(cls, config_type: Type[DecoderConfig], suffix: str): :return: Class decorator. """ + def wrapper(target_cls): cls.__registry[config_type] = (target_cls, suffix) return target_cls @@ -269,10 +270,10 @@ def decode_sequence(self, target = mx.sym.Dropout(data=target, p=self.config.dropout_prepost) for layer in self.layers: - target = layer(target=target, - target_bias=target_bias, - source=source_encoded, - source_bias=source_bias) + target, layer_probs = layer(target=target, + target_bias=target_bias, + source=source_encoded, + source_bias=source_bias) target = self.final_process(data=target, prev=None) return target @@ -320,23 +321,29 @@ def decode_step(self, new_states = [source_encoded, source_encoded_lengths] layer_caches = self._get_cache_per_layer(cast(List[mx.sym.Symbol], cache)) + + attention_probs = [] + for layer, layer_cache in zip(self.layers, layer_caches): - target = layer(target=target, - target_bias=target_bias, - source=source_encoded, - source_bias=source_bias, - cache=layer_cache) + target, layer_probs = layer(target=target, + target_bias=target_bias, + source=source_encoded, + source_bias=source_bias, + cache=layer_cache) # store updated keys and values in states list. # (layer.__call__() has the side-effect of updating contents of layer_cache) new_states += [layer_cache['k'], layer_cache['v']] + attention_probs.append(layer_probs) # (batch_size, 1, model_size) target = self.final_process(data=target, prev=None) # (batch_size, model_size) target = mx.sym.reshape(target, shape=(-3, -1)) - # TODO(fhieber): no attention probs for now - attention_probs = mx.sym.sum(mx.sym.zeros_like(source_encoded), axis=2, keepdims=False) + # (layers, batch_size, heads, source_length) + attention_probs = mx.sym.stack(*attention_probs, axis=0) + # (batch_size, source_length) + attention_probs = mx.sym.mean(attention_probs, axis=(0, 2), keepdims=False) return target, attention_probs, new_states @@ -462,7 +469,7 @@ def __init__(self, max_seq_len_source: int, rnn_config: rnn.RNNConfig, attention_config: rnn_attention.AttentionConfig, - hidden_dropout: float = .0, + hidden_dropout: float = .0, # TODO: move this dropout functionality to OutputLayer state_init: str = C.RNN_DEC_INIT_LAST, state_init_lhuc: bool = False, context_gating: bool = False, @@ -470,7 +477,6 @@ def __init__(self, attention_in_upper_layers: bool = False, dtype: str = C.DTYPE_FP32, enc_last_hidden_concat_to_embedding: bool = False) -> None: - super().__init__() self.max_seq_len_source = max_seq_len_source self.rnn_config = rnn_config @@ -587,9 +593,9 @@ def decode_sequence(self, enc_last_hidden = None if self.config.enc_last_hidden_concat_to_embedding: enc_last_hidden = mx.sym.SequenceLast(data=source_encoded, - sequence_length=source_encoded_lengths, - axis=1, - use_sequence_length=True) + sequence_length=source_encoded_lengths, + axis=1, + use_sequence_length=True) # get recurrent attention function conditioned on source attention_func = self.attention.on(source_encoded, source_encoded_lengths, @@ -832,7 +838,8 @@ def _step(self, word_vec_prev: mx.sym.Symbol, attention_func: Callable, attention_state: rnn_attention.AttentionState, seq_idx: int = 0, - enc_last_hidden: Optional[mx.sym.Symbol] = None) -> Tuple[RecurrentDecoderState, rnn_attention.AttentionState]: + enc_last_hidden: Optional[mx.sym.Symbol] = None) -> Tuple[ + RecurrentDecoderState, rnn_attention.AttentionState]: """ Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function, @@ -849,7 +856,7 @@ def _step(self, word_vec_prev: mx.sym.Symbol, # concat previous word embedding and previous hidden state if enc_last_hidden is not None: word_vec_prev = mx.sym.concat(word_vec_prev, enc_last_hidden, dim=1, - name="%sconcat_target_encoder_t%d" % (self.prefix, seq_idx)) + name="%sconcat_target_encoder_t%d" % (self.prefix, seq_idx)) rnn_input = mx.sym.concat(word_vec_prev, state.hidden, dim=1, name="%sconcat_target_context_t%d" % (self.prefix, seq_idx)) # rnn_pre_attention_output: (batch_size, rnn_num_hidden) diff --git a/sockeye/layers.py b/sockeye/layers.py index 096bfb58a..ad3247f40 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -44,7 +44,7 @@ def activation(data: mx.sym.Symbol, act_type: str) -> mx.sym.Symbol: return data * mx.sym.Activation(data, act_type="sigmoid") elif act_type == C.GELU: # Approximation of x * gaussian_cdf(x) used by Hendrycks and Gimpel - return 0.5 * data * (1 + mx.sym.Activation((math.sqrt(2 / math.pi) * (data + (0.044715 * (data**3)))), + return 0.5 * data * (1 + mx.sym.Activation((math.sqrt(2 / math.pi) * (data + (0.044715 * (data ** 3)))), act_type="tanh")) else: return mx.sym.Activation(data, act_type=act_type) @@ -60,6 +60,7 @@ class LayerNormalization: :param scale_init: Initial value of scale variable if scale is None. Default 1.0. :param shift_init: Initial value of shift variable if shift is None. Default 0.0. """ + def __init__(self, prefix: str = 'layernorm', scale: Optional[mx.sym.Symbol] = None, @@ -99,6 +100,7 @@ class LHUC: :param weight: Optional parameter vector. :param prefix: Optional prefix for created parameters (if not given as weight). """ + def __init__(self, num_hidden: int, weight: Optional[mx.sym.Symbol] = None, @@ -330,6 +332,52 @@ def dot_attention(queries: mx.sym.Symbol, return mx.sym.batch_dot(lhs=probs, rhs=values, name='%scontexts' % prefix) +def dot_attention_with_probs(queries: mx.sym.Symbol, + keys: mx.sym.Symbol, + values: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + dropout: float = 0.0, + bias: Optional[mx.sym.Symbol] = None, + prefix: Optional[str] = ''): + """ + Computes dot attention for a set of queries, keys, and values, additionally return attention probabilities. + + :param queries: Attention queries. Shape: (n, lq, d). + :param keys: Attention keys. Shape: (n, lk, d). + :param values: Attention values. Shape: (n, lk, dv). + :param lengths: Optional sequence lengths of the keys. Shape: (n,). + :param dropout: Dropout probability. + :param bias: Optional 3d bias tensor. + :param prefix: Optional prefix + :return: 'Context' vectors for each query. Shape: (n, lq, dv). + """ + utils.check_condition(lengths is not None or bias is not None, + "Must provide either length or bias argument for masking") + + # (n, lq, lk) + logits = mx.sym.batch_dot(lhs=queries, rhs=keys, transpose_b=True, name='%sdot' % prefix) + + if lengths is not None: + # mask lk dimension + # (lk, n, lq) + logits = mx.sym.transpose(data=logits, axes=(2, 0, 1)) + logits = mx.sym.SequenceMask(data=logits, + use_sequence_length=True, + sequence_length=lengths, + value=C.LARGE_NEGATIVE_VALUE) + # (n, lq, lk) + logits = mx.sym.transpose(data=logits, axes=(1, 2, 0)) + + if bias is not None: + logits = mx.sym.broadcast_add(logits, bias, name='%sbias_add' % prefix) + + probs = mx.sym.softmax(logits, axis=-1) + probs_with_dropout = mx.sym.Dropout(probs, p=dropout) if dropout > 0.0 else probs + + # (n, lq, lk) x (n, lk, dv) -> (n, lq, dv) + return mx.sym.batch_dot(lhs=probs_with_dropout, rhs=values, name='%scontexts' % prefix), probs + + class MultiHeadAttentionBase: """ Base class for Multi-head attention. @@ -340,6 +388,7 @@ class MultiHeadAttentionBase: :param depth_out: Output depth / number of output units. :param dropout: Dropout probability on attention scores """ + def __init__(self, prefix: str, depth_att: int = 512, @@ -410,6 +459,7 @@ class MultiHeadSelfAttention(MultiHeadAttentionBase): :param depth_out: Output depth / number of output units. :param dropout: Dropout probability on attention scores """ + def __init__(self, prefix: str, depth_att: int = 512, @@ -531,6 +581,77 @@ def __call__(self, lengths=memory_lengths) +class MultiHeadAttentionWithProbs(MultiHeadAttention): + """ + Multi-head attention layer for queries independent from keys/values. + + :param prefix: Attention prefix. + :param depth_att: Attention depth / number of hidden units. + :param heads: Number of attention heads. + :param depth_out: Output depth / number of output units. + :param dropout: Dropout probability on attention scores + """ + + def __init__(self, + prefix: str, + depth_att: int = 512, + heads: int = 8, + depth_out: int = 512, + dropout: float = 0.0) -> None: + super().__init__(prefix, depth_att, heads, depth_out, dropout) + self.w_q2h = mx.sym.Variable("%sq2h_weight" % prefix) + self.w_k2h = mx.sym.Variable("%sk2h_weight" % prefix) + self.w_v2h = mx.sym.Variable("%sv2h_weight" % prefix) + + def _attend(self, + queries: mx.sym.Symbol, + keys: mx.sym.Symbol, + values: mx.sym.Symbol, + lengths: Optional[mx.sym.Symbol] = None, + bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + """ + Returns context vectors of multi-head dot attention. + + :param queries: Query tensor. Shape: (batch_size, query_max_length, depth). + :param keys: Keys. Shape: (batch_size, memory_max_length, depth). + :param values: Values. Shape: (batch_size, memory_max_length, depth). + :param lengths: Optional lengths of keys. Shape: (batch_size,). + :param bias: Optional 3d bias. + :return: Context vectors: Shape: (batch_size, query_max_length, output_depth). + :return: Attention probabilities: Shape: (batch_size, heads, source_length). + """ + # scale by sqrt(depth_per_head) + queries = queries * (self.depth_per_head ** -0.5) + + # (batch*heads, length, depth/heads) + queries = split_heads(queries, self.depth_per_head, self.heads) + keys = split_heads(keys, self.depth_per_head, self.heads) + values = split_heads(values, self.depth_per_head, self.heads) + lengths = broadcast_to_heads(lengths, self.heads, ndim=1, fold_heads=True) if lengths is not None else lengths + + # (batch*heads, query_max_length, depth_per_head), (batch*heads, query_max_length, source_length) + contexts, attention_probs = dot_attention_with_probs(queries, keys, values, + lengths=lengths, dropout=self.dropout, bias=bias, + prefix=self.prefix) + + # (batch, query_max_length, depth) + contexts = combine_heads(contexts, self.depth_per_head, self.heads) + # (batch, length, heads, source_length) + attention_probs = mx.sym.reshape(data=attention_probs, shape=(-4, -1, self.heads, 0, 0)) + # MultiHeadAttentionWithProbs is only used in Encoder therefore length=1 so we can drop it + # (batch, heads, source_length) + attention_probs = mx.sym.reshape(data=attention_probs, shape=(0, -3, 0)) + + # contexts: (batch, query_max_length, output_depth) + contexts = mx.sym.FullyConnected(data=contexts, + weight=self.w_h2o, + no_bias=True, + num_hidden=self.depth_out, + flatten=False) + + return contexts, attention_probs + + class ProjectedDotAttention: """ Dot attention layer for queries independent from keys/values. diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 7d49b6a1e..f09db0a1b 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -95,8 +95,9 @@ def __init__(self, dropout=config.dropout_prepost, prefix="%sff_post_" % prefix) self.lhuc = None - if config.use_lhuc: - self.lhuc = layers.LHUC(config.model_size, prefix=prefix) + if hasattr(config, 'use_lhuc'): + if config.use_lhuc: + self.lhuc = layers.LHUC(config.model_size, prefix=prefix) def __call__(self, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention @@ -140,11 +141,11 @@ def __init__(self, self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence, dropout=config.dropout_prepost, prefix="%satt_enc_pre_" % prefix) - self.enc_attention = layers.MultiHeadAttention(depth_att=config.model_size, - heads=config.attention_heads, - depth_out=config.model_size, - dropout=config.dropout_attention, - prefix="%satt_enc_" % prefix) + self.enc_attention = layers.MultiHeadAttentionWithProbs(depth_att=config.model_size, + heads=config.attention_heads, + depth_out=config.model_size, + dropout=config.dropout_attention, + prefix="%satt_enc_" % prefix) self.post_enc_attention = TransformerProcessBlock(sequence=config.postprocess_sequence, dropout=config.dropout_prepost, prefix="%satt_enc_post_" % prefix) @@ -162,8 +163,9 @@ def __init__(self, prefix="%sff_post_" % prefix) self.lhuc = None - if config.use_lhuc: - self.lhuc = layers.LHUC(config.model_size, prefix=prefix) + if hasattr(config, 'use_lhuc'): + if config.use_lhuc: + self.lhuc = layers.LHUC(config.model_size, prefix=prefix) def __call__(self, target: mx.sym.Symbol, @@ -178,9 +180,9 @@ def __call__(self, target = self.post_self_attention(target_self_att, target) # encoder attention - target_enc_att = self.enc_attention(queries=self.pre_enc_attention(target, None), - memory=source, - bias=source_bias) + target_enc_att, attention_probs = self.enc_attention(queries=self.pre_enc_attention(target, None), + memory=source, + bias=source_bias) target = self.post_enc_attention(target_enc_att, target) # feed-forward @@ -190,7 +192,7 @@ def __call__(self, if self.lhuc: target = self.lhuc(target) - return target + return target, attention_probs class TransformerProcessBlock: @@ -382,7 +384,7 @@ class AutoRegressiveBias(mx.operator.CustomOp): 0 0 0 0 """ - def __init__(self, length: int, dtype:str, ctx: mx.Context) -> None: + def __init__(self, length: int, dtype: str, ctx: mx.Context) -> None: super().__init__() self.bias = self.get_bias(length, dtype, ctx)