Skip to content

Commit

Permalink
Script MultiheadAttention (#1002)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/fairseq-py#1002

Pull Request resolved: pytorch/translate#681

Pull Request resolved: #1524

Make fairseq MultiheadAttention scriptable. Looking for feedbacks.

1. Add types
2. Move incremental state management logic from util functions to initializers. TorchScript in general doesn't support global dict. As a result modules with multihead attention in it would assign itself fairseq_instance_id in the initializer.
3. There might be opportunities to make assertions and annotations cleaner.

Reviewed By: myleott

Differential Revision: D18772594

fbshipit-source-id: 377aef4bbb7ef51da5b6bac9a87a6f7b03b16fe1
  • Loading branch information
cndn authored and facebook-github-bot committed Jan 22, 2020
1 parent 2535cab commit 4e48c4a
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 81 deletions.
29 changes: 29 additions & 0 deletions fairseq/incremental_decoding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import utils


class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
init_incremental_state(self)


def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
return cls


# In most cases we should register incremental states using @with_incremental_state decorator
# instead of calling into this explicitly in initializer.
def init_incremental_state(obj):
obj.module_name = obj.__class__.__name__
utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = (
utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1
)
obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[
obj.module_name
]
10 changes: 7 additions & 3 deletions fairseq/iterative_refinement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out.attn is None else decoder_out.attn[terminated]
None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated]
)

if self.retain_history:
Expand Down Expand Up @@ -259,8 +259,12 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None
attn=decoder_out.attn[not_terminated]
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
else None,
history=[h[not_terminated] for h in decoder_out.history]
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
Expand Down
11 changes: 6 additions & 5 deletions fairseq/models/fairseq_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import torch.nn as nn

from fairseq import utils


Expand All @@ -29,7 +28,9 @@ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
x = self.output_layer(x)
return x, extra

Expand All @@ -54,10 +55,10 @@ def output_layer(self, features, **kwargs):
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""

if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert 'target' in sample
target = sample['target']
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
Expand Down
2 changes: 2 additions & 0 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# LICENSE file in the root directory of this source tree.

from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state


@with_incremental_state
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.
Expand Down
3 changes: 2 additions & 1 deletion fairseq/models/fconv_self_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LearnedPositionalEmbedding,
LinearizedConvolution,
)

from fairseq.incremental_decoding_utils import with_incremental_state

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -291,6 +291,7 @@ def max_positions(self):
return self.embed_positions.max_positions()


@with_incremental_state
class FConvDecoder(FairseqDecoder):
"""Convolutional decoder"""
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/dynamic_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fairseq import utils
from .unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state


def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
Expand Down Expand Up @@ -38,6 +39,7 @@ def Linear(in_features, out_features, bias=True):
return m


@with_incremental_state
class DynamicConv1dTBC(nn.Module):
'''Dynamic lightweight convolution taking T x B x C inputs
Args:
Expand Down
3 changes: 2 additions & 1 deletion fairseq/modules/lightweight_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fairseq import utils
from fairseq.modules.unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state


def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
Expand Down Expand Up @@ -99,6 +100,7 @@ def forward(self, input):
return output


@with_incremental_state
class LightweightConv1dTBC(nn.Module):
'''Lightweight Convolution assuming the input is TxBxC
Args:
Expand Down Expand Up @@ -136,7 +138,6 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
self.bias = None

self.reset_parameters()

self.onnx_trace = False

def reset_parameters(self):
Expand Down
3 changes: 2 additions & 1 deletion fairseq/modules/linearized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import torch.nn.functional as F

from fairseq import utils

from .conv_tbc import ConvTBC
from fairseq.incremental_decoding_utils import with_incremental_state


@with_incremental_state
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
Expand Down
Loading

0 comments on commit 4e48c4a

Please sign in to comment.