Skip to content

Commit

Permalink
add bear example in just dance
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Jul 24, 2023
1 parent c94434b commit 97b8545
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 21 deletions.
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Import:
- configs/body_2d_keypoint/rtmpose/coco/rtmpose_coco.yml
- configs/body_2d_keypoint/rtmpose/crowdpose/rtmpose_crowdpose.yml
- configs/body_2d_keypoint/rtmpose/mpii/rtmpose_mpii.yml
- configs/body_2d_keypoint/rtmpose/humanart/rtmpose_humanart.yml
- configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
- configs/body_2d_keypoint/simcc/coco/resnet_coco.yml
- configs/body_2d_keypoint/simcc/coco/vipnas_coco.yml
Expand Down
12 changes: 10 additions & 2 deletions projects/just_dance/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def process_video(
os.system(
f'wget -O {project_path}/resources/tsinghua_30fps.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4' # noqa
)
os.system(
f'wget -O {project_path}/resources/student1.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/student1.mp4' # noqa
)
os.system(
f'wget -O {project_path}/resources/bear_teacher.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/bear_teacher.mp4' # noqa
)

with gr.Blocks() as demo:
with gr.Tab('Upload-Video'):
Expand All @@ -66,13 +72,15 @@ def process_video(
student_video = gr.Video(type='mp4')
gr.Examples([
os.path.join(project_path, 'resources/tom.mp4'),
os.path.join(project_path, 'resources/tsinghua_30fps.mp4')
os.path.join(project_path, 'resources/tsinghua_30fps.mp4'),
os.path.join(project_path, 'resources/student1.mp4')
], student_video)
with gr.Column():
gr.Markdown('Teacher Video')
teacher_video = gr.Video(type='mp4')
gr.Examples([
os.path.join(project_path, 'resources/idol_producer.mp4')
os.path.join(project_path, 'resources/idol_producer.mp4'),
os.path.join(project_path, 'resources/bear_teacher.mp4')
], teacher_video)

button = gr.Button('Grading', variant='primary')
Expand Down
68 changes: 49 additions & 19 deletions projects/just_dance/process_video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
from typing import Optional

import cv2
import mmcv
Expand All @@ -24,32 +25,50 @@
from utils import (blend_images, convert_video_fps, get_smoothed_kpt,
resize_image_to_fixed_height)

det_config = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'configs/rtmdet-nano_one-person.py')
det_weights = 'https://download.openmmlab.com/mmpose/v1/projects/' \
'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth'
model_cfg = dict(
human=dict(
model='rtmpose-t_8xb256-420e_aic-coco-256x192',
det_model=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'configs/rtmdet-nano_one-person.py'),
det_weights='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth',
),
bear=dict(
model='rtmpose-l_8xb256-420e_humanart-256x192',
det_model='rtmdet-m',
det_cat_ids=77,
),
)


class VideoProcessor:
"""A class to process videos for pose estimation and visualization."""

def __init__(self):
self.category = 'human'

def _set_category(self, category):
assert category in model_cfg
self.category = category

@property
def pose_estimator(self) -> Pose2DInferencer:
if not hasattr(self, '_pose_estimator'):
self._pose_estimator = Pose2DInferencer(
'rtmpose-t_8xb256-420e_aic-coco-256x192',
det_model=det_config,
det_weights=det_weights)
self._pose_estimator.model.test_cfg['flip_test'] = False
return self._pose_estimator
self._pose_estimator = dict()
if self.category not in self._pose_estimator:
self._pose_estimator[self.category] = Pose2DInferencer(
**(model_cfg[self.category]))
self._pose_estimator[
self.category].model.test_cfg['flip_test'] = False
return self._pose_estimator[self.category]

@property
def visualizer(self) -> PoseLocalVisualizer:
if hasattr(self, '_visualizer'):
return self._visualizer
elif hasattr(self, '_pose_estimator'):
return self._pose_estimator.visualizer
return self.pose_estimator.visualizer

# init visualizer
self._visualizer = PoseLocalVisualizer()
Expand Down Expand Up @@ -109,11 +128,16 @@ def get_keypoints_from_video(self, video: str) -> np.ndarray:

video_reader = mmcv.VideoReader(video)

if video_reader.fps != 30:
if abs(video_reader.fps - 30) > 0.1:
video_reader = mmcv.VideoReader(convert_video_fps(video))

assert video_reader.fps == 30, f'only support videos with 30 FPS, ' \
f'but the video {video_fname} has {video_reader.fps} fps'
assert abs(video_reader.fps - 30) < 0.1, f'only support videos with ' \
f'30 FPS, but the video {video_fname} has {video_reader.fps} fps'

if os.path.basename(video_fname).startswith('bear'):
self._set_category('bear')
else:
self._set_category('human')
keypoints_list = []
for i, frame in enumerate(video_reader):
keypoints = self.get_keypoints_from_frame(frame)
Expand All @@ -123,7 +147,10 @@ def get_keypoints_from_video(self, video: str) -> np.ndarray:
return keypoints

@torch.no_grad()
def run(self, tch_video: str, stu_video: str):
def run(self,
tch_video: str,
stu_video: str,
output_file: Optional[str] = None):
# extract human poses
tch_kpts = self.get_keypoints_from_video(tch_video)
stu_kpts = self.get_keypoints_from_video(stu_video)
Expand All @@ -137,8 +164,9 @@ def run(self, tch_video: str, stu_video: str):
# output
tch_name = os.path.basename(tch_video).rsplit('.', 1)[0]
stu_name = os.path.basename(stu_video).rsplit('.', 1)[0]
fname = f'{tch_name}-{stu_name}.mp4'
output_file = os.path.join(tempfile.mkdtemp(), fname)
if output_file is None:
fname = f'{tch_name}-{stu_name}.mp4'
output_file = os.path.join(tempfile.mkdtemp(), fname)
return self.generate_output_video(tch_video, stu_video, output_file,
tch_kpts, stu_kpts, piece_info)

Expand Down Expand Up @@ -223,7 +251,9 @@ def generate_output_video(self, tch_video: str, stu_video: str,
parser = ArgumentParser()
parser.add_argument('teacher_video', help='Path to the Teacher Video')
parser.add_argument('student_video', help='Path to the Student Video')
parser.add_argument(
'--output-file', help='Path to save the output Video', default=None)
args = parser.parse_args()

processor = VideoProcessor()
processor.run(args.teacher_video, args.student_video)
processor.run(args.teacher_video, args.student_video, args.output_file)

0 comments on commit 97b8545

Please sign in to comment.