-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add stgcn
- Loading branch information
Showing
22 changed files
with
1,676 additions
and
23 deletions.
There are no files selected for viewing
120 changes: 120 additions & 0 deletions
120
configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
_base_ = 'configs/base.py' | ||
|
||
CLASSES = [ | ||
'drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop', | ||
'pickup', 'throw', 'sitting down', 'standing up (from sitting position)', | ||
'clapping', 'reading', 'writing', 'tear up paper', 'wear jacket', | ||
'take off jacket', 'wear a shoe', 'take off a shoe', 'wear on glasses', | ||
'take off glasses', 'put on a hat/cap', 'take off a hat/cap', 'cheer up', | ||
'hand waving', 'kicking something', 'reach into pocket', | ||
'hopping (one foot jumping)', 'jump up', 'make a phone call/answer phone', | ||
'playing with phone/tablet', 'typing on a keyboard', | ||
'pointing to something with finger', 'taking a selfie', | ||
'check time (from watch)', 'rub two hands together', 'nod head/bow', | ||
'shake head', 'wipe face', 'salute', 'put the palms together', | ||
'cross hands in front (say stop)', 'sneeze/cough', 'staggering', 'falling', | ||
'touch head (headache)', 'touch chest (stomachache/heart pain)', | ||
'touch back (backache)', 'touch neck (neckache)', | ||
'nausea or vomiting condition', | ||
'use a fan (with hand or paper)/feeling warm', | ||
'punching/slapping other person', 'kicking other person', | ||
'pushing other person', 'pat on back of other person', | ||
'point finger at the other person', 'hugging other person', | ||
'giving something to other person', "touch other person's pocket", | ||
'handshaking', 'walking towards each other', | ||
'walking apart from each other' | ||
] | ||
|
||
model = dict( | ||
type='SkeletonGCN', | ||
backbone=dict( | ||
type='STGCN', | ||
in_channels=3, | ||
edge_importance_weighting=True, | ||
graph_cfg=dict(layout='coco', strategy='spatial')), | ||
cls_head=dict( | ||
type='STGCNHead', | ||
num_classes=60, | ||
in_channels=256, | ||
loss_cls=dict(type='CrossEntropyLoss')), | ||
train_cfg=None, | ||
test_cfg=None) | ||
|
||
dataset_type = 'VideoDataset' | ||
ann_file_train = 'data/posec3d/ntu60_xsub_train.pkl' | ||
ann_file_val = 'data/posec3d/ntu60_xsub_val.pkl' | ||
|
||
train_pipeline = [ | ||
dict(type='PaddingWithLoop', clip_len=300), | ||
dict(type='PoseDecode'), | ||
dict(type='FormatGCNInput', input_format='NCTVM'), | ||
dict(type='PoseNormalize'), | ||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), | ||
dict(type='VideoToTensor', keys=['keypoint']) | ||
] | ||
val_pipeline = [ | ||
dict(type='PaddingWithLoop', clip_len=300), | ||
dict(type='PoseDecode'), | ||
dict(type='FormatGCNInput', input_format='NCTVM'), | ||
dict(type='PoseNormalize'), | ||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), | ||
dict(type='VideoToTensor', keys=['keypoint']) | ||
] | ||
test_pipeline = [ | ||
dict(type='PaddingWithLoop', clip_len=300), | ||
dict(type='PoseDecode'), | ||
dict(type='FormatGCNInput', input_format='NCTVM'), | ||
dict(type='PoseNormalize'), | ||
dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]), | ||
dict(type='VideoToTensor', keys=['keypoint']) | ||
] | ||
data = dict( | ||
imgs_per_gpu=16, | ||
workers_per_gpu=2, | ||
train=dict( | ||
type=dataset_type, | ||
data_source=dict( | ||
type='PoseDataSourceForVideoRec', | ||
ann_file=ann_file_train, | ||
data_prefix='', | ||
), | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
imgs_per_gpu=1, | ||
data_source=dict( | ||
type='PoseDataSourceForVideoRec', | ||
ann_file=ann_file_val, | ||
data_prefix='', | ||
), | ||
pipeline=val_pipeline), | ||
test=dict( | ||
type=dataset_type, | ||
data_source=dict( | ||
type='PoseDataSourceForVideoRec', | ||
ann_file=ann_file_val, | ||
data_prefix='', | ||
), | ||
pipeline=test_pipeline)) | ||
|
||
# optimizer | ||
optimizer = dict( | ||
type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001, nesterov=True) | ||
optimizer_config = dict(grad_clip=None) | ||
# learning policy | ||
lr_config = dict(policy='step', step=[10, 50]) | ||
total_epochs = 80 | ||
|
||
# eval | ||
eval_config = dict(initial=False, interval=1, gpu_collect=True) | ||
eval_pipelines = [ | ||
dict( | ||
mode='test', | ||
data=data['val'], | ||
dist_eval=True, | ||
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], | ||
) | ||
] | ||
|
||
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) | ||
checkpoint_config = dict(interval=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import argparse | ||
import os | ||
import os.path as osp | ||
import shutil | ||
|
||
import cv2 | ||
import mmcv | ||
import numpy as np | ||
import torch | ||
|
||
from easycv.file.utils import is_url_path | ||
from easycv.predictors.pose_predictor import PoseTopDownPredictor | ||
from easycv.predictors.video_classifier import STGCNPredictor | ||
|
||
try: | ||
import moviepy.editor as mpy | ||
except ImportError: | ||
raise ImportError('Please install moviepy to enable output file') | ||
|
||
FONTFACE = cv2.FONT_HERSHEY_DUPLEX | ||
FONTSCALE = 0.75 | ||
FONTCOLOR = (255, 255, 255) # BGR, white | ||
THICKNESS = 1 | ||
LINETYPE = 1 | ||
TMP_DIR = './tmp' | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Video classification demo based skeleton.') | ||
parser.add_argument( | ||
'--video', | ||
default= | ||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/demos/videos/ntu_sample.avi', | ||
help='video file/url') | ||
parser.add_argument( | ||
'--out_file', | ||
default=f'{TMP_DIR}/demo_show.mp4', | ||
help='output filename') | ||
parser.add_argument( | ||
'--config', | ||
default=( | ||
'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py' | ||
), | ||
help='skeleton model config file path') | ||
parser.add_argument( | ||
'--checkpoint', | ||
default= | ||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/video/skeleton_based/stgcn/stgcn_80e_ntu60_xsub.pth' | ||
), | ||
help='skeleton model checkpoint file/url') | ||
parser.add_argument( | ||
'--det-config', | ||
default='configs/detection/yolox/yolox_s_8xb16_300e_coco.py', | ||
help='human detection config file path') | ||
parser.add_argument( | ||
'--det-checkpoint', | ||
default= | ||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pt' | ||
), | ||
help='human detection checkpoint file/url') | ||
parser.add_argument( | ||
'--det-predictor-type', | ||
default='YoloXPredictor', | ||
help='detection predictor type') | ||
parser.add_argument( | ||
'--pose-config', | ||
default='configs/pose/hrnet_w48_coco_256x192_udp.py', | ||
help='human pose estimation config file path') | ||
parser.add_argument( | ||
'--pose-checkpoint', | ||
default= | ||
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/pose/top_down_hrnet/pose_hrnet_epoch_210_export.pt' | ||
), | ||
help='human pose estimation checkpoint file/url') | ||
parser.add_argument( | ||
'--bbox-thr', | ||
type=float, | ||
default=0.5, | ||
help='the threshold of human detection score') | ||
parser.add_argument( | ||
'--device', type=str, default='cuda:0', help='CPU/CUDA device option') | ||
parser.add_argument( | ||
'--short-side', | ||
type=int, | ||
default=480, | ||
help='specify the short-side length of the image') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def frame_extraction(video_path, short_side): | ||
"""Extract frames given video_path. | ||
Args: | ||
video_path (str): The video_path. | ||
""" | ||
if is_url_path(video_path): | ||
from torch.hub import download_url_to_file | ||
cache_video_path = os.path.join(TMP_DIR, os.path.basename(video_path)) | ||
print( | ||
'Download video file from remote to local path "{cache_video_path}"...' | ||
) | ||
download_url_to_file(video_path, cache_video_path) | ||
video_path = cache_video_path | ||
|
||
# Load the video, extract frames into ./tmp/video_name | ||
target_dir = osp.join(TMP_DIR, osp.basename(osp.splitext(video_path)[0])) | ||
os.makedirs(target_dir, exist_ok=True) | ||
# Should be able to handle videos up to several hours | ||
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') | ||
vid = cv2.VideoCapture(video_path) | ||
frames = [] | ||
frame_paths = [] | ||
flag, frame = vid.read() | ||
cnt = 0 | ||
new_h, new_w = None, None | ||
while flag: | ||
if new_h is None: | ||
h, w, _ = frame.shape | ||
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf)) | ||
frame = mmcv.imresize(frame, (new_w, new_h)) | ||
frames.append(frame) | ||
frame_path = frame_tmpl.format(cnt + 1) | ||
frame_paths.append(frame_path) | ||
|
||
cv2.imwrite(frame_path, frame) | ||
cnt += 1 | ||
flag, frame = vid.read() | ||
|
||
return frame_paths, frames | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
if not osp.exists(TMP_DIR): | ||
os.makedirs(TMP_DIR) | ||
|
||
frame_paths, original_frames = frame_extraction(args.video, | ||
args.short_side) | ||
num_frame = len(frame_paths) | ||
h, w, _ = original_frames[0].shape | ||
|
||
# Get Human detection results | ||
pose_predictor = PoseTopDownPredictor( | ||
model_path=args.pose_checkpoint, | ||
config_file=args.pose_config, | ||
detection_predictor_config=dict( | ||
type=args.det_predictor_type, | ||
model_path=args.det_checkpoint, | ||
config_file=args.det_config, | ||
), | ||
bbox_thr=args.bbox_thr, | ||
cat_id=0, # person category id | ||
) | ||
|
||
video_cls_predictor = STGCNPredictor( | ||
model_path=args.checkpoint, | ||
config_file=args.config, | ||
ori_image_size=(w, h), | ||
label_map=None) | ||
|
||
pose_results = pose_predictor(original_frames) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
fake_anno = dict( | ||
frame_dir='', | ||
label=-1, | ||
img_shape=(h, w), | ||
original_shape=(h, w), | ||
start_index=0, | ||
modality='Pose', | ||
total_frames=num_frame) | ||
num_person = max([len(x) for x in pose_results]) | ||
|
||
num_keypoint = 17 | ||
keypoints = np.zeros((num_person, num_frame, num_keypoint, 2), | ||
dtype=np.float16) | ||
keypoints_score = np.zeros((num_person, num_frame, num_keypoint), | ||
dtype=np.float16) | ||
for i, poses in enumerate(pose_results): | ||
if len(poses) < 1: | ||
continue | ||
_keypoint = poses['keypoints'] # shape = (num_person, num_keypoint, 3) | ||
for j, pose in enumerate(_keypoint): | ||
keypoints[j, i] = pose[:, :2] | ||
keypoints_score[j, i] = pose[:, 2] | ||
|
||
fake_anno['keypoint'] = keypoints | ||
fake_anno['keypoint_score'] = keypoints_score | ||
|
||
results = video_cls_predictor([fake_anno]) | ||
|
||
action_label = results[0]['class_name'][0] | ||
print(f'action label: {action_label}') | ||
|
||
vis_frames = [ | ||
pose_predictor.show_result(original_frames[i], pose_results[i]) | ||
if len(pose_results[i]) > 0 else original_frames[i] | ||
for i in range(num_frame) | ||
] | ||
for frame in vis_frames: | ||
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE, | ||
FONTCOLOR, THICKNESS, LINETYPE) | ||
|
||
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24) | ||
vid.write_videofile(args.out_file, remove_temp=True) | ||
print(f'Write video to {args.out_file} successfully!') | ||
|
||
tmp_frame_dir = osp.dirname(frame_paths[0]) | ||
shutil.rmtree(tmp_frame_dir) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from .pose_datasource import PoseDataSourceForVideoRec | ||
from .video_datasource import VideoDatasource | ||
from .video_text_datasource import VideoTextDatasource |
Oops, something went wrong.