Skip to content

Commit

Permalink
[audio]fix audio get_window security error (#47386)
Browse files Browse the repository at this point in the history
* fix window security error

* format
  • Loading branch information
SmileGoat authored Oct 28, 2022
1 parent 0f649b3 commit 26c419c
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions python/paddle/audio/functional/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,39 @@
from paddle import Tensor


class WindowFunctionRegister(object):
def __init__(self):
self._functions_dict = dict()

def register(self, func=None):
def add_subfunction(func):
name = func.__name__
self._functions_dict[name] = func
return func

return add_subfunction

def get(self, name):
return self._functions_dict[name]


window_function_register = WindowFunctionRegister()


@window_function_register.register()
def _cat(x: List[Tensor], data_type: str) -> Tensor:
l = [paddle.to_tensor(_, data_type) for _ in x]
return paddle.concat(l)


@window_function_register.register()
def _acosh(x: Union[Tensor, float]) -> Tensor:
if isinstance(x, float):
return math.log(x + math.sqrt(x**2 - 1))
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))


@window_function_register.register()
def _extend(M: int, sym: bool) -> bool:
"""Extend window by 1 sample if needed for DFT-even symmetry."""
if not sym:
Expand All @@ -38,6 +60,7 @@ def _extend(M: int, sym: bool) -> bool:
return M, False


@window_function_register.register()
def _len_guards(M: int) -> bool:
"""Handle small or incorrect window lengths."""
if int(M) != M or M < 0:
Expand All @@ -46,6 +69,7 @@ def _len_guards(M: int) -> bool:
return M <= 1


@window_function_register.register()
def _truncate(w: Tensor, needed: bool) -> Tensor:
"""Truncate window by 1 sample if needed for DFT-even symmetry."""
if needed:
Expand All @@ -54,6 +78,7 @@ def _truncate(w: Tensor, needed: bool) -> Tensor:
return w


@window_function_register.register()
def _general_gaussian(
M: int, p, sig, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -70,6 +95,7 @@ def _general_gaussian(
return _truncate(w, needs_trunc)


@window_function_register.register()
def _general_cosine(
M: int, a: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -86,6 +112,7 @@ def _general_cosine(
return _truncate(w, needs_trunc)


@window_function_register.register()
def _general_hamming(
M: int, alpha: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -95,6 +122,7 @@ def _general_hamming(
return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)


@window_function_register.register()
def _taylor(
M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand Down Expand Up @@ -151,6 +179,7 @@ def W(n):
return _truncate(w, needs_trunc)


@window_function_register.register()
def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with
Expand All @@ -159,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.54, sym, dtype=dtype)


@window_function_register.register()
def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hann window.
The Hann window is a taper formed by using a raised cosine or sine-squared
Expand All @@ -167,6 +197,7 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.5, sym, dtype=dtype)


@window_function_register.register()
def _tukey(
M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand Down Expand Up @@ -200,6 +231,7 @@ def _tukey(
return _truncate(w, needs_trunc)


@window_function_register.register()
def _kaiser(
M: int, beta: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -209,6 +241,7 @@ def _kaiser(
raise NotImplementedError()


@window_function_register.register()
def _gaussian(
M: int, std: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -226,6 +259,7 @@ def _gaussian(
return _truncate(w, needs_trunc)


@window_function_register.register()
def _exponential(
M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
Expand All @@ -245,6 +279,7 @@ def _exponential(
return _truncate(w, needs_trunc)


@window_function_register.register()
def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a triangular window."""
if _len_guards(M):
Expand All @@ -262,6 +297,7 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc)


@window_function_register.register()
def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Bohman window.
The Bohman window is the autocorrelation of a cosine window.
Expand All @@ -279,6 +315,7 @@ def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc)


@window_function_register.register()
def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Blackman window.
The Blackman window is a taper formed by using the first three terms of
Expand All @@ -289,6 +326,7 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)


@window_function_register.register()
def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a window with a simple cosine shape."""
if _len_guards(M):
Expand All @@ -308,7 +346,7 @@ def get_window(
"""Return a window of a given length and type.
Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'.
Expand Down Expand Up @@ -348,8 +386,8 @@ def get_window(
)

try:
winfunc = eval('_' + winstr)
except NameError as e:
winfunc = window_function_register.get('_' + winstr)
except KeyError as e:
raise ValueError("Unknown window type.") from e

params = (win_length,) + args
Expand Down

0 comments on commit 26c419c

Please sign in to comment.