-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
472 lines (362 loc) · 17.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
"""
To run:
uvicorn main:app --host 0.0.0.0 --port 8080 --reload
FastAPI produces documentation automagically! (see https://fastapi.tiangolo.com/tutorial/first-steps/ and https://fastapi.tiangolo.com/tutorial/metadata/)
Docs:
Swagger: localhost:8080/docs
Redoc: localhost:8080/redoc
Requirements
python 3.8
pip install fastapi[all]
aiofiles
pyaudio
Recording:
sudo apt install portaudio19-dev
TODO: Migrate to Python3.9 (Debian has issues using older Python versions than default, the new Raspberries use 3.9?)
TODO: Unlinked files persist. Cleanup on server shutdown, move unlinked files to different folder
TODO: Front expects a json-message as response to POST requests (e.g session adding). (partially?) Use status codes instead?
TODO: Unlink images/audio/motions
"""
import subprocess
from uuid import UUID, uuid4
from zipfile import ZipFile
from tempfile import TemporaryFile
from typing import Union
from fastapi import FastAPI, Form, Path, Body, WebSocket, Request
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, validator
from config import *
from recordingManager import RecordingManager
from data_handlers.audio import AudioShortcutsHandler
from data_handlers.motion import MotionsHandler
from data_handlers.action import ActionsHandler, ActionShortcutsHandler, MultiAction, UtteranceItem
from data_handlers.session import SessionsHandler, Session
from pepperConnectionManager import PepperConnectionManager
from recordingForwardingManager import RecordingForwardingManager
from addressForwardingManager import AddressForwarder
from data_handlers.file_operations import *
# Save file paths
SESSIONS_FILE = "data/sessions.json"
AUDIO_SHORTCUTS_FILE = "data/audio_shortcuts.json"
ACTION_SHORTCUTS_FILE = "data/action_shortcuts.json"
MOTIONS_FILE = "data/motions.json"
ADDITINAL_MOTIONS_FOLDER = "data/additional_motions"
# Create missing files/folders
if not os.path.isdir('data'):
os.mkdir('data')
for subdir in ['additional_motions', 'recordings', 'uploads', 'compressed_sessions']:
if not os.path.isdir(os.path.join('data', subdir)):
os.mkdir(os.path.join('data', subdir))
for memory_file in [SESSIONS_FILE, AUDIO_SHORTCUTS_FILE, ACTION_SHORTCUTS_FILE, MOTIONS_FILE]:
if not os.path.isfile(memory_file):
with open(memory_file, "w") as f:
f.write(json.dumps({os.path.basename(memory_file).rsplit(".", 1)[0]: []}))
# FastAPI config
tags_metadata = [
{"name": "Pepper",
"description": "Endpoints for communicating with Pepper"},
{"name": "Sessions",
"description": "Session manipulation"},
{"name": "General audio",
"description": "General audio queries"},
{"name": "Actions",
"description": "Action shortcut manipulation"},
{"name": "Audio",
"description": "Audio shortcut manipulation"},
{"name": "Motions",
"description": "Movements manipulation"},
{"name": "Uploads",
"description": "Session uploads"},
{"name": "Synthesis",
"description": "Calls to synthesize and save speech files via Neurokõne"},
{"name": "Recording",
"description": "Calls to start/end audio recording."},
{"name": "Maintenance",
"description": "Calls related to updating the front-end and back-end servers."}
]
app = FastAPI(
title="Pepper backend",
description="SA Tartu Ülikooli Kliinikumi kõnehäiretega laste robot Pepperi süsteemi toesserveri dokumentatsioon.",
version="0.10.1",
contact={
"name": "Rauno Jaaska",
"email": "rauno.jaaska@ut.ee",
},
openapi_tags=tags_metadata
)
app.mount("/data", StaticFiles(directory="data"), name="data")
# Allowed origins (see https://fastapi.tiangolo.com/tutorial/cors/)
origins = [
"*"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*']
)
# TODO: FIX ALSA ERROR OUTPUTS
# Workaround: ALSA errors are only displayed on the first instantiation, so we get it out of the way on boot
# Various methods of silencing stdout/stderr did not have ANY effect (before deadline), more work needed
# p = pyaudio.PyAudio()
# del p
# Helper objects
actions_handler = ActionsHandler()
motions_handler = MotionsHandler(MOTIONS_FILE, ADDITINAL_MOTIONS_FOLDER, actions_handler)
sessions_handler = SessionsHandler(SESSIONS_FILE, actions_handler, motions_handler)
audio_shortcuts_handler = AudioShortcutsHandler(AUDIO_SHORTCUTS_FILE, actions_handler)
action_shortcuts_handler = ActionShortcutsHandler(ACTION_SHORTCUTS_FILE, actions_handler, motions_handler)
if CLOUDFRONT_SERVER:
recording_manager = RecordingForwardingManager()
else:
recording_manager = RecordingManager()
pepper_connection_manager = PepperConnectionManager(motions_handler, actions_handler, recording_manager)
address_forwarder = AddressForwarder(10)
# Verbose 422 logging (see https://fastapi.tiangolo.com/tutorial/handling-errors/#use-the-requestvalidationerror-body)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
print(exc.body, exc.errors())
# Pepper
@app.websocket("/api/pepper/initiate")
async def pepper_connect(websocket: WebSocket):
await pepper_connection_manager.connect(websocket)
@app.get("/api/pepper/connect",
tags=['Pepper'], summary="Connect the client to a robot.")
async def connect_pepper(conn: str):
return await pepper_connection_manager.link(conn)
@app.get("/api/pepper/disconnect",
tags=['Pepper'], summary="Disconnect the client from a robot.")
async def disconnect_pepper(conn: str):
return await pepper_connection_manager.unlink(conn)
@app.get("/api/pepper/status",
tags=['Pepper'], summary="Check Pepper connection and recording storage status.")
def check_pepper(conn: Union[str, None] = None):
msg = {}
if not CLOUDFRONT_SERVER:
msg["rec_fill"] = recording_manager.update_recordings_size()
if conn:
msg['status'] = pepper_connection_manager.get_status(conn)
return msg
@app.post("/api/pepper/send_command",
tags=['Pepper'], summary="Send Pepper a command to fulfill.")
async def command_pepper(conn: str, item_json: dict = Body(...)):
return await pepper_connection_manager.send_command(UUID(item_json['item_id']), conn)
@app.get("/api/pepper/stop_video",
tags=['Pepper'], summary="Stop video playback.")
async def stop_video(conn: str):
return await pepper_connection_manager.clear_fragment(conn)
# Sessions
@app.get("/api/sessions/",
tags=['Sessions'], summary="Get all sessions.")
def get_sessions():
return sessions_handler.get_sorted_sessions()
@app.post("/api/sessions/",
tags=['Sessions'], summary="Add a session.")
def post_session(session: Session):
sessions_handler.add_session(session)
return {"message": "Session saved!"}
@app.get("/api/sessions/{session_id}",
tags=['Sessions'], summary="Get a specific session.")
def get_session(session_id: UUID = Path(...)):
return sessions_handler.get_session(session_id)
@app.put("/api/sessions/{session_id}",
tags=['Sessions'], summary="Update an existing session.")
def update_session(session: Session, session_id: UUID = Path(...)):
return sessions_handler.update_session(session_id, session)
# TODO: make all deletions non-destructive (?)
@app.delete("/api/sessions/{session_id}",
tags=['Sessions'], summary="Delete an existing session.")
def remove_session(session_id: UUID = Path(...)):
sessions_handler.remove_session(session_id)
return {'message': 'Session removed!'}
@app.get("/api/session_items/{session_item_id}",
tags=['Sessions'], summary="Get a specific question (SessionItem).")
def get_session_item(session_item_id: UUID = Path(...)):
return sessions_handler.get_session_item(session_item_id)
@app.get("/api/export_session/{session_id}",
tags=['Sessions'], summary="Export a session.")
def get_exported_session(session_id: UUID = Path(...)):
return compress_session(sessions_handler.get_session(session_id))
@app.delete("/api/instruction/{action_id}",
tags=['Sessions'], summary="Remove an action from a question (SessionItem).")
def delete_session_action(action_id: UUID = Path(...)):
return sessions_handler.remove_action(action_id)
# Action shortcuts
@app.get("/api/actions/",
tags=['Actions'], summary="Get all action shortcuts.")
def get_action_shortcuts():
return action_shortcuts_handler.get_actions()
@app.post("/api/actions/",
tags=['Actions'], summary="Add an action shortcut.")
def post_action_shortcut(action: MultiAction):
return action_shortcuts_handler.add_action(action)
@app.post("/api/actions/{action_id}",
tags=['Actions'], summary="Update an action shortcut")
def update_action_shortcut(action: MultiAction):
return action_shortcuts_handler.update_action(action)
@app.delete("/api/actions/{action_id}",
tags=['Actions'], summary="Delete an action shortcut.")
def delete_action_shortcut(action_id: UUID = Path(...)):
return action_shortcuts_handler.remove_action(action_id)
# Quick audio
@app.get("/api/audio/",
tags=['Audio'], summary="Get metadata of all quick audio files.")
def get_audio_shortcuts():
return audio_shortcuts_handler.get_audio_metadata()
@app.post("/api/audio/",
tags=['Audio'], summary="Add a new audio shortcut")
async def post_audio_shortcut(file_content: UploadFile, phrase: str = Form(...), group: str = Form("Default")):
phrase_hash = hash_phrase_to_filename(phrase)
save_path = os.path.join("data", "uploads", f"{phrase_hash}.wav")
if not os.path.isfile(save_path):
async with async_open(save_path, "wb") as save_file:
while content := await file_content.read(1024):
await save_file.write(content)
audio_shortcuts_handler.add_audio(UtteranceItem.parse_obj({"ID": uuid4(),
"Group": group,
"Delay": 0,
"Phrase": phrase,
"FilePath": save_path
}))
return {"message": "Audio shortcut created."}
@app.get("/api/audio/{audio_id}",
tags=['Audio'], summary="Get metadata of a specific audio shortcut")
def get_audio_shortcut(audio_id: UUID = Path(...)):
return audio_shortcuts_handler.get_single_audio_metadata(audio_id)
@app.delete("/api/audio/{audio_id}",
tags=['Audio'], summary="Remove an audio shortcut")
def remove_audio_shortcut(audio_id: UUID = Path(...)):
audio_shortcuts_handler.remove_audio(audio_id)
return {"message": "Audio shortcut removed"}
# Motions
@app.get("/api/motions/",
tags=['Motions'], summary="Get metadata of all movements.")
def get_moves():
return motions_handler.get_motions()
@app.get("/api/motions/{move_id}",
tags=['Motions'], summary="Get metadata of a specific movement")
def get_move(move_id: UUID = Path(...)):
return motions_handler.get_motion_by_id(move_id)
# Speech synthesis
@app.get("/api/voices",
tags=['Synthesis'], summary="List available voices")
def get_voices():
return {'voices': SPEAKERS}
class SynthesisRequest(BaseModel):
phrase: str
voice: str
speed: float
@validator('speed')
def speed_validator(cls, spd):
print(spd)
if spd < 0.5 or spd > 2:
raise ValueError("Speed must be a value between 0.5 and 2!")
return spd
@app.post("/api/synthesis",
tags=['Synthesis'], summary="Synthesize speech using the given phrase. Returns the path to the resulting file.")
def post_synthesize(sr: SynthesisRequest):
print(sr)
return {'message': 'Audio synthesized!', 'filepath': synthesize(sr.phrase, sr.voice, sr.speed, force=True)}
@app.post("/api/synthesis/batch",
tags=['Synthesis'], summary="Synthesize all speech for the given session.")
def post_synthesize_batch(voice: str, session: Session):
for session_item in session.Items:
for action in session_item.Actions:
if action.UtteranceItem and action.UtteranceItem.Phrase:
if action.UtteranceItem.Pronunciation and action.UtteranceItem.Pronunciation != action.UtteranceItem.Phrase:
phrase = action.UtteranceItem.Pronunciation
else:
phrase = action.UtteranceItem.Phrase
action.UtteranceItem.Pronunciation = ""
action.UtteranceItem.FilePath = synthesize(phrase, voice, action.UtteranceItem.Speed, force=True)
return sessions_handler.update_session(session.ID, session)
# Uploads
@app.post("/api/upload/audio",
tags=['Uploads'], summary="Upload session audio")
async def post_audio(file_content: UploadFile):
return await hash_and_save_file(file_content, "Audio file")
@app.post("/api/upload/image",
tags=["Uploads"], summary="Upload session image")
async def post_image(file_content: UploadFile):
return await hash_and_save_file(file_content, "Image")
@app.post("/api/upload/session",
tags=['Uploads'], summary="Upload a session")
async def post_session(file_content: UploadFile):
temp_file = TemporaryFile()
temp_file.write(file_content.file.read())
session_zip = ZipFile(temp_file)
if 'session.json' not in session_zip.namelist():
return {'error': 'Session file missing from archive!'}
for filename in list(filter(lambda x: x.startswith('uploads/'), session_zip.namelist())):
# Extract the file manually to avoid wonky directory creation via ZipFile.extract()
with open(os.path.join('data', 'uploads', os.path.basename(filename)), 'wb') as f1:
f1.write(session_zip.read(filename))
with session_zip.open('session.json') as sess:
session = json.loads(sess.read())
if file_content.filename.replace(".zip", "") != session['Name']:
return {'error': {'Import failed: archive and session name do not match!'}}
# If an existing session shares the name with the posted session, the client-side check has passed and
# the existing session must be updated instead.
for old_session in sessions_handler.sessions:
if session['Name'] == old_session.Name:
await sessions_handler.dict_to_session_rename(session)
sessions_handler.update_session(old_session.ID, Session.parse_obj(session))
return {'message': "Session update dummy msg", 'session_index': sessions_handler.get_session_index(session['ID'])}
await sessions_handler.import_session(session)
return {'message': 'Session imported!', 'session_index': sessions_handler.get_session_index(session['ID'])}
# Recording
@app.get("/api/recording/start",
tags=['Recording'], summary="Begin recording audio and session progress.",
description="Audio is taken from the server's default audio input and is saved in WAV format. Recordings can be found in data/recordings.")
def start_recording(conn: str):
return recording_manager.start_recording(conn)
@app.get("/api/recording/stop",
tags=['Recording'], summary="Stop recording.")
def stop_recording(conn: str):
return recording_manager.stop_recording(conn)
@app.get("/api/recording/export",
tags=['Recording'], summary="Export recording data.")
def export_recordings():
if recording_manager.recording_connection:
return {"error": "The server is currently recording. Finish recording to export recording data."}
return compress_recordings()
@app.get("/api/recording/clear_archives",
tags=['Recording'], summary="Delete stored archives.")
def clear_archives():
for filder in os.listdir(os.path.join("data", "recordings")):
if filder.endswith(".zip"):
os.remove(os.path.join("data", "recordings", filder))
return {"message": "Stored archives deleted."}
# Server maintenance
@app.get("/api/rebuild",
tags=['Maintenance'], summary="Rebuild the site via NPM.")
def get_rebuild():
subprocess.Popen('./rebuild.sh', shell=True, preexec_fn=os.setpgrp)
return {'message': "Started rebuild.sh"}
@app.get("/api/check_update",
tags=['Maintenance'], summary="Check for update availability.")
def get_update_status():
subprocess.run(['git', 'fetch'])
backend_update = "[behind " in str(subprocess.check_output(['git', 'status', '-sb']))
os.chdir("../web-client")
subprocess.run(['git', 'fetch'])
frontend_update = "[behind " in str(subprocess.check_output(['git', 'status', '-sb']))
os.chdir("../python-backend")
return {"update_available": backend_update or frontend_update}
@app.get("/api/update",
tags=['Maintenance'], summary="Update the servers")
def get_update():
subprocess.Popen('./update.sh', shell=True, preexec_fn=os.setpgrp)
return {'message': "Started update.sh"}
@app.get("/api/shutdown",
tags=['Maintenance'], summary="Shut the server down.")
def get_shutdown():
os.system("shutdown -P now")
@app.on_event("shutdown")
def shutdown_event():
motions_handler.save_motions()
sessions_handler.save_sessions()
address_forwarder.stop()