-
Notifications
You must be signed in to change notification settings - Fork 41
/
modules.py
115 lines (92 loc) · 4.15 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 12 14:53:48 2021
@author: xiaohuaile
"""
import tensorflow.keras as keras
import tensorflow as tf
'''
dual path rnn block
'''
class DprnnBlock(keras.layers.Layer):
def __init__(self, numUnits, batch_size, L, width, channel, causal = True, **kwargs):
super(DprnnBlock, self).__init__(**kwargs)
'''
numUnits hidden layer size in the LSTM
batch_size
L number of frames, -1 for undefined length
width width size output from encoder
channel channel size output from encoder
causal instant Layer Norm or global Layer Norm
'''
self.numUnits = numUnits
self.batch_size = batch_size
self.causal = causal
self.intra_rnn = keras.layers.Bidirectional(keras.layers.LSTM(units=self.numUnits//2, return_sequences=True,implementation = 2,recurrent_activation = 'hard_sigmoid'))
self.intra_fc = keras.layers.Dense(units = self.numUnits,)
if self.causal:
self.intra_ln = keras.layers.LayerNormalization(center=True, scale=True, axis = [-1,-2])
else:
self.intra_ln = keras.layers.LayerNormalization(center=False, scale=False)
self.inter_rnn = keras.layers.LSTM(units=self.numUnits, return_sequences=True,implementation = 2,recurrent_activation = 'hard_sigmoid')
self.inter_fc = keras.layers.Dense(units = self.numUnits,)
if self.causal:
self.inter_ln = keras.layers.LayerNormalization(center=True, scale=True, axis = [-1,-2])
else:
self.inter_ln = keras.layers.LayerNormalization(center=False, scale=False)
self.L = L
self.width = width
self.channel = channel
def call(self, x):
batch_size = self.batch_size
L = self.L
width = self.width
intra_rnn = self.intra_rnn
intra_fc = self.intra_fc
intra_ln = self.intra_ln
inter_rnn = self.inter_rnn
inter_fc = self.inter_fc
inter_ln = self.inter_ln
channel = self.channel
causal = self.causal
# Intra-Chunk Processing
# input shape (bs,T,F,C) --> (bs*T,F,C)
intra_LSTM_input = tf.reshape(x,[-1,width,channel])
# (bs*T,F,C)
intra_LSTM_out = intra_rnn(intra_LSTM_input)
# (bs*T,F,C) channel axis dense
intra_dense_out = intra_fc(intra_LSTM_out)
if causal:
# (bs*T,F,C) --> (bs,T,F,C) Freq and channel norm
intra_ln_input = tf.reshape(intra_dense_out,[batch_size,-1,width,channel])
intra_out = intra_ln(intra_ln_input)
else:
# (bs*T,F,C) --> (bs,T*F*C) global norm
intra_ln_input = tf.reshape(intra_dense_out,[batch_size,-1])
intra_ln_out = intra_ln(intra_ln_input)
intra_out = tf.reshape(intra_ln_out,[batch_size,L,width,channel])
# (bs,T,F,C)
intra_out = keras.layers.Add()([x,intra_out])
#%% Inter-Chunk Processing
# (bs,T,F,C) --> (bs,F,T,C)
inter_LSTM_input = tf.transpose(intra_out,[0,2,1,3])
# (bs,F,T,C) --> (bs*F,T,C)
inter_LSTM_input = tf.reshape(inter_LSTM_input,[batch_size*width,L,channel])
inter_LSTM_out = inter_rnn(inter_LSTM_input)
# (bs,F,T,C)
inter_dense_out = inter_fc(inter_LSTM_out)
inter_dense_out = tf.reshape(inter_dense_out,[batch_size,width,L,channel])
if causal:
# (bs,F,T,C) --> (bs,T,F,C)
inter_ln_input = tf.transpose(inter_dense_out,[0,2,1,3])
inter_out = inter_ln(inter_ln_input)
else:
# (bs,F,T,C) --> (bs,F*T*C)
inter_ln_input = tf.reshape(inter_dense_out,[batch_size,-1])
inter_ln_out = inter_ln(inter_ln_input)
inter_out = tf.reshape(inter_ln_out,[batch_size,width,L,channel])
# (bs,F,T,C) --> (bs,T,F,C)
inter_out = tf.transpose(inter_out,[0,2,1,3])
# (bs,T,F,C)
inter_out = keras.layers.Add()([intra_out,inter_out])
return inter_out