Skip to content

Commit

Permalink
fix gradient checkponit
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 17, 2024
1 parent a88f4db commit b6cecfe
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 43 deletions.
70 changes: 61 additions & 9 deletions wenet/branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Encoder definition."""

import torch

from typing import List, Optional, Union

from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
Expand Down Expand Up @@ -117,12 +118,63 @@ def __init__(
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
f"should be equal to num_blocks ({num_blocks})")

self.encoders = torch.nn.ModuleList([
BranchformerEncoderLayer(
output_size, WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum])
for lnum in range(num_blocks)
])
self.encoders = LayerDropModuleList(
p=stochastic_depth_rate,
modules=[
BranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum],
gradient_checkpointing) for lnum in range(num_blocks)
])

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
return self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)


# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
class LayerDropModuleList(torch.nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
Limitations:
1 can work with ddp when layer's gradient checkpoint disabled
2 can't work with ddp when layer's gradient checkpoint enables
3 can work with fsdp
"""

def __init__(self, p: List[float], modules=None):
super().__init__(modules)
assert len(p) == len(self)
self.p = p

def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p[i]):
yield m
77 changes: 43 additions & 34 deletions wenet/branchformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
cgmlp_weight: float = 0.5,
attn_branch_drop_rate: float = 0.0,
stochastic_depth_rate: float = 0.0,
gradient_checkpointing: bool = False,
):
super().__init__()
assert (attn is not None) or (
Expand Down Expand Up @@ -105,49 +106,18 @@ def __init__(
raise ValueError(f"unknown merge method: {merge_method}")
else:
self.merge_proj = torch.nn.Identity()
self.gradient_checkpointing = gradient_checkpointing

def forward(
def _forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
stoch_layer_coeff: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
pos_emb (torch.Tensor): positional encoding, must not be None
for BranchformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time.
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""

stoch_layer_coeff = 1.0
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)

if skip_layer:
return x, mask, att_cache, cnn_cache

# Two branches
x1 = x
x2 = x
Expand Down Expand Up @@ -232,3 +202,42 @@ def forward(
x = self.norm_final(x)

return x, mask, new_att_cache, new_cnn_cache

def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
pos_emb (torch.Tensor): positional encoding, must not be None
for BranchformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time.
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""

stoch_layer_coeff = 1.0
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training:
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache,
stoch_layer_coeff)

0 comments on commit b6cecfe

Please sign in to comment.