Skip to content

Commit

Permalink
change window axis to spatial_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 17, 2022
1 parent 59efe66 commit 8a99df4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
17 changes: 15 additions & 2 deletions nn/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def split(source: nn.Tensor, *,

def window(
source: nn.Tensor, *,
axis: nn.Dim,
spatial_dim: nn.Dim,
window_dim: nn.Dim,
window_left: Optional[int] = NotSpecified,
window_right: Optional[int] = NotSpecified,
Expand All @@ -185,12 +185,25 @@ def window(
layer, (window_dim, out_spatial_dim), state = rec_window(
source,
window_dim=window_dim, window_left=window_left, window_right=window_right,
axis=axis, padding=padding, stride=stride,
axis=spatial_dim, padding=padding, stride=stride,
name=name)
del state
return layer, out_spatial_dim


def window_direct(
source: nn.Tensor, *,
axis: nn.Dim,
window_dim: nn.Dim,
padding: str = "same",
stride: int = 1,
) -> Tuple[nn.Tensor, nn.Dim]:
"""
:param source:
:param axis:
"""


def window_step(
source: nn.Tensor, *, state: nn.LayerState,
window_dim: nn.Dim,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nn_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_window():
x = nn.Data("data", dim_tags=[time_dim, nn.batch_dim, in_dim], available_for_inference=True)
x = nn.get_extern_data(x)
win_dim = nn.SpatialDim("window", 3)
out, _ = nn.window(x, axis=time_dim, window_dim=win_dim)
out, _ = nn.window(x, spatial_dim=time_dim, window_dim=win_dim)
out.mark_as_default_output()
config_str = nn.get_returnn_config().get_config_raw_dict(nn.Module())
res = dummy_run_net_single_custom(config_str, default_out_dim_tag_order=[time_dim, win_dim, nn.batch_dim, in_dim])
Expand All @@ -101,7 +101,7 @@ def test_window_stride():
x = nn.Data("data", dim_tags=[time_dim, nn.batch_dim, in_dim], available_for_inference=True)
x = nn.get_extern_data(x)
win_dim = nn.SpatialDim("window", 5)
out, time_dim_ = nn.window(x, axis=time_dim, window_dim=win_dim, stride=3)
out, time_dim_ = nn.window(x, spatial_dim=time_dim, window_dim=win_dim, stride=3)
out.mark_as_default_output()
config_str = nn.get_returnn_config().get_config_raw_dict(nn.Module())
res = dummy_run_net_single_custom(config_str, default_out_dim_tag_order=[time_dim_, win_dim, nn.batch_dim, in_dim])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nn_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_chunking_conformer():
num_heads=2, num_layers=2,
)
window_dim = nn.SpatialDim("window", 50)
data, time_dim_ = nn.window(data, axis=time_dim, window_dim=window_dim, stride=25)
data, time_dim_ = nn.window(data, spatial_dim=time_dim, window_dim=window_dim, stride=25)
out, _ = conformer(data, in_spatial_dim=window_dim)
out.verify_out_shape({nn.batch_dim, time_dim_, window_dim, conformer.out_dim})

Expand Down

0 comments on commit 8a99df4

Please sign in to comment.