Skip to content

Commit

Permalink
Support stgcn (#293)
Browse files Browse the repository at this point in the history
* add stgcn
  • Loading branch information
Cathy0908 authored Mar 2, 2023
1 parent c73edee commit 4cf6f79
Show file tree
Hide file tree
Showing 22 changed files with 1,676 additions and 23 deletions.
120 changes: 120 additions & 0 deletions configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
_base_ = 'configs/base.py'

CLASSES = [
'drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop',
'pickup', 'throw', 'sitting down', 'standing up (from sitting position)',
'clapping', 'reading', 'writing', 'tear up paper', 'wear jacket',
'take off jacket', 'wear a shoe', 'take off a shoe', 'wear on glasses',
'take off glasses', 'put on a hat/cap', 'take off a hat/cap', 'cheer up',
'hand waving', 'kicking something', 'reach into pocket',
'hopping (one foot jumping)', 'jump up', 'make a phone call/answer phone',
'playing with phone/tablet', 'typing on a keyboard',
'pointing to something with finger', 'taking a selfie',
'check time (from watch)', 'rub two hands together', 'nod head/bow',
'shake head', 'wipe face', 'salute', 'put the palms together',
'cross hands in front (say stop)', 'sneeze/cough', 'staggering', 'falling',
'touch head (headache)', 'touch chest (stomachache/heart pain)',
'touch back (backache)', 'touch neck (neckache)',
'nausea or vomiting condition',
'use a fan (with hand or paper)/feeling warm',
'punching/slapping other person', 'kicking other person',
'pushing other person', 'pat on back of other person',
'point finger at the other person', 'hugging other person',
'giving something to other person', "touch other person's pocket",
'handshaking', 'walking towards each other',
'walking apart from each other'
]

model = dict(
type='SkeletonGCN',
backbone=dict(
type='STGCN',
in_channels=3,
edge_importance_weighting=True,
graph_cfg=dict(layout='coco', strategy='spatial')),
cls_head=dict(
type='STGCNHead',
num_classes=60,
in_channels=256,
loss_cls=dict(type='CrossEntropyLoss')),
train_cfg=None,
test_cfg=None)

dataset_type = 'VideoDataset'
ann_file_train = 'data/posec3d/ntu60_xsub_train.pkl'
ann_file_val = 'data/posec3d/ntu60_xsub_val.pkl'

train_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
val_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
test_pipeline = [
dict(type='PaddingWithLoop', clip_len=300),
dict(type='PoseDecode'),
dict(type='FormatGCNInput', input_format='NCTVM'),
dict(type='PoseNormalize'),
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
dict(type='VideoToTensor', keys=['keypoint'])
]
data = dict(
imgs_per_gpu=16,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_train,
data_prefix='',
),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
imgs_per_gpu=1,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_val,
data_prefix='',
),
pipeline=val_pipeline),
test=dict(
type=dataset_type,
data_source=dict(
type='PoseDataSourceForVideoRec',
ann_file=ann_file_val,
data_prefix='',
),
pipeline=test_pipeline))

# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001, nesterov=True)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[10, 50])
total_epochs = 80

# eval
eval_config = dict(initial=False, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
)
]

log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
checkpoint_config = dict(interval=1)
217 changes: 217 additions & 0 deletions demos/video_recognition/skeleton_based_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import argparse
import os
import os.path as osp
import shutil

import cv2
import mmcv
import numpy as np
import torch

from easycv.file.utils import is_url_path
from easycv.predictors.pose_predictor import PoseTopDownPredictor
from easycv.predictors.video_classifier import STGCNPredictor

try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')

FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1
TMP_DIR = './tmp'


def parse_args():
parser = argparse.ArgumentParser(
description='Video classification demo based skeleton.')
parser.add_argument(
'--video',
default=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/demos/videos/ntu_sample.avi',
help='video file/url')
parser.add_argument(
'--out_file',
default=f'{TMP_DIR}/demo_show.mp4',
help='output filename')
parser.add_argument(
'--config',
default=(
'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
),
help='skeleton model config file path')
parser.add_argument(
'--checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/video/skeleton_based/stgcn/stgcn_80e_ntu60_xsub.pth'
),
help='skeleton model checkpoint file/url')
parser.add_argument(
'--det-config',
default='configs/detection/yolox/yolox_s_8xb16_300e_coco.py',
help='human detection config file path')
parser.add_argument(
'--det-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pt'
),
help='human detection checkpoint file/url')
parser.add_argument(
'--det-predictor-type',
default='YoloXPredictor',
help='detection predictor type')
parser.add_argument(
'--pose-config',
default='configs/pose/hrnet_w48_coco_256x192_udp.py',
help='human pose estimation config file path')
parser.add_argument(
'--pose-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/pose/top_down_hrnet/pose_hrnet_epoch_210_export.pt'
),
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.5,
help='the threshold of human detection score')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
args = parser.parse_args()
return args


def frame_extraction(video_path, short_side):
"""Extract frames given video_path.
Args:
video_path (str): The video_path.
"""
if is_url_path(video_path):
from torch.hub import download_url_to_file
cache_video_path = os.path.join(TMP_DIR, os.path.basename(video_path))
print(
'Download video file from remote to local path "{cache_video_path}"...'
)
download_url_to_file(video_path, cache_video_path)
video_path = cache_video_path

# Load the video, extract frames into ./tmp/video_name
target_dir = osp.join(TMP_DIR, osp.basename(osp.splitext(video_path)[0]))
os.makedirs(target_dir, exist_ok=True)
# Should be able to handle videos up to several hours
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
vid = cv2.VideoCapture(video_path)
frames = []
frame_paths = []
flag, frame = vid.read()
cnt = 0
new_h, new_w = None, None
while flag:
if new_h is None:
h, w, _ = frame.shape
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
frame = mmcv.imresize(frame, (new_w, new_h))
frames.append(frame)
frame_path = frame_tmpl.format(cnt + 1)
frame_paths.append(frame_path)

cv2.imwrite(frame_path, frame)
cnt += 1
flag, frame = vid.read()

return frame_paths, frames


def main():
args = parse_args()

if not osp.exists(TMP_DIR):
os.makedirs(TMP_DIR)

frame_paths, original_frames = frame_extraction(args.video,
args.short_side)
num_frame = len(frame_paths)
h, w, _ = original_frames[0].shape

# Get Human detection results
pose_predictor = PoseTopDownPredictor(
model_path=args.pose_checkpoint,
config_file=args.pose_config,
detection_predictor_config=dict(
type=args.det_predictor_type,
model_path=args.det_checkpoint,
config_file=args.det_config,
),
bbox_thr=args.bbox_thr,
cat_id=0, # person category id
)

video_cls_predictor = STGCNPredictor(
model_path=args.checkpoint,
config_file=args.config,
ori_image_size=(w, h),
label_map=None)

pose_results = pose_predictor(original_frames)

torch.cuda.empty_cache()

fake_anno = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=num_frame)
num_person = max([len(x) for x in pose_results])

num_keypoint = 17
keypoints = np.zeros((num_person, num_frame, num_keypoint, 2),
dtype=np.float16)
keypoints_score = np.zeros((num_person, num_frame, num_keypoint),
dtype=np.float16)
for i, poses in enumerate(pose_results):
if len(poses) < 1:
continue
_keypoint = poses['keypoints'] # shape = (num_person, num_keypoint, 3)
for j, pose in enumerate(_keypoint):
keypoints[j, i] = pose[:, :2]
keypoints_score[j, i] = pose[:, 2]

fake_anno['keypoint'] = keypoints
fake_anno['keypoint_score'] = keypoints_score

results = video_cls_predictor([fake_anno])

action_label = results[0]['class_name'][0]
print(f'action label: {action_label}')

vis_frames = [
pose_predictor.show_result(original_frames[i], pose_results[i])
if len(pose_results[i]) > 0 else original_frames[i]
for i in range(num_frame)
]
for frame in vis_frames:
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)

vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
vid.write_videofile(args.out_file, remove_temp=True)
print(f'Write video to {args.out_file} successfully!')

tmp_frame_dir = osp.dirname(frame_paths[0])
shutil.rmtree(tmp_frame_dir)


if __name__ == '__main__':
main()
8 changes: 7 additions & 1 deletion easycv/core/evaluation/classification_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ def _evaluate_impl(self, predictions, gt_labels):
a dict, each key is metric_name, value is metric value
'''
eval_res = OrderedDict()

if isinstance(gt_labels, dict):
assert len(gt_labels) == 1
gt_labels = list(gt_labels.values())[0]

target = gt_labels.long()

# if self.neck_num is not None:
if self.neck_num is None:
predictions = {'neck': predictions['neck']}
if len(predictions) > 1:
predictions = {'neck': predictions['neck']}
else:
predictions = {
'neck_%d_0' % self.neck_num:
Expand Down
1 change: 1 addition & 0 deletions easycv/datasets/video_recognition/data_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .pose_datasource import PoseDataSourceForVideoRec
from .video_datasource import VideoDatasource
from .video_text_datasource import VideoTextDatasource
Loading

0 comments on commit 4cf6f79

Please sign in to comment.