Skip to content

Commit

Permalink
【Hackathon + No.163】基于PaddleDetection PP-TinyPose,新增手势关键点检测模型 (#8066)
Browse files Browse the repository at this point in the history
* support COCO Whole Bady Hand

* update transforms

* disable `AugmentationbyInformantionDropping`

* fix infer bug

* fix getImgIds
  • Loading branch information
flytocc authored May 22, 2023
1 parent eeebef9 commit a694be1
Show file tree
Hide file tree
Showing 7 changed files with 720 additions and 6 deletions.
145 changes: 145 additions & 0 deletions configs/keypoint/tiny_pose/tinypose_256x256_hand.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/tinypose_256x256_hand/model_final
epoch: 210
num_joints: &num_joints 21
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOWholeBadyHandEval
num_classes: 1
train_height: &train_height 256
train_width: &train_width 256
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [64, 64]
flip_perm: &flip_perm []


#####model
architecture: TopDownHRNet

TopDownHRNet:
backbone: LiteHRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 40
loss: KeyPointMSELoss
use_dark: true

LiteHRNet:
network_type: wider_naive
freeze_at: -1
freeze_norm: false
return_idx: [0]

KeyPointMSELoss:
use_target_weight: true
loss_scale: 1.0


#####optimizer
LearningRate:
base_lr: 0.002
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 500

OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2


#####data
TrainDataset:
!KeypointTopDownCocoWholeBodyHandDataset
image_dir: train2017
anno_path: annotations/coco_wholebody_train_v1.0.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std

EvalDataset:
!KeypointTopDownCocoWholeBodyHandDataset
image_dir: val2017
anno_path: annotations/coco_wholebody_val_v1.0.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std

TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt

worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- TopDownRandomShiftBboxCenter:
shift_prob: 0.3
shift_factor: 0.16
- TopDownRandomFlip:
flip_prob: 0.5
flip_perm: *flip_perm
- TopDownGetRandomScaleRotation:
rot_prob: 0.6
rot_factor: 90
scale_factor: 0.3
# - AugmentationbyInformantionDropping:
# prob_cutout: 0.5
# offset_factor: 0.05
# num_patch: 1
# trainsize: *trainsize
- TopDownAffine:
trainsize: *trainsize
use_udp: true
- ToHeatmapsTopDown_DARK:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128
shuffle: true
drop_last: false

EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
use_udp: true
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128

TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false
6 changes: 4 additions & 2 deletions ppdet/data/source/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ def get_categories(metric_type, anno_file=None, arch=None):
elif metric_type.lower() == 'widerface':
return _widerface_category()

elif metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower(
) == 'keypointtopdownmpiieval':
elif metric_type.lower() in [
'keypointtopdowncocoeval', 'keypointtopdownmpiieval',
'keypointtopdowncocowholebadyhandeval'
]:
return (None, {'id': 'keypoint'})

elif metric_type.lower() == 'pose3deval':
Expand Down
116 changes: 116 additions & 0 deletions ppdet/data/source/keypoint_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,122 @@ def _load_coco_person_detection_results(self):
return kpt_db


@register
@serializable
class KeypointTopDownCocoWholeBodyHandDataset(KeypointTopDownBaseDataset):
"""CocoWholeBody dataset for top-down hand pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
COCO-WholeBody Hand keypoint indexes:
0: 'wrist',
1: 'thumb1',
2: 'thumb2',
3: 'thumb3',
4: 'thumb4',
5: 'forefinger1',
6: 'forefinger2',
7: 'forefinger3',
8: 'forefinger4',
9: 'middle_finger1',
10: 'middle_finger2',
11: 'middle_finger3',
12: 'middle_finger4',
13: 'ring_finger1',
14: 'ring_finger2',
15: 'ring_finger3',
16: 'ring_finger4',
17: 'pinky_finger1',
18: 'pinky_finger2',
19: 'pinky_finger3',
20: 'pinky_finger4'
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
anno_path (str): Relative path to the annotation file.
num_joints (int): Keypoint numbers
trainsize (list):[w, h] Image target size
transform (composed(operators)): A sequence of data transforms.
pixel_std (int): The pixel std of the scale
Default: 200.
"""

def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
trainsize,
transform=[],
pixel_std=200):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform)

self.trainsize = trainsize
self.pixel_std = pixel_std
self.dataset_name = 'coco_wholebady_hand'

def _box2cs(self, box):
x, y, w, h = box[:4]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]

if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
scale = np.array(
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25

return center, scale

def parse_dataset(self):
gt_db = []
num_joints = self.ann_info['num_joints']
coco = COCO(self.get_anno())
img_ids = list(coco.imgs.keys())
for img_id in img_ids:
im_ann = coco.loadImgs(img_id)[0]
image_file = os.path.join(self.img_prefix, im_ann['file_name'])
im_id = int(im_ann["id"])

ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
objs = coco.loadAnns(ann_ids)

for obj in objs:
for type in ['left', 'right']:
if (obj[f'{type}hand_valid'] and
max(obj[f'{type}hand_kpts']) > 0):

joints = np.zeros((num_joints, 3), dtype=np.float32)
joints_vis = np.zeros((num_joints, 3), dtype=np.float32)

keypoints = np.array(obj[f'{type}hand_kpts'])
keypoints = keypoints.reshape(-1, 3)
joints[:, :2] = keypoints[:, :2]
joints_vis[:, :2] = np.minimum(1, keypoints[:, 2:3])

center, scale = self._box2cs(obj[f'{type}hand_box'][:4])
gt_db.append({
'image_file': image_file,
'center': center,
'scale': scale,
'gt_joints': joints,
'joints_vis': joints_vis,
'im_id': im_id,
})

self.db = gt_db


@register
@serializable
class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
Expand Down
Loading

0 comments on commit a694be1

Please sign in to comment.