Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Attention Probabilities #504

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion requirements/requirements.gpu-cu90.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml==3.12
mxnet-cu90mkl==1.2.0
numpy>=1.12
numpy>=1.12,<1.15.0
typing
43 changes: 25 additions & 18 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def register(cls, config_type: Type[DecoderConfig], suffix: str):

:return: Class decorator.
"""

def wrapper(target_cls):
cls.__registry[config_type] = (target_cls, suffix)
return target_cls
Expand Down Expand Up @@ -269,10 +270,10 @@ def decode_sequence(self,
target = mx.sym.Dropout(data=target, p=self.config.dropout_prepost)

for layer in self.layers:
target = layer(target=target,
target_bias=target_bias,
source=source_encoded,
source_bias=source_bias)
target, layer_probs = layer(target=target,
target_bias=target_bias,
source=source_encoded,
source_bias=source_bias)
target = self.final_process(data=target, prev=None)

return target
Expand Down Expand Up @@ -320,23 +321,29 @@ def decode_step(self,

new_states = [source_encoded, source_encoded_lengths]
layer_caches = self._get_cache_per_layer(cast(List[mx.sym.Symbol], cache))

attention_probs = []

for layer, layer_cache in zip(self.layers, layer_caches):
target = layer(target=target,
target_bias=target_bias,
source=source_encoded,
source_bias=source_bias,
cache=layer_cache)
target, layer_probs = layer(target=target,
target_bias=target_bias,
source=source_encoded,
source_bias=source_bias,
cache=layer_cache)
# store updated keys and values in states list.
# (layer.__call__() has the side-effect of updating contents of layer_cache)
new_states += [layer_cache['k'], layer_cache['v']]
attention_probs.append(layer_probs)

# (batch_size, 1, model_size)
target = self.final_process(data=target, prev=None)
# (batch_size, model_size)
target = mx.sym.reshape(target, shape=(-3, -1))

# TODO(fhieber): no attention probs for now
attention_probs = mx.sym.sum(mx.sym.zeros_like(source_encoded), axis=2, keepdims=False)
# (layers, batch_size, heads, source_length)
attention_probs = mx.sym.stack(*attention_probs, axis=0)
# (batch_size, source_length)
attention_probs = mx.sym.mean(attention_probs, axis=(0, 2), keepdims=False)

return target, attention_probs, new_states

Expand Down Expand Up @@ -462,15 +469,14 @@ def __init__(self,
max_seq_len_source: int,
rnn_config: rnn.RNNConfig,
attention_config: rnn_attention.AttentionConfig,
hidden_dropout: float = .0,
hidden_dropout: float = .0, # TODO: move this dropout functionality to OutputLayer
state_init: str = C.RNN_DEC_INIT_LAST,
state_init_lhuc: bool = False,
context_gating: bool = False,
layer_normalization: bool = False,
attention_in_upper_layers: bool = False,
dtype: str = C.DTYPE_FP32,
enc_last_hidden_concat_to_embedding: bool = False) -> None:

super().__init__()
self.max_seq_len_source = max_seq_len_source
self.rnn_config = rnn_config
Expand Down Expand Up @@ -587,9 +593,9 @@ def decode_sequence(self,
enc_last_hidden = None
if self.config.enc_last_hidden_concat_to_embedding:
enc_last_hidden = mx.sym.SequenceLast(data=source_encoded,
sequence_length=source_encoded_lengths,
axis=1,
use_sequence_length=True)
sequence_length=source_encoded_lengths,
axis=1,
use_sequence_length=True)

# get recurrent attention function conditioned on source
attention_func = self.attention.on(source_encoded, source_encoded_lengths,
Expand Down Expand Up @@ -832,7 +838,8 @@ def _step(self, word_vec_prev: mx.sym.Symbol,
attention_func: Callable,
attention_state: rnn_attention.AttentionState,
seq_idx: int = 0,
enc_last_hidden: Optional[mx.sym.Symbol] = None) -> Tuple[RecurrentDecoderState, rnn_attention.AttentionState]:
enc_last_hidden: Optional[mx.sym.Symbol] = None) -> Tuple[
RecurrentDecoderState, rnn_attention.AttentionState]:

"""
Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function,
Expand All @@ -849,7 +856,7 @@ def _step(self, word_vec_prev: mx.sym.Symbol,
# concat previous word embedding and previous hidden state
if enc_last_hidden is not None:
word_vec_prev = mx.sym.concat(word_vec_prev, enc_last_hidden, dim=1,
name="%sconcat_target_encoder_t%d" % (self.prefix, seq_idx))
name="%sconcat_target_encoder_t%d" % (self.prefix, seq_idx))
rnn_input = mx.sym.concat(word_vec_prev, state.hidden, dim=1,
name="%sconcat_target_context_t%d" % (self.prefix, seq_idx))
# rnn_pre_attention_output: (batch_size, rnn_num_hidden)
Expand Down
123 changes: 122 additions & 1 deletion sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def activation(data: mx.sym.Symbol, act_type: str) -> mx.sym.Symbol:
return data * mx.sym.Activation(data, act_type="sigmoid")
elif act_type == C.GELU:
# Approximation of x * gaussian_cdf(x) used by Hendrycks and Gimpel
return 0.5 * data * (1 + mx.sym.Activation((math.sqrt(2 / math.pi) * (data + (0.044715 * (data**3)))),
return 0.5 * data * (1 + mx.sym.Activation((math.sqrt(2 / math.pi) * (data + (0.044715 * (data ** 3)))),
act_type="tanh"))
else:
return mx.sym.Activation(data, act_type=act_type)
Expand All @@ -60,6 +60,7 @@ class LayerNormalization:
:param scale_init: Initial value of scale variable if scale is None. Default 1.0.
:param shift_init: Initial value of shift variable if shift is None. Default 0.0.
"""

def __init__(self,
prefix: str = 'layernorm',
scale: Optional[mx.sym.Symbol] = None,
Expand Down Expand Up @@ -99,6 +100,7 @@ class LHUC:
:param weight: Optional parameter vector.
:param prefix: Optional prefix for created parameters (if not given as weight).
"""

def __init__(self,
num_hidden: int,
weight: Optional[mx.sym.Symbol] = None,
Expand Down Expand Up @@ -330,6 +332,52 @@ def dot_attention(queries: mx.sym.Symbol,
return mx.sym.batch_dot(lhs=probs, rhs=values, name='%scontexts' % prefix)


def dot_attention_with_probs(queries: mx.sym.Symbol,
keys: mx.sym.Symbol,
values: mx.sym.Symbol,
lengths: Optional[mx.sym.Symbol] = None,
dropout: float = 0.0,
bias: Optional[mx.sym.Symbol] = None,
prefix: Optional[str] = ''):
"""
Computes dot attention for a set of queries, keys, and values, additionally return attention probabilities.

:param queries: Attention queries. Shape: (n, lq, d).
:param keys: Attention keys. Shape: (n, lk, d).
:param values: Attention values. Shape: (n, lk, dv).
:param lengths: Optional sequence lengths of the keys. Shape: (n,).
:param dropout: Dropout probability.
:param bias: Optional 3d bias tensor.
:param prefix: Optional prefix
:return: 'Context' vectors for each query. Shape: (n, lq, dv).
"""
utils.check_condition(lengths is not None or bias is not None,
"Must provide either length or bias argument for masking")

# (n, lq, lk)
logits = mx.sym.batch_dot(lhs=queries, rhs=keys, transpose_b=True, name='%sdot' % prefix)

if lengths is not None:
# mask lk dimension
# (lk, n, lq)
logits = mx.sym.transpose(data=logits, axes=(2, 0, 1))
logits = mx.sym.SequenceMask(data=logits,
use_sequence_length=True,
sequence_length=lengths,
value=C.LARGE_NEGATIVE_VALUE)
# (n, lq, lk)
logits = mx.sym.transpose(data=logits, axes=(1, 2, 0))

if bias is not None:
logits = mx.sym.broadcast_add(logits, bias, name='%sbias_add' % prefix)

probs = mx.sym.softmax(logits, axis=-1)
probs_with_dropout = mx.sym.Dropout(probs, p=dropout) if dropout > 0.0 else probs

# (n, lq, lk) x (n, lk, dv) -> (n, lq, dv)
return mx.sym.batch_dot(lhs=probs_with_dropout, rhs=values, name='%scontexts' % prefix), probs


class MultiHeadAttentionBase:
"""
Base class for Multi-head attention.
Expand All @@ -340,6 +388,7 @@ class MultiHeadAttentionBase:
:param depth_out: Output depth / number of output units.
:param dropout: Dropout probability on attention scores
"""

def __init__(self,
prefix: str,
depth_att: int = 512,
Expand Down Expand Up @@ -410,6 +459,7 @@ class MultiHeadSelfAttention(MultiHeadAttentionBase):
:param depth_out: Output depth / number of output units.
:param dropout: Dropout probability on attention scores
"""

def __init__(self,
prefix: str,
depth_att: int = 512,
Expand Down Expand Up @@ -531,6 +581,77 @@ def __call__(self,
lengths=memory_lengths)


class MultiHeadAttentionWithProbs(MultiHeadAttention):
"""
Multi-head attention layer for queries independent from keys/values.

:param prefix: Attention prefix.
:param depth_att: Attention depth / number of hidden units.
:param heads: Number of attention heads.
:param depth_out: Output depth / number of output units.
:param dropout: Dropout probability on attention scores
"""

def __init__(self,
prefix: str,
depth_att: int = 512,
heads: int = 8,
depth_out: int = 512,
dropout: float = 0.0) -> None:
super().__init__(prefix, depth_att, heads, depth_out, dropout)
self.w_q2h = mx.sym.Variable("%sq2h_weight" % prefix)
self.w_k2h = mx.sym.Variable("%sk2h_weight" % prefix)
self.w_v2h = mx.sym.Variable("%sv2h_weight" % prefix)

def _attend(self,
queries: mx.sym.Symbol,
keys: mx.sym.Symbol,
values: mx.sym.Symbol,
lengths: Optional[mx.sym.Symbol] = None,
bias: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol:
"""
Returns context vectors of multi-head dot attention.

:param queries: Query tensor. Shape: (batch_size, query_max_length, depth).
:param keys: Keys. Shape: (batch_size, memory_max_length, depth).
:param values: Values. Shape: (batch_size, memory_max_length, depth).
:param lengths: Optional lengths of keys. Shape: (batch_size,).
:param bias: Optional 3d bias.
:return: Context vectors: Shape: (batch_size, query_max_length, output_depth).
:return: Attention probabilities: Shape: (batch_size, heads, source_length).
"""
# scale by sqrt(depth_per_head)
queries = queries * (self.depth_per_head ** -0.5)

# (batch*heads, length, depth/heads)
queries = split_heads(queries, self.depth_per_head, self.heads)
keys = split_heads(keys, self.depth_per_head, self.heads)
values = split_heads(values, self.depth_per_head, self.heads)
lengths = broadcast_to_heads(lengths, self.heads, ndim=1, fold_heads=True) if lengths is not None else lengths

# (batch*heads, query_max_length, depth_per_head), (batch*heads, query_max_length, source_length)
contexts, attention_probs = dot_attention_with_probs(queries, keys, values,
lengths=lengths, dropout=self.dropout, bias=bias,
prefix=self.prefix)

# (batch, query_max_length, depth)
contexts = combine_heads(contexts, self.depth_per_head, self.heads)
# (batch, length, heads, source_length)
attention_probs = mx.sym.reshape(data=attention_probs, shape=(-4, -1, self.heads, 0, 0))
# MultiHeadAttentionWithProbs is only used in Encoder therefore length=1 so we can drop it
# (batch, heads, source_length)
attention_probs = mx.sym.reshape(data=attention_probs, shape=(0, -3, 0))

# contexts: (batch, query_max_length, output_depth)
contexts = mx.sym.FullyConnected(data=contexts,
weight=self.w_h2o,
no_bias=True,
num_hidden=self.depth_out,
flatten=False)

return contexts, attention_probs


class ProjectedDotAttention:
"""
Dot attention layer for queries independent from keys/values.
Expand Down
30 changes: 16 additions & 14 deletions sockeye/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def __init__(self,
dropout=config.dropout_prepost,
prefix="%sff_post_" % prefix)
self.lhuc = None
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size, prefix=prefix)
if hasattr(config, 'use_lhuc'):
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size, prefix=prefix)

def __call__(self, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol:
# self-attention
Expand Down Expand Up @@ -140,11 +141,11 @@ def __init__(self,
self.pre_enc_attention = TransformerProcessBlock(sequence=config.preprocess_sequence,
dropout=config.dropout_prepost,
prefix="%satt_enc_pre_" % prefix)
self.enc_attention = layers.MultiHeadAttention(depth_att=config.model_size,
heads=config.attention_heads,
depth_out=config.model_size,
dropout=config.dropout_attention,
prefix="%satt_enc_" % prefix)
self.enc_attention = layers.MultiHeadAttentionWithProbs(depth_att=config.model_size,
heads=config.attention_heads,
depth_out=config.model_size,
dropout=config.dropout_attention,
prefix="%satt_enc_" % prefix)
self.post_enc_attention = TransformerProcessBlock(sequence=config.postprocess_sequence,
dropout=config.dropout_prepost,
prefix="%satt_enc_post_" % prefix)
Expand All @@ -162,8 +163,9 @@ def __init__(self,
prefix="%sff_post_" % prefix)

self.lhuc = None
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size, prefix=prefix)
if hasattr(config, 'use_lhuc'):
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size, prefix=prefix)

def __call__(self,
target: mx.sym.Symbol,
Expand All @@ -178,9 +180,9 @@ def __call__(self,
target = self.post_self_attention(target_self_att, target)

# encoder attention
target_enc_att = self.enc_attention(queries=self.pre_enc_attention(target, None),
memory=source,
bias=source_bias)
target_enc_att, attention_probs = self.enc_attention(queries=self.pre_enc_attention(target, None),
memory=source,
bias=source_bias)
target = self.post_enc_attention(target_enc_att, target)

# feed-forward
Expand All @@ -190,7 +192,7 @@ def __call__(self,
if self.lhuc:
target = self.lhuc(target)

return target
return target, attention_probs


class TransformerProcessBlock:
Expand Down Expand Up @@ -382,7 +384,7 @@ class AutoRegressiveBias(mx.operator.CustomOp):
0 0 0 0
"""

def __init__(self, length: int, dtype:str, ctx: mx.Context) -> None:
def __init__(self, length: int, dtype: str, ctx: mx.Context) -> None:
super().__init__()
self.bias = self.get_bias(length, dtype, ctx)

Expand Down