Skip to content

Commit

Permalink
lifted T >= S for regular case
Browse files Browse the repository at this point in the history
  • Loading branch information
durson committed Aug 23, 2023
1 parent 82983b5 commit 5c47e2d
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 28 deletions.
86 changes: 58 additions & 28 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@
from .mutual_information import mutual_information_recursion


def validate_st_lengths(
S: int, T: int, is_rnnt_type_regular: bool, boundary: Optional[Tensor] = None
):
if boundary is None:
assert S >= 1, S
assert (
is_rnnt_type_regular or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
else:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
assert (Ss >= 1).all(), Ss
assert (
is_rnnt_type_regular or (Ts >= Ss).all()
), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}"


def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
Expand Down Expand Up @@ -145,8 +162,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# subtracting am_max and lm_max is to ensure the probs are in a good range
Expand Down Expand Up @@ -389,8 +406,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -523,7 +540,7 @@ def rnnt_loss(
return (loss, scores_and_grads[1]) if return_grad else loss


def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
def _monotonic_lower_bound(x: Tensor) -> Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
Expand Down Expand Up @@ -558,8 +575,8 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:


def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
s_begin: Tensor, s_range: int
) -> Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
Expand Down Expand Up @@ -614,11 +631,11 @@ def _adjust_pruning_lower_bound(
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
px_grad: Tensor,
py_grad: Tensor,
boundary: Tensor,
s_range: int,
) -> torch.Tensor:
) -> Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
Expand Down Expand Up @@ -662,28 +679,41 @@ def get_rnnt_prune_ranges(
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]

is_regular = T1 != T

assert T1 in [T, T + 1], T1
S1 = S + 1
assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape

assert S >= 1, S
assert T >= S, (T, S)
validate_st_lengths(S, T, is_regular, boundary)

# adjust s_range if S >> T in regular case
if is_regular:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
s_range_min = (Ss - 2).div(Ts, rounding_mode="trunc").add(3).max().item()
if s_range < s_range_min:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}"
)
s_range = s_range_min

# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S + 1

if T1 == T:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

else:
if is_regular:
assert (
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
else:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
Expand Down Expand Up @@ -740,8 +770,8 @@ def get_rnnt_prune_ranges(


def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
am: Tensor, lm: Tensor, ranges: Tensor
) -> Tuple[Tensor, Tensor]:
"""Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`.
Expand Down Expand Up @@ -782,7 +812,7 @@ def do_rnnt_pruning(
return am_pruning, lm_pruning


def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
def _roll_by_shifts(src: Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row.
Note:
Expand Down Expand Up @@ -822,7 +852,7 @@ def get_rnnt_logprobs_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Expand Down Expand Up @@ -893,8 +923,8 @@ def get_rnnt_logprobs_pruned(
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), f"{ranges.shape} == ({B}, {T}, {s_range})"
(B, S) = symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -989,7 +1019,7 @@ def rnnt_loss_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
Expand Down Expand Up @@ -1208,8 +1238,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# Caution: some parts of this code are a little less clear than they could
Expand Down
96 changes: 96 additions & 0 deletions fast_rnnt/python/tests/rnnt_loss_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,102 @@ def test_rnnt_loss_pruned_small_symbols_number(self):
)
print(f"Pruned loss with range {r} : {pruned_loss}")

# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
# and would result in inf loss
def test_rnnt_loss_pruned_small_s_range(self):
B = 2
T = 2
S = 10
C = 10

frames = torch.randint(1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames

print(f"B = {B}, T = {T}, S = {S}, C = {C}")

for rnnt_type in ["regular"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()

# nonlinear transform
logits = torch.sigmoid(logits)

loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)

print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")

# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

S0 = 2

for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm

# nonlinear transform
logits = torch.sigmoid(logits)

pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
assert (
not pruned_loss.isinf().any()
), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}"
print(f"Pruned loss with range {r} : {pruned_loss}")


if __name__ == "__main__":
unittest.main()

0 comments on commit 5c47e2d

Please sign in to comment.