-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
114 lines (88 loc) · 4.44 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Where everything comes together
"""
import torch
from discriminators.isvoice_dtor import get_isvoice_discriminator
from discriminators.content_dtor import get_content_discriminator
from discriminators.identity_dtor import get_identity_discriminator
from transformer.mem_transformer import MemTransformer
from embedding import embeddings
import warnings
def get_transformer(config):
"""
Returns a neural network which takes a stylevector (N x S) tensor
and audiosample (N x 1 x T x M) and returns another (N x 1 x T x M) tensor
representing the output mel. (Where audiosample i has been transformed to
have style stylevector[i])
Neural net should extend pytorch.module so it can be easily checkpointed?
where S is the dimensionality of style vector and M is the number of
mel-spectrogram channels
"""
return MemTransformer(config)
def get_embedder_and_size(mel_size, path=None, cuda=False,
embedding_size=512):
"""
Returns a embedding model captures the sytle of a speakers voice
returns (Batch Size, style_size)
where network which takes a transformation of a speakers utterances (Batch Size x 1 x Frames x Features)
"""
if path != None:
loaded_style_size, num_classes, num_features, num_frames = embeddings.parse_params(path)
if mel_size != num_features:
raise RuntimeError("Loaded embedder yields mel size of " + str(num_features)
+ " but mel size of " + str(mel_size) + " requested")
if embedding_size != loaded_style_size:
raise RuntimeError("Loaded embedder yields style size of " + str(loaded_style_size)
+ " but style size of " + str(embedding_size) + " requested")
embedder = embeddings.load_embedder(
checkpoint_path=path,
embedding_size=embedding_size,
num_classes=num_classes,
num_features=num_features,
frame_dim=num_frames,
cuda=cuda
)
else:
print("No Embedder , initializing random weights")
embedder = embeddings.load_embedder(embedding_size=embedding_size,
num_features=mel_size,
cuda=cuda)
return (embedder, embedding_size)
class ProjectModel(torch.nn.Module):
def __init__(self, config, embedder_path, mel_size, style_size, identity_mode, cuda):
"""
:style_size: The size of the stylevector produced by embedder
:mel_size: The number of frequency channels in the mel-cepstrogram
"""
super().__init__()
self.mel_size = mel_size
self.embedder, self.style_size = get_embedder_and_size(mel_size=mel_size,
path=embedder_path,
embedding_size=style_size,
cuda=cuda)
config["d_model"] = self.mel_size
config["d_style"] = self.style_size
self.isvoice_dtor = get_isvoice_discriminator(self.mel_size)
# self.content_dtor = get_content_discriminator(self.mel_size)
self.identity_dtor = get_identity_discriminator(self.style_size,
identity_mode=identity_mode)
self.transformer = get_transformer(config)
def forward(self, source_mel, target_style):
"""
:target_style: An (N x S) tensor
:input_audio: A (N x 1 x T x M) tensor
returns the tuple (out_audio, isvoice_prob,
samecontent_prob, targetperson_prob)
where the probabilities are scalar Tensors
"""
transformed_mel = self.transformer(source_mel, target_style)
transformed_style = self.embedder(transformed_mel)
transformed_style = transformed_style.reshape(target_style.shape)
isvoice_prob = self.isvoice_dtor(transformed_mel, None)
# content_prob = self.content_dtor(torch.stack((source_mel,
# transformed_mel),
# dim=1), None)
identity_prob = self.identity_dtor(torch.cat((target_style,
transformed_style),
dim=1), None)
return transformed_mel, isvoice_prob, identity_prob