Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Nov 19, 2024
1 parent 4b400e4 commit 5de9115
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 43 deletions.
159 changes: 131 additions & 28 deletions src/layer/inversespectrogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,51 @@ InverseSpectrogram::InverseSpectrogram()
int InverseSpectrogram::load_param(const ParamDict& pd)
{
n_fft = pd.get(0, 0);
power = pd.get(1, 0);
returns = pd.get(1, 0);
hoplen = pd.get(2, n_fft / 4);
winlen = pd.get(3, n_fft);
window_type = pd.get(4, 0);
center = pd.get(5, 1);
normalized = pd.get(7, 0);

// assert winlen <= n_fft
// generate window
window_data.create(n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}
if (window_type == 0)
{
// all ones
for (int i = 0; i < winlen; i++)
{
*p++ = 1.f;
}
}
if (window_type == 1)
{
// hann window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.5f * (1 - cos(2 * M_PI * i / winlen));
}
}
if (window_type == 2)
{
// hamming window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.54f - 0.46f * cos(2 * M_PI * i / winlen);
}
}
for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}
}

return 0;
}
Expand All @@ -37,71 +79,132 @@ int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Opt
// https://github.com/librosa/librosa/blob/main/librosa/core/spectrum.py#L630

// TODO custom window
// TODO padding for center=True
// TODO onesided=True

const int frames = bottom_blob.h;
// const int freqs = bottom_blob.c;
const int freqs = bottom_blob.c;
// assert freqs == n_fft or freqs == n_fft / 2 + 1

const int outsize = (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2); // center=1
// const int outsize = (frames - 1) * hoplen + n_fft; // center=0
const int onesided = freqs == n_fft / 2 + 1 ? 1 : 0;

const int outsize = center ? (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2) : (frames - 1) * hoplen + n_fft;

const size_t elemsize = bottom_blob.elemsize;

top_blob.create(outsize, elemsize, opt.blob_allocator);
if (returns == 0)
{
top_blob.create(2, outsize, elemsize, opt.blob_allocator);
}
else
{
top_blob.create(outsize, elemsize, opt.blob_allocator);
}
if (top_blob.empty())
return -100;

Mat top_blob_padded(outsize + n_fft, elemsize, opt.workspace_allocator);
Mat window_sumsquare(outsize + n_fft, elemsize, opt.workspace_allocator);
for (int i = 0; i < outsize; i++)
{
top_blob_padded[i] = 0.f;
window_sumsquare[i] = 0.f;
}
if (window_sumsquare.empty())
return -100;

top_blob.fill(0.f);
window_sumsquare.fill(0.f);

for (int j = 0; j < frames; j++)
{
// collect complex
Mat sp(2, n_fft);
for (int k = 0; k < n_fft; k++)
if (onesided == 1)
{
for (int k = 0; k < n_fft / 2 + 1; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
for (int k = n_fft / 2 + 1; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(n_fft - k).row(j)[0];
sp.row(k)[1] = -bottom_blob.channel(n_fft - k).row(j)[1];
}
}
else
{
for (int k = 0; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
}

if (normalized)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= sqrt(winlen);
}
}

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < n_fft; i++)
{
// inverse dft
float re = 0.f;
// float im = 0.f;
float im = 0.f;
for (int k = 0; k < n_fft; k++)
{
double angle = 2 * M_PI * i * k / n_fft;

re += sp.row(k)[0] * cos(angle) - sp.row(k)[1] * sin(angle);
// im += sp.row(k)[0] * sin(angle) + sp.row(k)[1] * cos(angle);
im += sp.row(k)[0] * sin(angle) + sp.row(k)[1] * cos(angle);
}

re /= n_fft;
// im /= n_fft;
im /= n_fft;

// apply hann window
re *= 0.5f * (1 - cos(2 * M_PI * i / n_fft));
// apply window
re *= window_data[i];
im *= window_data[i];

// square hann window
window_sumsquare[j * hoplen + i] += (0.5f * (1 - cos(2 * M_PI * i / n_fft))) * (0.5f * (1 - cos(2 * M_PI * i / n_fft)));

top_blob_padded[j * hoplen + i] += re;
int output_index = j * hoplen + i;
if (center == 1)
{
output_index -= n_fft / 2;
}
if (output_index >= 0 && output_index < outsize)
{
// square window
window_sumsquare[output_index] += window_data[i] * window_data[i];

if (returns == 0)
{
top_blob.row(output_index)[0] += re;
top_blob.row(output_index)[1] += im;
}
if (returns == 1)
{
top_blob[output_index] += re;
}
if (returns == 2)
{
top_blob[output_index] += im;
}
}
}
}

// cut padding
for (int i = 0; i < outsize; i++)
// square window norm
if (returns == 0)
{
top_blob[i] = top_blob_padded[n_fft / 2 + i] / window_sumsquare[n_fft / 2 + i];
for (int i = 0; i < outsize; i++)
{
top_blob.row(i)[0] /= window_sumsquare[i];
top_blob.row(i)[1] /= window_sumsquare[i];
}
}
else
{
for (int i = 0; i < outsize; i++)
{
top_blob[i] /= window_sumsquare[i];
}
}

return 0;
Expand Down
7 changes: 6 additions & 1 deletion src/layer/inversespectrogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ class InverseSpectrogram : public Layer

public:
int n_fft;
int power;
int returns; // 0=complex 1=real 2=imag
int hoplen;
int winlen;
int window_type; // 0=ones 1=hann 2=hamming
int center;
int normalized;

Mat window_data;
};

} // namespace ncnn
Expand Down
2 changes: 1 addition & 1 deletion src/layer/spectrogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op
{
float v = ptr[k];

// apply hann window
// apply window
v *= window_data[k];

// dft
Expand Down
126 changes: 121 additions & 5 deletions tools/pnnx/src/pass_ncnn/torch_istft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
// op->params["0"] = 1;
op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 1; // returns
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = 0; // all ones
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft, 20)

class torch_istft_1 : public GraphRewriterPass
class torch_istft_1 : public torch_istft
{
public:
const char* match_pattern_graph() const
Expand All @@ -65,6 +71,73 @@ pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
torch_istft::write(op, captured_params);

op->params["1"] = 0; // returns
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_1, 20)

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

static int detect_window_type(const std::vector<float>& window_data)
{
const int winlen = (int)window_data.size();

bool is_one = true;
bool is_hann = true;
bool is_hamming = true;
for (int i = 0; i < winlen; i++)
{
if (!NearlyEqual(window_data[i], 1.f, 0.001))
is_one = false;

if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001))
is_hann = false;

if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001))
is_hamming = false;
}

if (is_one)
return 0;
if (is_hann)
return 1;
if (is_hamming)
return 2;

return -1;
}

class torch_istft_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
pnnx.Attribute op_1 0 1 window @data
torch.istft op_2 2 1 a window out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InverseSpectrogram";
Expand All @@ -75,13 +148,56 @@ pnnx.Output output 1 0 out
return "istft";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data();
const int window_type = detect_window_type(window_data);
fprintf(stderr, "window_type = %d\n", window_type);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
// op->params["0"] = 1;
const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data();
const int window_type = detect_window_type(window_data);

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 1; // returns
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_1, 20)
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_2, 20)

class torch_istft_3 : public torch_istft_2
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
pnnx.Attribute op_1 0 1 window @data
torch.istft op_2 2 1 a window b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length
torch.view_as_real op_3 1 1 b out
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
torch_istft_2::write(op, captured_params, captured_attrs);

op->params["1"] = 0; // returns
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_3, 20)

} // namespace ncnn

Expand Down
Loading

0 comments on commit 5de9115

Please sign in to comment.