forked from erew123/alltalk_tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tts_server.py
1096 lines (985 loc) · 46.4 KB
/
tts_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import time
import os
from pathlib import Path
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import io
import wave
##########################
#### Webserver Imports####
##########################
from fastapi import (
FastAPI,
Form,
Request,
Response,
Depends,
HTTPException,
)
from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse, FileResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
###########################
#### STARTUP VARIABLES ####
###########################
# STARTUP VARIABLE - Create "this_dir" variable as the current script directory
this_dir = Path(__file__).parent.resolve()
# STARTUP VARIABLE - Set "device" to cuda if exists, otherwise cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
# STARTUP VARIABLE - Import languges file for Gradio to be able to display them in the interface
with open(this_dir / "languages.json", encoding="utf8") as f:
languages = json.load(f)
# Base setting for a possible FineTuned model existing and the loader being available
tts_method_xtts_ft = False
#################################################################
#### LOAD PARAMS FROM confignew.json - REQUIRED FOR BRANDING ####
#################################################################
# Load config file and get settings
def load_config(file_path):
with open(file_path, "r") as configfile_path:
configfile_data = json.load(configfile_path)
return configfile_data
# Define the path to the confignew.json file
configfile_path = this_dir / "confignew.json"
# Load confignew.json and assign it to a different variable (config_data)
params = load_config(configfile_path)
# check someone hasnt enabled lowvram on a system thats not cuda enabled
params["low_vram"] = "false" if not torch.cuda.is_available() else params["low_vram"]
# Define the path to the JSON file
config_file_path = this_dir / "modeldownload.json"
#############################################
#### LOAD PARAMS FROM MODELDOWNLOAD.JSON ####
############################################
# This is used only in the instance that someone has changed their model path
# Define the path to the JSON file
modeldownload_config_file_path = this_dir / "modeldownload.json"
# Check if the JSON file exists
if modeldownload_config_file_path.exists():
with open(modeldownload_config_file_path, "r") as modeldownload_config_file:
modeldownload_settings = json.load(modeldownload_config_file)
# Extract settings from the loaded JSON
modeldownload_base_path = Path(modeldownload_settings.get("base_path", ""))
modeldownload_model_path = Path(modeldownload_settings.get("model_path", ""))
else:
# Default settings if the JSON file doesn't exist or is empty
print(
f"[{params['branding']}Startup] \033[91mWarning\033[0m modeldownload.config is missing so please re-download it and save it in the alltalk_tts main folder."
)
########################
#### STARTUP CHECKS ####
########################
try:
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
print(
f"[{params['branding']}Startup] \033[91mWarning\033[0m Could not find the TTS module. Make sure to install the requirements for the alltalk_tts extension.",
f"[{params['branding']}Startup] \033[91mWarning\033[0m Linux / Mac:\npip install -r extensions/alltalk_tts/requirements.txt\n",
f"[{params['branding']}Startup] \033[91mWarning\033[0m Windows:\npip install -r extensions\\alltalk_tts\\requirements.txt\n",
f"[{params['branding']}Startup] \033[91mWarning\033[0m If you used the one-click installer, paste the command above in the terminal window launched after running the cmd_ script. On Windows, that's cmd_windows.bat."
)
raise
# DEEPSPEED Import - Check for DeepSpeed and import it if it exists
try:
import deepspeed
deepspeed_installed = True
print(f"[{params['branding']}Startup] DeepSpeed \033[93mDetected\033[0m")
print(
f"[{params['branding']}Startup] Activate DeepSpeed in {params['branding']} settings"
)
except ImportError:
deepspeed_installed = False
print(
f"[{params['branding']}Startup] DeepSpeed \033[93mNot Detected\033[0m. See https://github.com/microsoft/DeepSpeed"
)
@asynccontextmanager
async def startup_shutdown(no_actual_value_it_demanded_something_be_here):
await setup()
yield
# Shutdown logic
# Create FastAPI app with lifespan
app = FastAPI(lifespan=startup_shutdown)
# Allow all origins, and set other CORS options
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Set this to the specific origins you want to allow
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#####################################
#### MODEL LOADING AND UNLOADING ####
#####################################
# MODEL LOADERS Picker For API TTS, API Local, XTTSv2 Local, XTTSv2 FT
async def setup():
global device
# Set a timer to calculate load times
generate_start_time = time.time() # Record the start time of loading the model
# Start loading the correct model as set by "tts_method_api_tts", "tts_method_api_local" or "tts_method_xtts_local" being True/False
if params["tts_method_api_tts"]:
print(
f"[{params['branding']}Model] \033[94mAPI TTS Loading\033[0m {params['tts_model_name']} \033[94minto\033[93m",
device,
"\033[0m",
)
model = await api_load_model()
elif params["tts_method_api_local"]:
print(
f"[{params['branding']}Model] \033[94mAPI Local Loading\033[0m {modeldownload_model_path} \033[94minto\033[93m",
device,
"\033[0m",
)
model = await api_manual_load_model()
elif params["tts_method_xtts_local"]:
print(
f"[{params['branding']}Model] \033[94mXTTSv2 Local Loading\033[0m {modeldownload_model_path} \033[94minto\033[93m",
device,
"\033[0m",
)
model = await xtts_manual_load_model()
elif tts_method_xtts_ft:
print(
f"[{params['branding']}Model] \033[94mXTTSv2 FT Loading\033[0m /models/fintuned/model.pth \033[94minto\033[93m",
device,
"\033[0m",
)
model = await xtts_ft_manual_load_model()
# Create an end timer for calculating load times
generate_end_time = time.time()
# Calculate start time minus end time
generate_elapsed_time = generate_end_time - generate_start_time
# Print out the result of the load time
print(
f"[{params['branding']}Model] \033[94mModel Loaded in \033[93m{generate_elapsed_time:.2f} seconds.\033[0m"
)
# Set "tts_model_loaded" to true
params["tts_model_loaded"] = True
# Set the output path for wav files
output_directory = this_dir / params["output_folder_wav_standalone"]
output_directory.mkdir(parents=True, exist_ok=True)
#Path(f'this_folder/outputs/').mkdir(parents=True, exist_ok=True)
# MODEL LOADER For "API TTS"
async def api_load_model():
global model
model = TTS(params["tts_model_name"]).to(device)
return model
# MODEL LOADER For "API Local"
async def api_manual_load_model():
global model
# check to see if a custom path has been set in modeldownload.json and use that path to load the model if so
if str(modeldownload_base_path) == "models":
model = TTS(
model_path=this_dir / "models" / modeldownload_model_path,
config_path=this_dir / "models" / modeldownload_model_path / "config.json",
).to(device)
else:
print(
f"[{params['branding']}Model] \033[94mInfo\033[0m Loading your custom model set in \033[93mmodeldownload.json\033[0m:",
modeldownload_base_path / modeldownload_model_path,
)
model = TTS(
model_path=modeldownload_base_path / modeldownload_model_path,
config_path=modeldownload_base_path / modeldownload_model_path / "config.json",
).to(device)
return model
# MODEL LOADER For "XTTSv2 Local"
async def xtts_manual_load_model():
global model
config = XttsConfig()
# check to see if a custom path has been set in modeldownload.json and use that path to load the model if so
if str(modeldownload_base_path) == "models":
config_path = this_dir / "models" / modeldownload_model_path / "config.json"
vocab_path_dir = this_dir / "models" / modeldownload_model_path / "vocab.json"
checkpoint_dir = this_dir / "models" / modeldownload_model_path
else:
print(
f"[{params['branding']}Model] \033[94mInfo\033[0m Loading your custom model set in \033[93mmodeldownload.json\033[0m:",
modeldownload_base_path / modeldownload_model_path,
)
config_path = modeldownload_base_path / modeldownload_model_path / "config.json"
vocab_path_dir = modeldownload_base_path / modeldownload_model_path / "vocab.json"
checkpoint_dir = modeldownload_base_path / modeldownload_model_path
config.load_json(str(config_path))
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_dir=str(checkpoint_dir),
vocab_path=str(vocab_path_dir),
use_deepspeed=params["deepspeed_activate"],
)
model.to(device)
return model
# MODEL LOADER For "XTTSv2 FT"
async def xtts_ft_manual_load_model():
global model
config = XttsConfig()
config_path = this_dir / "models" / "trainedmodel" / "config.json"
vocab_path_dir = this_dir / "models" / "trainedmodel" / "vocab.json"
checkpoint_dir = this_dir / "models" / "trainedmodel"
config.load_json(str(config_path))
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_dir=str(checkpoint_dir),
vocab_path=str(vocab_path_dir),
use_deepspeed=params["deepspeed_activate"],
)
model.to(device)
return model
# MODEL UNLOADER
async def unload_model(model):
print(f"[{params['branding']}Model] \033[94mUnloading model \033[0m")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
params["tts_model_loaded"] = False
return None
# MODEL - Swap model based on Gradio selection API TTS, API Local, XTTSv2 Local
async def handle_tts_method_change(tts_method):
global model
global tts_method_xtts_ft
# Update the params dictionary based on the selected radio button
print(
f"[{params['branding']}Model] \033[94mChanging model \033[92m(Please wait 15 seconds)\033[0m"
)
# Set other parameters to False
if tts_method == "API TTS":
params["tts_method_api_local"] = False
params["tts_method_xtts_local"] = False
params["tts_method_api_tts"] = True
params["deepspeed_activate"] = False
tts_method_xtts_ft = False
elif tts_method == "API Local":
params["tts_method_api_tts"] = False
params["tts_method_xtts_local"] = False
params["tts_method_api_local"] = True
params["deepspeed_activate"] = False
tts_method_xtts_ft = False
elif tts_method == "XTTSv2 Local":
params["tts_method_api_tts"] = False
params["tts_method_api_local"] = False
params["tts_method_xtts_local"] = True
tts_method_xtts_ft = False
elif tts_method == "XTTSv2 FT":
tts_method_xtts_ft = True
params["tts_method_api_tts"] = False
params["tts_method_api_local"] = False
params["tts_method_xtts_local"] = False
# Unload the current model
model = await unload_model(model)
# Load the correct model based on the updated params
await setup()
# MODEL WEBSERVER- API Swap Between Models
@app.route("/api/reload", methods=["POST"])
async def reload(request: Request):
tts_method = request.query_params.get("tts_method")
if tts_method not in ["API TTS", "API Local", "XTTSv2 Local", "XTTSv2 FT"]:
return {"status": "error", "message": "Invalid TTS method specified"}
await handle_tts_method_change(tts_method)
return Response(
content=json.dumps({"status": "model-success"}), media_type="application/json"
)
##################
#### LOW VRAM ####
##################
# LOW VRAM - MODEL MOVER VRAM(cuda)<>System RAM(cpu) for Low VRAM setting
async def switch_device():
global model, device
# Check if CUDA is available before performing GPU-related operations
if torch.cuda.is_available():
if device == "cuda":
device = "cpu"
model.to(device)
torch.cuda.empty_cache()
else:
device == "cpu"
device = "cuda"
model.to(device)
@app.post("/api/lowvramsetting")
async def set_low_vram(request: Request, new_low_vram_value: bool):
global device
try:
if new_low_vram_value is None:
raise ValueError("Missing 'low_vram' parameter")
if params["low_vram"] == new_low_vram_value:
return Response(
content=json.dumps(
{
"status": "success",
"message": f"[{params['branding']}Model] LowVRAM is already {'enabled' if new_low_vram_value else 'disabled'}.",
}
)
)
params["low_vram"] = new_low_vram_value
if params["low_vram"]:
await unload_model(model)
if torch.cuda.is_available():
device = "cpu"
print(
f"[{params['branding']}Model] \033[94mChanging model \033[92m(Please wait 15 seconds)\033[0m"
)
print(
f"[{params['branding']}Model] \033[94mLowVRAM Enabled.\033[0m Model will move between \033[93mVRAM(cuda) <> System RAM(cpu)\033[0m"
)
await setup()
else:
# Handle the case where CUDA is not available
print(
f"[{params['branding']}Model] \033[91mError:\033[0m Nvidia CUDA is not available on this system. Unable to use LowVRAM mode."
)
params["low_vram"] = False
else:
await unload_model(model)
if torch.cuda.is_available():
device = "cuda"
print(
f"[{params['branding']}Model] \033[94mChanging model \033[92m(Please wait 15 seconds)\033[0m"
)
print(
f"[{params['branding']}Model] \033[94mLowVRAM Disabled.\033[0m Model will stay in \033[93mVRAM(cuda)\033[0m"
)
await setup()
else:
# Handle the case where CUDA is not available
print(
f"[{params['branding']}Model] \033[91mError:\033[0m Nvidia CUDA is not available on this system. Unable to use LowVRAM mode."
)
params["low_vram"] = False
return Response(content=json.dumps({"status": "lowvram-success"}))
except Exception as e:
return Response(content=json.dumps({"status": "error", "message": str(e)}))
###################
#### DeepSpeed ####
###################
# DEEPSPEED - Reload the model when DeepSpeed checkbox is enabled/disabled
async def handle_deepspeed_change(value):
global model
if value:
# DeepSpeed enabled
print(f"[{params['branding']}Model] \033[93mDeepSpeed Activating\033[0m")
print(
f"[{params['branding']}Model] \033[94mChanging model \033[92m(DeepSpeed can take 30 seconds to activate)\033[0m"
)
print(
f"[{params['branding']}Model] \033[91mInformation\033[0m If you have not set CUDA_HOME path, DeepSpeed may fail to load/activate"
)
print(
f"[{params['branding']}Model] \033[91mInformation\033[0m DeepSpeed needs to find nvcc from the CUDA Toolkit. Please check your CUDA_HOME path is"
)
print(
f"[{params['branding']}Model] \033[91mInformation\033[0m pointing to the correct location and use 'set CUDA_HOME=putyoutpathhere' (Windows) or"
)
print(
f"[{params['branding']}Model] \033[91mInformation\033[0m 'export CUDA_HOME=putyoutpathhere' (Linux) within your Python Environment"
)
model = await unload_model(model)
params["tts_method_api_tts"] = False
params["tts_method_api_local"] = False
params["tts_method_xtts_local"] = True
params["deepspeed_activate"] = True
await setup()
else:
# DeepSpeed disabled
print(f"[{params['branding']}Model] \033[93mDeepSpeed De-Activating\033[0m")
print(
f"[{params['branding']}Model] \033[94mChanging model \033[92m(Please wait 15 seconds)\033[0m"
)
params["deepspeed_activate"] = False
model = await unload_model(model)
await setup()
return value # Return new checkbox value
# DEEPSPEED WEBSERVER- API Enable/Disable DeepSpeed
@app.post("/api/deepspeed")
async def deepspeed(request: Request, new_deepspeed_value: bool):
try:
if new_deepspeed_value is None:
raise ValueError("Missing 'deepspeed' parameter")
if params["deepspeed_activate"] == new_deepspeed_value:
return Response(
content=json.dumps(
{
"status": "success",
"message": f"DeepSpeed is already {'enabled' if new_deepspeed_value else 'disabled'}.",
}
)
)
params["deepspeed_activate"] = new_deepspeed_value
await handle_deepspeed_change(params["deepspeed_activate"])
return Response(content=json.dumps({"status": "deepspeed-success"}))
except Exception as e:
return Response(content=json.dumps({"status": "error", "message": str(e)}))
########################
#### TTS GENERATION ####
########################
# TTS VOICE GENERATION METHODS (called from voice_preview and output_modifer)
async def generate_audio(text, voice, language, output_file, streaming=False):
# Get the async generator from the internal function
response = generate_audio_internal(text, voice, language, output_file, streaming)
# If streaming, then return the generator as-is, otherwise just exhaust it and return
if streaming:
return response
async for _ in response:
pass
async def generate_audio_internal(text, voice, language, output_file, streaming):
global model
if params["low_vram"] and device == "cpu":
await switch_device()
generate_start_time = time.time() # Record the start time of generating TTS
# XTTSv2 LOCAL & Xttsv2 FT Method
if params["tts_method_xtts_local"] or tts_method_xtts_ft:
print(f"[{params['branding']}TTSGen] {text}")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
audio_path=[f"{this_dir}/voices/{voice}"],
gpt_cond_len=model.config.gpt_cond_len,
max_ref_length=model.config.max_ref_len,
sound_norm_refs=model.config.sound_norm_refs,
)
# Common arguments for both functions
common_args = {
"text": text,
"language": language,
"gpt_cond_latent": gpt_cond_latent,
"speaker_embedding": speaker_embedding,
"temperature": float(params["local_temperature"]),
"length_penalty": float(model.config.length_penalty),
"repetition_penalty": float(params["local_repetition_penalty"]),
"top_k": int(model.config.top_k),
"top_p": float(model.config.top_p),
"enable_text_splitting": True
}
# Determine the correct inference function and add streaming specific argument if needed
inference_func = model.inference_stream if streaming else model.inference
if streaming:
common_args["stream_chunk_size"] = 20
# Call the appropriate function
output = inference_func(**common_args)
# Process the output based on streaming or non-streaming
if streaming:
# Streaming-specific operations
file_chunks = []
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(1)
vfout.setsampwidth(2)
vfout.setframerate(24000)
vfout.writeframes(b"")
wav_buf.seek(0)
yield wav_buf.read()
for i, chunk in enumerate(output):
file_chunks.append(chunk)
if isinstance(chunk, list):
chunk = torch.cat(chunk, dim=0)
chunk = chunk.clone().detach().cpu().numpy()
chunk = chunk[None, : int(chunk.shape[0])]
chunk = np.clip(chunk, -1, 1)
chunk = (chunk * 32767).astype(np.int16)
yield chunk.tobytes()
else:
# Non-streaming-specific operation
torchaudio.save(output_file, torch.tensor(output["wav"]).unsqueeze(0), 24000)
# API LOCAL Methods
elif params["tts_method_api_local"]:
# Streaming only allowed for XTTSv2 local
if streaming:
raise ValueError("Streaming is only supported in XTTSv2 local")
# Set the correct output path (different from the if statement)
print(f"[{params['branding']}TTSGen] Using API Local")
model.tts_to_file(
text=text,
file_path=output_file,
speaker_wav=[f"{this_dir}/voices/{voice}"],
language=language,
temperature=float(params["local_temperature"]),
length_penalty=model.config.length_penalty,
repetition_penalty=float(params["local_repetition_penalty"]),
top_k=model.config.top_k,
top_p=model.config.top_p,
)
# API TTS
elif params["tts_method_api_tts"]:
# Streaming only allowed for XTTSv2 local
if streaming:
raise ValueError("Streaming is only supported in XTTSv2 local")
print(f"[{params['branding']}TTSGen] Using API TTS")
model.tts_to_file(
text=text,
file_path=output_file,
speaker_wav=[f"{this_dir}/voices/{voice}"],
language=language,
)
# Print Generation time and settings
generate_end_time = time.time() # Record the end time to generate TTS
generate_elapsed_time = generate_end_time - generate_start_time
print(
f"[{params['branding']}TTSGen] \033[93m{generate_elapsed_time:.2f} seconds. \033[94mLowVRAM: \033[33m{params['low_vram']} \033[94mDeepSpeed: \033[33m{params['deepspeed_activate']}\033[0m"
)
# Move model back to cpu system ram if needed.
if params["low_vram"] and device == "cuda":
await switch_device()
return
# TTS VOICE GENERATION METHODS - generate TTS API
@app.route("/api/generate", methods=["POST"])
async def generate(request: Request):
try:
# Get parameters from JSON body
data = await request.json()
text = data["text"]
voice = data["voice"]
language = data["language"]
output_file = data["output_file"]
streaming = False
# Generation logic
response = await generate_audio(text, voice, language, output_file, streaming)
if streaming:
return StreamingResponse(response, media_type="audio/wav")
return JSONResponse(
content={"status": "generate-success", "data": {"audio_path": output_file}}
)
except Exception as e:
return JSONResponse(content={"status": "error", "message": str(e)})
###################################################
#### POPULATE FILES LIST FROM VOICES DIRECTORY ####
###################################################
# List files in the "voices" directory
def list_files(directory):
files = [
f
for f in os.listdir(directory)
if os.path.isfile(os.path.join(directory, f)) and f.endswith(".wav")
]
return files
#############################
#### JSON CONFIG UPDATER ####
#############################
# Create an instance of Jinja2Templates for rendering HTML templates
templates = Jinja2Templates(directory=this_dir / "templates")
# Create a dependency to get the current JSON data
def get_json_data():
with open(this_dir / "confignew.json", "r") as json_file:
data = json.load(json_file)
return data
# Define an endpoint function
@app.get("/settings")
async def get_settings(request: Request):
wav_files = list_files(this_dir / "voices")
# Render the template with the current JSON data and list of WAV files
return templates.TemplateResponse(
"generate_form.html",
{
"request": request,
"data": get_json_data(),
"modeldownload_model_path": modeldownload_model_path,
"wav_files": wav_files,
},
)
# Define an endpoint to serve static files
app.mount("/static", StaticFiles(directory=str(this_dir / "templates")), name="static")
@app.post("/update-settings")
async def update_settings(
request: Request,
activate: bool = Form(...),
autoplay: bool = Form(...),
deepspeed_activate: bool = Form(...),
delete_output_wavs: str = Form(...),
ip_address: str = Form(...),
language: str = Form(...),
local_temperature: str = Form(...),
local_repetition_penalty: str = Form(...),
low_vram: bool = Form(...),
tts_model_loaded: bool = Form(...),
tts_model_name: str = Form(...),
narrator_enabled: bool = Form(...),
narrator_voice: str = Form(...),
output_folder_wav: str = Form(...),
port_number: str = Form(...),
remove_trailing_dots: bool = Form(...),
show_text: bool = Form(...),
tts_method: str = Form(...),
voice: str = Form(...),
data: dict = Depends(get_json_data),
):
# Update the settings based on the form values
data["activate"] = activate
data["autoplay"] = autoplay
data["deepspeed_activate"] = deepspeed_activate
data["delete_output_wavs"] = delete_output_wavs
data["ip_address"] = ip_address
data["language"] = language
data["local_temperature"] = local_temperature
data["local_repetition_penalty"] = local_repetition_penalty
data["low_vram"] = low_vram
data["tts_model_loaded"] = tts_model_loaded
data["tts_model_name"] = tts_model_name
data["narrator_enabled"] = narrator_enabled
data["narrator_voice"] = narrator_voice
data["output_folder_wav"] = output_folder_wav
data["port_number"] = port_number
data["remove_trailing_dots"] = remove_trailing_dots
data["show_text"] = show_text
data["tts_method_api_local"] = tts_method == "api_local"
data["tts_method_api_tts"] = tts_method == "api_tts"
data["tts_method_xtts_local"] = tts_method == "xtts_local"
data["voice"] = voice
# Save the updated settings back to the JSON file
with open(this_dir / "confignew.json", "w") as json_file:
json.dump(data, json_file)
# Redirect to the settings page to display the updated settings
return RedirectResponse(url="/settings", status_code=303)
##################################
#### SETTINGS PAGE DEMO VOICE ####
##################################
@app.get("/tts-demo-request", response_class=StreamingResponse)
async def tts_demo_request_streaming(text: str, voice: str, language: str, output_file: str):
try:
output_file_path = this_dir / "outputs" / output_file
stream = await generate_audio(text, voice, language, output_file_path, streaming=True)
return StreamingResponse(stream, media_type="audio/wav")
except Exception as e:
print(f"An error occurred: {e}")
return JSONResponse(content={"error": "An error occurred"}, status_code=500)
@app.post("/tts-demo-request", response_class=JSONResponse)
async def tts_demo_request(request: Request, text: str = Form(...), voice: str = Form(...), language: str = Form(...), output_file: str = Form(...)):
try:
output_file_path = this_dir / "outputs" / output_file
await generate_audio(text, voice, language, output_file_path, streaming=False)
return JSONResponse(content={"output_file_path": str(output_file)}, status_code=200)
except Exception as e:
print(f"An error occurred: {e}")
return JSONResponse(content={"error": "An error occurred"}, status_code=500)
# Gives web access to the output files
@app.get("/audio/{filename}")
async def get_audio(filename: str):
audio_path = this_dir / "outputs" / filename
return FileResponse(audio_path)
@app.get("/audiocache/{filename}")
async def get_audio(filename: str):
audio_path = Path("outputs") / filename
if not audio_path.is_file():
raise HTTPException(status_code=404, detail="File not found")
response = FileResponse(
path=audio_path,
media_type='audio/wav',
filename=filename
)
# Set caching headers
response.headers["Cache-Control"] = "public, max-age=604800" # Cache for one week
response.headers["ETag"] = str(audio_path.stat().st_mtime) # Use the file's last modified time as a simple ETag
return response
#########################
#### VOICES LIST API ####
#########################
# Define the new endpoint
@app.get("/api/voices")
async def get_voices():
wav_files = list_files(this_dir / "voices")
return {"voices": wav_files}
###########################
#### PREVIEW VOICE API ####
###########################
@app.post("/api/previewvoice/", response_class=JSONResponse)
async def preview_voice(request: Request, voice: str = Form(...)):
try:
# Hardcoded settings
language = "en"
output_file_name = "api_preview_voice"
# Clean the voice filename for inclusion in the text
clean_voice_filename = re.sub(r'\.wav$', '', voice.replace(' ', '_'))
clean_voice_filename = re.sub(r'[^a-zA-Z0-9]', ' ', clean_voice_filename)
# Generate the audio
text = f"Hello, this is a preview of voice {clean_voice_filename}."
# Generate the audio
output_file_path = this_dir / "outputs" / f"{output_file_name}.wav"
await generate_audio(text, voice, language, output_file_path, streaming=False)
# Generate the URL
output_file_url = f'http://{params["ip_address"]}:{params["port_number"]}/audio/{output_file_name}.wav'
# Return the response with both local file path and URL
return JSONResponse(
content={
"status": "generate-success",
"output_file_path": str(output_file_path),
"output_file_url": str(output_file_url),
},
status_code=200,
)
except Exception as e:
print(f"An error occurred: {e}")
return JSONResponse(content={"error": "An error occurred"}, status_code=500)
########################
#### GENERATION API ####
########################
import html
import re
import uuid
import numpy as np
import soundfile as sf
import sys
# Check for PortAudio library on Linux
try:
import sounddevice as sd
sounddevice_installed=True
except OSError:
print(f"[{params['branding']}Startup] \033[91mInfo\033[0m PortAudio library not found. If you wish to play TTS in standalone mode through the API suite")
print(f"[{params['branding']}Startup] \033[91mInfo\033[0m please install PortAudio. This will not affect any other features or use of Alltalk.")
print(f"[{params['branding']}Startup] \033[91mInfo\033[0m If you don't know what the API suite is, then this message is nothing to worry about.")
sounddevice_installed=False
if sys.platform.startswith('linux'):
print(f"[{params['branding']}Startup] \033[91mInfo\033[0m On Linux, you can use the following command to install PortAudio:")
print(f"[{params['branding']}Startup] \033[91mInfo\033[0m sudo apt-get install portaudio19-dev")
from typing import Union, Dict
from pydantic import BaseModel, ValidationError, Field
def play_audio(file_path, volume):
data, fs = sf.read(file_path)
sd.play(volume * data, fs)
sd.wait()
class Request(BaseModel):
# Define the structure of the 'Request' class if needed
pass
class JSONInput(BaseModel):
text_input: str = Field(..., max_length=1000, description="text_input needs to be 1000 characters or less.")
text_filtering: str = Field(..., pattern="^(none|standard|html)$", description="text_filtering needs to be 'none', 'standard' or 'html'.")
character_voice_gen: str = Field(..., pattern="^.*\.wav$", description="character_voice_gen needs to be the name of a wav file e.g. mysample.wav.")
narrator_enabled: bool = Field(..., description="narrator_enabled needs to be true or false.")
narrator_voice_gen: str = Field(..., pattern="^.*\.wav$", description="narrator_voice_gen needs to be the name of a wav file e.g. mysample.wav.")
text_not_inside: str = Field(..., pattern="^(character|narrator)$", description="text_not_inside needs to be 'character' or 'narrator'.")
language: str = Field(..., pattern="^(ar|zh-cn|cs|nl|en|fr|de|hu|it|ja|ko|pl|pt|ru|es|tr)$", description="language needs to be one of the following ar|zh-cn|cs|nl|en|fr|de|hu|it|ja|ko|pl|pt|ru|es|tr.")
output_file_name: str = Field(..., pattern="^[a-zA-Z0-9_]+$", description="output_file_name needs to be the name without any special characters or file extension e.g. 'filename'")
output_file_timestamp: bool = Field(..., description="output_file_timestamp needs to be true or false.")
autoplay: bool = Field(..., description="autoplay needs to be a true or false value.")
autoplay_volume: float = Field(..., ge=0.1, le=1.0, description="autoplay_volume needs to be from 0.1 to 1.0")
@classmethod
def validate_autoplay_volume(cls, value):
if not (0.1 <= value <= 1.0):
raise ValueError("Autoplay volume must be between 0.1 and 1.0")
return value
class TTSGenerator:
@staticmethod
def validate_json_input(json_data: Union[Dict, str]) -> Union[None, str]:
try:
if isinstance(json_data, str):
json_data = json.loads(json_data)
JSONInput(**json_data)
return None # JSON is valid
except ValidationError as e:
return str(e)
def process_text(text):
# Normalize HTML encoded quotes
text = html.unescape(text)
# Replace ellipsis with a single dot
text = re.sub(r'\.{3,}', '.', text)
# Pattern to identify combined narrator and character speech
combined_pattern = r'(\*[^*"]+\*|"[^"*]+")'
# List to hold parts of speech along with their type
ordered_parts = []
# Track the start of the next segment
start = 0
# Find all matches
for match in re.finditer(combined_pattern, text):
# Add the text before the match, if any, as ambiguous
if start < match.start():
ambiguous_text = text[start:match.start()].strip()
if ambiguous_text:
ordered_parts.append(('ambiguous', ambiguous_text))
# Add the matched part as either narrator or character
matched_text = match.group(0)
if matched_text.startswith('*') and matched_text.endswith('*'):
ordered_parts.append(('narrator', matched_text.strip('*').strip()))
elif matched_text.startswith('"') and matched_text.endswith('"'):
ordered_parts.append(('character', matched_text.strip('"').strip()))
else:
# In case of mixed or improperly formatted parts
if '*' in matched_text:
ordered_parts.append(('narrator', matched_text.strip('*').strip('"')))
else:
ordered_parts.append(('character', matched_text.strip('"').strip('*')))
# Update the start of the next segment
start = match.end()
# Add any remaining text after the last match as ambiguous
if start < len(text):
ambiguous_text = text[start:].strip()
if ambiguous_text:
ordered_parts.append(('ambiguous', ambiguous_text))
return ordered_parts
def standard_filtering(text_input):
text_output = (text_input
.replace("***", "")
.replace("**", "")
.replace("*", "")
.replace("\n\n", "\n")
.replace("'", "'")
)
return text_output
def combine(output_file_timestamp, output_file_name, audio_files):
audio = np.array([])
sample_rate = None
try:
for audio_file in audio_files:
audio_data, current_sample_rate = sf.read(audio_file)
if audio.size == 0:
audio = audio_data
sample_rate = current_sample_rate
elif sample_rate == current_sample_rate:
audio = np.concatenate((audio, audio_data))
else:
raise ValueError("Sample rates of input files are not consistent.")
except Exception as e:
# Handle exceptions (e.g., file not found, invalid audio format)
return None, None
if output_file_timestamp:
timestamp = int(time.time())
output_file_path = os.path.join(this_dir / "outputs" / f'{output_file_name}_{timestamp}_combined.wav')
output_file_url = f'http://{params["ip_address"]}:{params["port_number"]}/audio/{output_file_name}_{timestamp}_combined.wav'
output_cache_url = f'http://{params["ip_address"]}:{params["port_number"]}/audiocache/{output_file_name}_{timestamp}_combined.wav'
else:
output_file_path = os.path.join(this_dir / "outputs" / f'{output_file_name}_combined.wav')
output_file_url = f'http://{params["ip_address"]}:{params["port_number"]}/audio/{output_file_name}.wav'
output_cache_url = f'http://{params["ip_address"]}:{params["port_number"]}/audiocache/{output_file_name}.wav'
try:
sf.write(output_file_path, audio, samplerate=sample_rate)
# Clean up unnecessary files
for audio_file in audio_files:
os.remove(audio_file)
except Exception as e:
# Handle exceptions (e.g., failed to write output file)
return None, None
return output_file_path, output_file_url, output_cache_url
# Generation API (separate from text-generation-webui)
@app.post("/api/tts-generate", response_class=JSONResponse)
async def tts_generate(
text_input: str = Form(...),
text_filtering: str = Form(...),
character_voice_gen: str = Form(...),
narrator_enabled: bool = Form(...),
narrator_voice_gen: str = Form(...),
text_not_inside: str = Form(...),
language: str = Form(...),
output_file_name: str = Form(...),
output_file_timestamp: bool = Form(...),
autoplay: bool = Form(...),
autoplay_volume: float = Form(...),
streaming: bool = Form(False),
):
try:
json_input_data = {
"text_input": text_input,
"text_filtering": text_filtering,
"character_voice_gen": character_voice_gen,
"narrator_enabled": narrator_enabled,
"narrator_voice_gen": narrator_voice_gen,
"text_not_inside": text_not_inside,
"language": language,
"output_file_name": output_file_name,
"output_file_timestamp": output_file_timestamp,
"autoplay": autoplay,
"autoplay_volume": autoplay_volume,
"streaming": streaming,
}
JSONresult = TTSGenerator.validate_json_input(json_input_data)
if JSONresult is None:
pass
else:
return JSONResponse(content={"error": JSONresult}, status_code=400)
if narrator_enabled:
processed_parts = process_text(text_input)
audio_files_all_paragraphs = []
for part_type, part in processed_parts:
# Skip parts that are too short
if len(part.strip()) <= 3:
continue
# Determine the voice to use based on the part type
if part_type == 'narrator':
voice_to_use = narrator_voice_gen
print(f"[{params['branding']}TTSGen] \033[92mNarrator\033[0m") # Green
elif part_type == 'character':
voice_to_use = character_voice_gen
print(f"[{params['branding']}TTSGen] \033[36mCharacter\033[0m") # Yellow