-
Notifications
You must be signed in to change notification settings - Fork 22
/
train.py
51 lines (43 loc) Β· 1.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import logging
import bentoml
from transformers import (
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
WhisperForConditionalGeneration,
WhisperProcessor,
)
logging.basicConfig(level=logging.WARN)
if __name__ == "__main__":
t5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
t5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
t5_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
whisper_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny"
)
whisper_model.config.forced_decoder_ids = None
saved_t5_processor = bentoml.transformers.save_model(
"speecht5_tts_processor", t5_processor
)
print(f"Saved: {saved_t5_processor}")
saved_t5_model = bentoml.transformers.save_model(
"speecht5_tts_model",
t5_model,
signatures={"generate_speech": {"batchable": False}},
)
print(f"Saved: {saved_t5_model}")
saved_t5_vocoder = bentoml.transformers.save_model(
"speecht5_tts_vocoder", t5_vocoder
)
print(f"Saved: {saved_t5_vocoder}")
saved_whisper_processor = bentoml.transformers.save_model(
"whisper_processor",
whisper_processor,
)
print(f"Saved: {saved_whisper_processor}")
saved_whisper_model = bentoml.transformers.save_model(
"whisper_model",
whisper_model,
)
print(f"Saved: {saved_whisper_model}")