-
Notifications
You must be signed in to change notification settings - Fork 2
/
pose-fastapi.py
109 lines (89 loc) · 3.5 KB
/
pose-fastapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import io
import cv2
import uvicorn
from PIL import Image
# from typing import Optional
from ultralytics import YOLO
from fastapi import FastAPI, UploadFile, File
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=[
"*"], allow_methods=["*"], allow_headers=["*"])
@app.get("/")
def read_root():
return {"Welcome": "FastAPI API for Pose object detection"}
@app.post("/process_image")
async def process_files(
source: UploadFile = File(...),
model_path: str = "./models/yolov8n-pose.pt",
conf: float = 0.3
):
# Save the uploaded file
file_extension = os.path.splitext(source.filename)[1]
temp_file_path = "temp_file" + file_extension
with open(temp_file_path, "wb") as f:
f.write(source.file.read())
# Load the YOLO model
model = YOLO(model_path)
# Process the file
results = []
if file_extension.lower() in (".jpg", ".jpeg", ".png"):
results = model(source=temp_file_path, conf=conf)
else:
return {"error": "Unsupported file format"}
processed_images = []
for idx, r in enumerate(results):
image_array = r.plot(conf=True, boxes=True)
# Save the image as a file
cv2.imwrite("processed_image.png", image_array)
with open("processed_image.png", "rb") as image_file:
image_bytes = image_file.read()
processed_images.append(image_bytes)
return StreamingResponse(io.BytesIO(processed_images[0]), media_type="image/png")
@app.post("/process_video")
async def process_video(
source: UploadFile = File(...),
model_path: str = "./models/yolov8n-pose.pt",
conf: float = 0.3
):
# Save the uploaded video
file_extension = os.path.splitext(source.filename)[1]
temp_file_path = "temp_video" + file_extension
with open(temp_file_path, "wb") as f:
f.write(source.file.read())
# Load the YOLO model
model = YOLO(model_path)
# Process the video
if file_extension.lower() in (".mp4", ".avi", ".mov", ".mkv"):
results = model(source=temp_file_path, show=True, conf=conf)
else:
return {"error": "Unsupported file format"}
# Save processed video frames
processed_frames = []
for idx, r in enumerate(results):
image_array = r.plot(conf=True, boxes=True)
processed_frames.append(image_array)
# Define codec and create VideoWriter
fourcc = cv2.VideoWriter_fourcc(*"XVID")
processed_video_path = "./sample/processed_video.avi"
out = cv2.VideoWriter(processed_video_path, fourcc, 30,
(processed_frames[0].shape[1], processed_frames[0].shape[0]))
# Write processed frames to the video
for frame in processed_frames:
out.write(frame)
out.release()
# Read the processed video as bytes
with open(processed_video_path, "rb") as video_file:
video_bytes = video_file.read()
# Adjust media_type as needed
# return processed_video_path(processed_video_path=processed_video_path)
# for future read and stream in api or save in s3 bucket and fetch
# # Read the processed video as bytes
# with open(processed_video_path, "rb") as video_file:
# video_bytes = video_file.read()
# return StreamingResponse(io.BytesIO(video_bytes), media_type="video/mp4")
return {"result": f"Succesfully saved in '{str(processed_video_path)}'"}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)