Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Adding NeuralStack and NeuralQueue models under /research
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 247806737
  • Loading branch information
Shawn Simister authored and copybara-github committed May 13, 2019
1 parent dad03c9 commit 838aca4
Show file tree
Hide file tree
Showing 4 changed files with 805 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,21 @@ def length_from_embedding(emb):
return tf.cast(tf.reduce_sum(mask_from_embedding(emb), [1, 2, 3]), tf.int32)


def mask_pos_gt(source_length, target_length):
"""A mask with 1.0 wherever source_pos > target_pos and 0.0 elsewhere.
Args:
source_length: an integer
target_length: an integer
Returns:
a Tensor with shape [1, target_length, source_length]
"""
return tf.expand_dims(
tf.cast(tf.greater(tf.expand_dims(tf.range(target_length), axis=0),
tf.expand_dims(tf.range(source_length), axis=1)),
dtype=tf.float32), axis=0)


def mask_leq(target_length, source_length):
"""A mask with 1.0 wherever source_pos <= target_pos and 0.0 elsewhere.
Expand All @@ -1268,6 +1283,21 @@ def mask_leq(target_length, source_length):
out_shape=[1, target_length, source_length])


def mask_pos_lt(source_length, target_length):
"""A mask with 1.0 wherever source_pos < target_pos and 0.0 elsewhere.
Args:
source_length: an integer
target_length: an integer
Returns:
a Tensor with shape [1, target_length, source_length]
"""
return tf.expand_dims(
tf.cast(tf.less(tf.expand_dims(tf.range(target_length), axis=0),
tf.expand_dims(tf.range(source_length), axis=1)),
dtype=tf.float32), axis=0)


def relu_density_logit(x, reduce_dims):
"""logit(density(x)).
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from tensor2tensor.models.research import lm_experiments
from tensor2tensor.models.research import moe_experiments
from tensor2tensor.models.research import multiquery_paper
from tensor2tensor.models.research import neural_stack
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import similarity_transformer
from tensor2tensor.models.research import super_lm
Expand Down
Loading

0 comments on commit 838aca4

Please sign in to comment.