diff --git a/configs/_base_/datasets/ubody3d.py b/configs/_base_/datasets/ubody3d.py
new file mode 100644
index 0000000000..9242559ea1
--- /dev/null
+++ b/configs/_base_/datasets/ubody3d.py
@@ -0,0 +1,958 @@
+dataset_info = dict(
+ dataset_name='ubody3d',
+ paper_info=dict(
+ author='Jing Lin, Ailing Zeng, Haoqian Wang, Lei Zhang, Yu Li',
+ title='One-Stage 3D Whole-Body Mesh Recovery with Component Aware'
+ 'Transformer',
+ container='IEEE Computer Society Conference on Computer Vision and '
+ 'Pattern Recognition (CVPR)',
+ year='2023',
+ homepage='https://github.com/IDEA-Research/OSX',
+ ),
+ keypoint_info={
+ 0:
+ dict(name='Pelvis', id=0, color=[0, 255, 0], type='', swap=''),
+ 1:
+ dict(
+ name='L_Hip', id=1, color=[0, 255, 0], type='lower', swap='R_Hip'),
+ 2:
+ dict(
+ name='R_Hip', id=2, color=[0, 255, 0], type='lower', swap='L_Hip'),
+ 3:
+ dict(
+ name='L_Knee',
+ id=3,
+ color=[0, 255, 0],
+ type='lower',
+ swap='R_Knee'),
+ 4:
+ dict(
+ name='R_Knee',
+ id=4,
+ color=[0, 255, 0],
+ type='lower',
+ swap='L_Knee'),
+ 5:
+ dict(
+ name='L_Ankle',
+ id=5,
+ color=[0, 255, 0],
+ type='lower',
+ swap='R_Ankle'),
+ 6:
+ dict(
+ name='R_Ankle',
+ id=6,
+ color=[0, 255, 0],
+ type='lower',
+ swap='L_Ankle'),
+ 7:
+ dict(name='Neck', id=7, color=[0, 255, 0], type='upper', swap=''),
+ 8:
+ dict(
+ name='L_Shoulder',
+ id=8,
+ color=[0, 255, 0],
+ type='upper',
+ swap='R_Shoulder'),
+ 9:
+ dict(
+ name='R_Shoulder',
+ id=9,
+ color=[0, 255, 0],
+ type='upper',
+ swap='L_Shoulder'),
+ 10:
+ dict(
+ name='L_Elbow',
+ id=10,
+ color=[0, 255, 0],
+ type='upper',
+ swap='R_Elbow'),
+ 11:
+ dict(
+ name='R_Elbow',
+ id=11,
+ color=[0, 255, 0],
+ type='upper',
+ swap='L_Elbow'),
+ 12:
+ dict(
+ name='L_Wrist',
+ id=12,
+ color=[0, 255, 0],
+ type='upper',
+ swap='R_Wrist'),
+ 13:
+ dict(
+ name='R_Wrist',
+ id=13,
+ color=[0, 255, 0],
+ type='upper',
+ swap='L_Wrist'),
+ 14:
+ dict(
+ name='L_Big_toe',
+ id=14,
+ color=[0, 255, 0],
+ type='lower',
+ swap='R_Big_toe'),
+ 15:
+ dict(
+ name='L_Small_toe',
+ id=15,
+ color=[0, 255, 0],
+ type='lower',
+ swap='R_Small_toe'),
+ 16:
+ dict(
+ name='L_Heel',
+ id=16,
+ color=[0, 255, 0],
+ type='lower',
+ swap='R_Heel'),
+ 17:
+ dict(
+ name='R_Big_toe',
+ id=17,
+ color=[0, 255, 0],
+ type='lower',
+ swap='L_Big_toe'),
+ 18:
+ dict(
+ name='R_Small_toe',
+ id=18,
+ color=[0, 255, 0],
+ type='lower',
+ swap='L_Small_toe'),
+ 19:
+ dict(
+ name='R_Heel',
+ id=19,
+ color=[0, 255, 0],
+ type='lower',
+ swap='L_Heel'),
+ 20:
+ dict(
+ name='L_Ear', id=20, color=[0, 255, 0], type='upper',
+ swap='R_Ear'),
+ 21:
+ dict(
+ name='R_Ear', id=21, color=[0, 255, 0], type='upper',
+ swap='L_Ear'),
+ 22:
+ dict(name='L_Eye', id=22, color=[0, 255, 0], type='', swap='R_Eye'),
+ 23:
+ dict(name='R_Eye', id=23, color=[0, 255, 0], type='', swap='L_Eye'),
+ 24:
+ dict(name='Nose', id=24, color=[0, 255, 0], type='upper', swap=''),
+ 25:
+ dict(
+ name='L_Thumb_1',
+ id=25,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Thumb_1'),
+ 26:
+ dict(
+ name='L_Thumb_2',
+ id=26,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Thumb_2'),
+ 27:
+ dict(
+ name='L_Thumb_3',
+ id=27,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Thumb_3'),
+ 28:
+ dict(
+ name='L_Thumb_4',
+ id=28,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Thumb_4'),
+ 29:
+ dict(
+ name='L_Index_1',
+ id=29,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Index_1'),
+ 30:
+ dict(
+ name='L_Index_2',
+ id=30,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Index_2'),
+ 31:
+ dict(
+ name='L_Index_3',
+ id=31,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Index_3'),
+ 32:
+ dict(
+ name='L_Index_4',
+ id=32,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Index_4'),
+ 33:
+ dict(
+ name='L_Middle_1',
+ id=33,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Middle_1'),
+ 34:
+ dict(
+ name='L_Middle_2',
+ id=34,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Middle_2'),
+ 35:
+ dict(
+ name='L_Middle_3',
+ id=35,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Middle_3'),
+ 36:
+ dict(
+ name='L_Middle_4',
+ id=36,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Middle_4'),
+ 37:
+ dict(
+ name='L_Ring_1',
+ id=37,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Ring_1'),
+ 38:
+ dict(
+ name='L_Ring_2',
+ id=38,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Ring_2'),
+ 39:
+ dict(
+ name='L_Ring_3',
+ id=39,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Ring_3'),
+ 40:
+ dict(
+ name='L_Ring_4',
+ id=40,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Ring_4'),
+ 41:
+ dict(
+ name='L_Pinky_1',
+ id=41,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Pinky_1'),
+ 42:
+ dict(
+ name='L_Pinky_2',
+ id=42,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Pinky_2'),
+ 43:
+ dict(
+ name='L_Pinky_3',
+ id=43,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Pinky_3'),
+ 44:
+ dict(
+ name='L_Pinky_4',
+ id=44,
+ color=[255, 128, 0],
+ type='',
+ swap='R_Pinky_4'),
+ 45:
+ dict(
+ name='R_Thumb_1',
+ id=45,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Thumb_1'),
+ 46:
+ dict(
+ name='R_Thumb_2',
+ id=46,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Thumb_2'),
+ 47:
+ dict(
+ name='R_Thumb_3',
+ id=47,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Thumb_3'),
+ 48:
+ dict(
+ name='R_Thumb_4',
+ id=48,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Thumb_4'),
+ 49:
+ dict(
+ name='R_Index_1',
+ id=49,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Index_1'),
+ 50:
+ dict(
+ name='R_Index_2',
+ id=50,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Index_2'),
+ 51:
+ dict(
+ name='R_Index_3',
+ id=51,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Index_3'),
+ 52:
+ dict(
+ name='R_Index_4',
+ id=52,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Index_4'),
+ 53:
+ dict(
+ name='R_Middle_1',
+ id=53,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Middle_1'),
+ 54:
+ dict(
+ name='R_Middle_2',
+ id=54,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Middle_2'),
+ 55:
+ dict(
+ name='R_Middle_3',
+ id=55,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Middle_3'),
+ 56:
+ dict(
+ name='R_Middle_4',
+ id=56,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Middle_4'),
+ 57:
+ dict(
+ name='R_Ring_1',
+ id=57,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Ring_1'),
+ 58:
+ dict(
+ name='R_Ring_2',
+ id=58,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Ring_2'),
+ 59:
+ dict(
+ name='R_Ring_3',
+ id=59,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Ring_3'),
+ 60:
+ dict(
+ name='R_Ring_4',
+ id=60,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Ring_4'),
+ 61:
+ dict(
+ name='R_Pinky_1',
+ id=61,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Pinky_1'),
+ 62:
+ dict(
+ name='R_Pinky_2',
+ id=62,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Pinky_2'),
+ 63:
+ dict(
+ name='R_Pinky_3',
+ id=63,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Pinky_3'),
+ 64:
+ dict(
+ name='R_Pinky_4',
+ id=64,
+ color=[255, 128, 0],
+ type='',
+ swap='L_Pinky_4'),
+ 65:
+ dict(name='Face_1', id=65, color=[255, 255, 255], type='', swap=''),
+ 66:
+ dict(name='Face_2', id=66, color=[255, 255, 255], type='', swap=''),
+ 67:
+ dict(
+ name='Face_3',
+ id=67,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_4'),
+ 68:
+ dict(
+ name='Face_4',
+ id=68,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_3'),
+ 69:
+ dict(
+ name='Face_5',
+ id=69,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_14'),
+ 70:
+ dict(
+ name='Face_6',
+ id=70,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_13'),
+ 71:
+ dict(
+ name='Face_7',
+ id=71,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_12'),
+ 72:
+ dict(
+ name='Face_8',
+ id=72,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_11'),
+ 73:
+ dict(
+ name='Face_9',
+ id=73,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_10'),
+ 74:
+ dict(
+ name='Face_10',
+ id=74,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_9'),
+ 75:
+ dict(
+ name='Face_11',
+ id=75,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_8'),
+ 76:
+ dict(
+ name='Face_12',
+ id=76,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_7'),
+ 77:
+ dict(
+ name='Face_13',
+ id=77,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_6'),
+ 78:
+ dict(
+ name='Face_14',
+ id=78,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_5'),
+ 79:
+ dict(name='Face_15', id=79, color=[255, 255, 255], type='', swap=''),
+ 80:
+ dict(name='Face_16', id=80, color=[255, 255, 255], type='', swap=''),
+ 81:
+ dict(name='Face_17', id=81, color=[255, 255, 255], type='', swap=''),
+ 82:
+ dict(name='Face_18', id=82, color=[255, 255, 255], type='', swap=''),
+ 83:
+ dict(
+ name='Face_19',
+ id=83,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_23'),
+ 84:
+ dict(
+ name='Face_20',
+ id=84,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_22'),
+ 85:
+ dict(name='Face_21', id=85, color=[255, 255, 255], type='', swap=''),
+ 86:
+ dict(
+ name='Face_22',
+ id=86,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_20'),
+ 87:
+ dict(
+ name='Face_23',
+ id=87,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_19'),
+ 88:
+ dict(
+ name='Face_24',
+ id=88,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_33'),
+ 89:
+ dict(
+ name='Face_25',
+ id=89,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_32'),
+ 90:
+ dict(
+ name='Face_26',
+ id=90,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_31'),
+ 91:
+ dict(
+ name='Face_27',
+ id=91,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_30'),
+ 92:
+ dict(
+ name='Face_28',
+ id=92,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_35'),
+ 93:
+ dict(
+ name='Face_29',
+ id=93,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_34'),
+ 94:
+ dict(
+ name='Face_30',
+ id=94,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_27'),
+ 95:
+ dict(
+ name='Face_31',
+ id=95,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_26'),
+ 96:
+ dict(
+ name='Face_32',
+ id=96,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_25'),
+ 97:
+ dict(
+ name='Face_33',
+ id=97,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_24'),
+ 98:
+ dict(
+ name='Face_34',
+ id=98,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_29'),
+ 99:
+ dict(
+ name='Face_35',
+ id=99,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_28'),
+ 100:
+ dict(
+ name='Face_36',
+ id=100,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_42'),
+ 101:
+ dict(
+ name='Face_37',
+ id=101,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_41'),
+ 102:
+ dict(
+ name='Face_38',
+ id=102,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_40'),
+ 103:
+ dict(name='Face_39', id=103, color=[255, 255, 255], type='', swap=''),
+ 104:
+ dict(
+ name='Face_40',
+ id=104,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_38'),
+ 105:
+ dict(
+ name='Face_41',
+ id=105,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_37'),
+ 106:
+ dict(
+ name='Face_42',
+ id=106,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_36'),
+ 107:
+ dict(
+ name='Face_43',
+ id=107,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_47'),
+ 108:
+ dict(
+ name='Face_44',
+ id=108,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_46'),
+ 109:
+ dict(name='Face_45', id=109, color=[255, 255, 255], type='', swap=''),
+ 110:
+ dict(
+ name='Face_46',
+ id=110,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_44'),
+ 111:
+ dict(
+ name='Face_47',
+ id=111,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_43'),
+ 112:
+ dict(
+ name='Face_48',
+ id=112,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_52'),
+ 113:
+ dict(
+ name='Face_49',
+ id=113,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_51'),
+ 114:
+ dict(name='Face_50', id=114, color=[255, 255, 255], type='', swap=''),
+ 115:
+ dict(
+ name='Face_51',
+ id=115,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_49'),
+ 116:
+ dict(
+ name='Face_52',
+ id=116,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_48'),
+ 117:
+ dict(
+ name='Face_53',
+ id=117,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_55'),
+ 118:
+ dict(name='Face_54', id=118, color=[255, 255, 255], type='', swap=''),
+ 119:
+ dict(
+ name='Face_55',
+ id=119,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_53'),
+ 120:
+ dict(
+ name='Face_56',
+ id=120,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_72'),
+ 121:
+ dict(
+ name='Face_57',
+ id=121,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_71'),
+ 122:
+ dict(
+ name='Face_58',
+ id=122,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_70'),
+ 123:
+ dict(
+ name='Face_59',
+ id=123,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_69'),
+ 124:
+ dict(
+ name='Face_60',
+ id=124,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_68'),
+ 125:
+ dict(
+ name='Face_61',
+ id=125,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_67'),
+ 126:
+ dict(
+ name='Face_62',
+ id=126,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_66'),
+ 127:
+ dict(
+ name='Face_63',
+ id=127,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_65'),
+ 128:
+ dict(name='Face_64', id=128, color=[255, 255, 255], type='', swap=''),
+ 129:
+ dict(
+ name='Face_65',
+ id=129,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_63'),
+ 130:
+ dict(
+ name='Face_66',
+ id=130,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_62'),
+ 131:
+ dict(
+ name='Face_67',
+ id=131,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_61'),
+ 132:
+ dict(
+ name='Face_68',
+ id=132,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_60'),
+ 133:
+ dict(
+ name='Face_69',
+ id=133,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_59'),
+ 134:
+ dict(
+ name='Face_70',
+ id=134,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_58'),
+ 135:
+ dict(
+ name='Face_71',
+ id=135,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_57'),
+ 136:
+ dict(
+ name='Face_72',
+ id=136,
+ color=[255, 255, 255],
+ type='',
+ swap='Face_56'),
+ },
+ skeleton_info={
+ 0: dict(link=('L_Ankle', 'L_Knee'), id=0, color=[0, 255, 0]),
+ 1: dict(link=('L_Knee', 'L_Hip'), id=1, color=[0, 255, 0]),
+ 2: dict(link=('R_Ankle', 'R_Knee'), id=2, color=[0, 255, 0]),
+ 3: dict(link=('R_Knee', 'R_Hip'), id=3, color=[0, 255, 0]),
+ 4: dict(link=('L_Hip', 'R_Hip'), id=4, color=[0, 255, 0]),
+ 5: dict(link=('L_Shoulder', 'L_Hip'), id=5, color=[0, 255, 0]),
+ 6: dict(link=('R_Shoulder', 'R_Hip'), id=6, color=[0, 255, 0]),
+ 7: dict(link=('L_Shoulder', 'R_Shoulder'), id=7, color=[0, 255, 0]),
+ 8: dict(link=('L_Shoulder', 'L_Elbow'), id=8, color=[0, 255, 0]),
+ 9: dict(link=('R_Shoulder', 'R_Elbow'), id=9, color=[0, 255, 0]),
+ 10: dict(link=('L_Elbow', 'L_Wrist'), id=10, color=[0, 255, 0]),
+ 11: dict(link=('R_Elbow', 'R_Wrist'), id=11, color=[255, 128, 0]),
+ 12: dict(link=('L_Eye', 'R_Eye'), id=12, color=[255, 128, 0]),
+ 13: dict(link=('Nose', 'L_Eye'), id=13, color=[255, 128, 0]),
+ 14: dict(link=('Nose', 'R_Eye'), id=14, color=[255, 128, 0]),
+ 15: dict(link=('L_Eye', 'L_Ear'), id=15, color=[255, 128, 0]),
+ 16: dict(link=('R_Eye', 'R_Ear'), id=16, color=[255, 128, 0]),
+ 17: dict(link=('L_Ear', 'L_Shoulder'), id=17, color=[255, 128, 0]),
+ 18: dict(link=('R_Ear', 'R_Shoulder'), id=18, color=[255, 128, 0]),
+ 19: dict(link=('L_Ankle', 'L_Big_toe'), id=19, color=[255, 128, 0]),
+ 20: dict(link=('L_Ankle', 'L_Small_toe'), id=20, color=[255, 128, 0]),
+ 21: dict(link=('L_Ankle', 'L_Heel'), id=21, color=[255, 128, 0]),
+ 22: dict(link=('R_Ankle', 'R_Big_toe'), id=22, color=[255, 128, 0]),
+ 23: dict(link=('R_Ankle', 'R_Small_toe'), id=23, color=[255, 128, 0]),
+ 24: dict(link=('R_Ankle', 'R_Heel'), id=24, color=[255, 128, 0]),
+ 25: dict(link=('L_Wrist', 'L_Thumb_1'), id=25, color=[255, 128, 0]),
+ 26: dict(link=('L_Thumb_1', 'L_Thumb_2'), id=26, color=[255, 128, 0]),
+ 27: dict(link=('L_Thumb_2', 'L_Thumb_3'), id=27, color=[255, 128, 0]),
+ 28: dict(link=('L_Thumb_3', 'L_Thumb_4'), id=28, color=[255, 128, 0]),
+ 29: dict(link=('L_Wrist', 'L_Index_1'), id=29, color=[255, 128, 0]),
+ 30: dict(link=('L_Index_1', 'L_Index_2'), id=30, color=[255, 128, 0]),
+ 31:
+ dict(link=('L_Index_2', 'L_Index_3'), id=31, color=[255, 255, 255]),
+ 32:
+ dict(link=('L_Index_3', 'L_Index_4'), id=32, color=[255, 255, 255]),
+ 33: dict(link=('L_Wrist', 'L_Middle_1'), id=33, color=[255, 255, 255]),
+ 34:
+ dict(link=('L_Middle_1', 'L_Middle_2'), id=34, color=[255, 255, 255]),
+ 35:
+ dict(link=('L_Middle_2', 'L_Middle_3'), id=35, color=[255, 255, 255]),
+ 36:
+ dict(link=('L_Middle_3', 'L_Middle_4'), id=36, color=[255, 255, 255]),
+ 37: dict(link=('L_Wrist', 'L_Ring_1'), id=37, color=[255, 255, 255]),
+ 38: dict(link=('L_Ring_1', 'L_Ring_2'), id=38, color=[255, 255, 255]),
+ 39: dict(link=('L_Ring_2', 'L_Ring_3'), id=39, color=[255, 255, 255]),
+ 40: dict(link=('L_Ring_3', 'L_Ring_4'), id=40, color=[255, 255, 255]),
+ 41: dict(link=('L_Wrist', 'L_Pinky_1'), id=41, color=[255, 255, 255]),
+ 42:
+ dict(link=('L_Pinky_1', 'L_Pinky_2'), id=42, color=[255, 255, 255]),
+ 43:
+ dict(link=('L_Pinky_2', 'L_Pinky_3'), id=43, color=[255, 255, 255]),
+ 44:
+ dict(link=('L_Pinky_3', 'L_Pinky_4'), id=44, color=[255, 255, 255]),
+ 45: dict(link=('R_Wrist', 'R_Thumb_1'), id=45, color=[255, 255, 255]),
+ 46:
+ dict(link=('R_Thumb_1', 'R_Thumb_2'), id=46, color=[255, 255, 255]),
+ 47:
+ dict(link=('R_Thumb_2', 'R_Thumb_3'), id=47, color=[255, 255, 255]),
+ 48:
+ dict(link=('R_Thumb_3', 'R_Thumb_4'), id=48, color=[255, 255, 255]),
+ 49: dict(link=('R_Wrist', 'R_Index_1'), id=49, color=[255, 255, 255]),
+ 50:
+ dict(link=('R_Index_1', 'R_Index_2'), id=50, color=[255, 255, 255]),
+ 51:
+ dict(link=('R_Index_2', 'R_Index_3'), id=51, color=[255, 255, 255]),
+ 52:
+ dict(link=('R_Index_3', 'R_Index_4'), id=52, color=[255, 255, 255]),
+ 53: dict(link=('R_Wrist', 'R_Middle_1'), id=53, color=[255, 255, 255]),
+ 54:
+ dict(link=('R_Middle_1', 'R_Middle_2'), id=54, color=[255, 255, 255]),
+ 55:
+ dict(link=('R_Middle_2', 'R_Middle_3'), id=55, color=[255, 255, 255]),
+ 56:
+ dict(link=('R_Middle_3', 'R_Middle_4'), id=56, color=[255, 255, 255]),
+ 57: dict(link=('R_Wrist', 'R_Pinky_1'), id=57, color=[255, 255, 255]),
+ 58:
+ dict(link=('R_Pinky_1', 'R_Pinky_2'), id=58, color=[255, 255, 255]),
+ 59:
+ dict(link=('R_Pinky_2', 'R_Pinky_3'), id=59, color=[255, 255, 255]),
+ 60:
+ dict(link=('R_Pinky_3', 'R_Pinky_4'), id=60, color=[255, 255, 255]),
+ },
+ joint_weights=[1.] * 137,
+ sigmas=[])
diff --git a/docs/en/dataset_zoo/3d_body_keypoint.md b/docs/en/dataset_zoo/3d_body_keypoint.md
index 82e21010fc..3a35e2443b 100644
--- a/docs/en/dataset_zoo/3d_body_keypoint.md
+++ b/docs/en/dataset_zoo/3d_body_keypoint.md
@@ -8,6 +8,7 @@ MMPose supported datasets:
- [Human3.6M](#human36m) \[ [Homepage](http://vision.imar.ro/human3.6m/description.php) \]
- [CMU Panoptic](#cmu-panoptic) \[ [Homepage](http://domedb.perception.cs.cmu.edu/) \]
- [Campus/Shelf](#campus-and-shelf) \[ [Homepage](http://campar.in.tum.de/Chair/MultiHumanPose) \]
+- [UBody](#ubody3d) \[ [Homepage](https://osx-ubody.github.io/) \]
## Human3.6M
@@ -197,3 +198,100 @@ mmpose
| ├── pred_shelf_maskrcnn_hrnet_coco.pkl
| ├── actorsGT.mat
```
+
+## UBody3d
+
+
+UBody (CVPR'2023)
+
+```bibtex
+@article{lin2023one,
+ title={One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer},
+ author={Lin, Jing and Zeng, Ailing and Wang, Haoqian and Zhang, Lei and Li, Yu},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ year={2023},
+}
+```
+
+
+
+
+
+
+
+For [Ubody](https://github.com/IDEA-Research/OSX) dataset, videos and annotations can be downloaded from [OSX homepage](https://github.com/IDEA-Research/OSX).
+
+Download and extract them under $MMPOSE/data, and make them look like this:
+
+```text
+mmpose
+├── mmpose
+├── docs
+├── tests
+├── tools
+├── configs
+`── data
+ │── UBody
+ ├── annotations
+ │ ├── ConductMusic
+ │ ├── Entertainment
+ │ ├── Fitness
+ │ ├── Interview
+ │ ├── LiveVlog
+ │ ├── Magic_show
+ │ ├── Movie
+ │ ├── Olympic
+ │ ├── Online_class
+ │ ├── SignLanguage
+ │ ├── Singing
+ │ ├── Speech
+ │ ├── TVShow
+ │ ├── TalkShow
+ │ └── VideoConference
+ ├── splits
+ │ ├── inter_scene_test_list.npy
+ │ └── intra_scene_test_list.npy
+ ├── videos
+ │ ├── ConductMusic
+ │ ├── Entertainment
+ │ ├── Fitness
+ │ ├── Interview
+ │ ├── LiveVlog
+ │ ├── Magic_show
+ │ ├── Movie
+ │ ├── Olympic
+ │ ├── Online_class
+ │ ├── SignLanguage
+ │ ├── Singing
+ │ ├── Speech
+ │ ├── TVShow
+ │ ├── TalkShow
+ │ └── VideoConference
+```
+
+Convert videos to images then split them into train/val set:
+
+```shell
+python tools/dataset_converters/ubody_kpts_to_coco.py
+```
+
+Before generating 3D keypoints, you need to install SMPLX tools and download human models, please refer to [Github](https://github.com/vchoutas/smplx#installation) and [SMPLX](https://smpl-x.is.tue.mpg.de/download.php).
+
+```shell
+pip install smplx
+```
+
+The directory tree of human models should be like this:
+
+```text
+human_model_path
+|── smplx
+ ├── SMPLX_NEUTRAL.npz
+ ├── SMPLX_NEUTRAL.pkl
+```
+
+After the above preparations are finished, execute the following script:
+
+```shell
+python tools/dataset_converters/ubody_smplx_to_coco.py --data-root {$MMPOSE/data/UBody} --human-model-path {$MMPOSE/data/human_model_path/}
+```
diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py
index 290edafed0..b06d934ac5 100644
--- a/mmpose/datasets/datasets/base/base_mocap_dataset.py
+++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py
@@ -96,8 +96,7 @@ def __init__(self,
assert exists(_ann_file), (
f'Annotation file `{_ann_file}` does not exist.')
- with get_local_path(_ann_file) as local_path:
- self.ann_data = np.load(local_path)
+ self._load_ann_file(_ann_file)
self.camera_param_file = camera_param_file
if self.camera_param_file:
@@ -137,6 +136,19 @@ def __init__(self,
lazy_init=lazy_init,
max_refetch=max_refetch)
+ def _load_ann_file(self, ann_file: str) -> dict:
+ """Load annotation file to get image information.
+
+ Args:
+ ann_file (str): Annotation file path.
+
+ Returns:
+ dict: Annotation information.
+ """
+
+ with get_local_path(ann_file) as local_path:
+ self.ann_data = np.load(local_path)
+
@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
diff --git a/mmpose/datasets/datasets/body3d/__init__.py b/mmpose/datasets/datasets/body3d/__init__.py
index d5afeca578..2b52caeadd 100644
--- a/mmpose/datasets/datasets/body3d/__init__.py
+++ b/mmpose/datasets/datasets/body3d/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .h36m_dataset import Human36mDataset
+from .ubody3d_dataset import UBody3dDataset
-__all__ = ['Human36mDataset']
+__all__ = ['Human36mDataset', 'UBody3dDataset']
diff --git a/mmpose/datasets/datasets/body3d/ubody3d_dataset.py b/mmpose/datasets/datasets/body3d/ubody3d_dataset.py
new file mode 100644
index 0000000000..85b8d893e7
--- /dev/null
+++ b/mmpose/datasets/datasets/body3d/ubody3d_dataset.py
@@ -0,0 +1,247 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import defaultdict
+from typing import List, Tuple
+
+import numpy as np
+from mmengine.fileio import get_local_path
+from xtcocotools.coco import COCO
+
+from mmpose.datasets.datasets import BaseMocapDataset
+from mmpose.registry import DATASETS
+
+
+@DATASETS.register_module()
+class UBody3dDataset(BaseMocapDataset):
+ """Ubody3d dataset for 3D human pose estimation.
+
+ "One-Stage 3D Whole-Body Mesh Recovery with Component Aware Transformer",
+ CVPR'2023. More details can be found in the `paper
+ `__ .
+
+ Ubody3D keypoints::
+
+ 0-24: 25 body keypoints,
+ 25-64: 40 hand keypoints,
+ 65-136: 72 face keypoints,
+
+ In total, we have 137 keypoints for wholebody 3D pose estimation.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ seq_len (int): Number of frames in a sequence. Default: 1.
+ multiple_target (int): If larger than 0, merge every
+ ``multiple_target`` sequence together. Default: 0.
+ causal (bool): If set to ``True``, the rightmost input frame will be
+ the target frame. Otherwise, the middle input frame will be the
+ target frame. Default: ``True``.
+ subset_frac (float): The fraction to reduce dataset size. If set to 1,
+ the dataset size is not reduced. Default: 1.
+ camera_param_file (str): Cameras' parameters file. Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data.
+ Default: ``dict(img='')``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ def __init__(self,
+ multiple_target: int = 0,
+ multiple_target_step: int = 0,
+ seq_step: int = 1,
+ pad_video_seq: bool = False,
+ **kwargs):
+ self.seq_step = seq_step
+ self.pad_video_seq = pad_video_seq
+
+ if multiple_target > 0 and multiple_target_step == 0:
+ multiple_target_step = multiple_target
+ self.multiple_target_step = multiple_target_step
+
+ super().__init__(multiple_target=multiple_target, **kwargs)
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/ubody3d.py')
+
+ def _load_ann_file(self, ann_file: str) -> dict:
+ """Load annotation file."""
+ with get_local_path(ann_file) as local_path:
+ self.ann_data = COCO(local_path)
+
+ def get_sequence_indices(self) -> List[List[int]]:
+ video_frames = defaultdict(list)
+ img_ids = self.ann_data.getImgIds()
+ for img_id in img_ids:
+ img_info = self.ann_data.loadImgs(img_id)[0]
+ subj, _, _ = self._parse_image_name(img_info['file_name'])
+ video_frames[subj].append(img_id)
+
+ sequence_indices = []
+ _len = (self.seq_len - 1) * self.seq_step + 1
+ _step = self.seq_step
+
+ if self.multiple_target:
+ for _, _img_ids in sorted(video_frames.items()):
+ n_frame = len(_img_ids)
+ _ann_ids = self.ann_data.getAnnIds(imgIds=_img_ids)
+ seqs_from_video = [
+ _ann_ids[i:(i + self.multiple_target):_step]
+ for i in range(0, n_frame, self.multiple_target_step)
+ ][:(n_frame + self.multiple_target_step -
+ self.multiple_target) // self.multiple_target_step]
+ sequence_indices.extend(seqs_from_video)
+ else:
+ for _, _img_ids in sorted(video_frames.items()):
+ n_frame = len(_img_ids)
+ _ann_ids = self.ann_data.getAnnIds(imgIds=_img_ids)
+ if self.pad_video_seq:
+ # Pad the sequence so that every frame in the sequence will
+ # be predicted.
+ if self.causal:
+ frames_left = self.seq_len - 1
+ frames_right = 0
+ else:
+ frames_left = (self.seq_len - 1) // 2
+ frames_right = frames_left
+ for i in range(n_frame):
+ pad_left = max(0, frames_left - i // _step)
+ pad_right = max(
+ 0, frames_right - (n_frame - 1 - i) // _step)
+ start = max(i % _step, i - frames_left * _step)
+ end = min(n_frame - (n_frame - 1 - i) % _step,
+ i + frames_right * _step + 1)
+ sequence_indices.append([_ann_ids[0]] * pad_left +
+ _ann_ids[start:end:_step] +
+ [_ann_ids[-1]] * pad_right)
+ else:
+ seqs_from_video = [
+ _ann_ids[i:(i + _len):_step]
+ for i in range(0, n_frame - _len + 1, _step)
+ ]
+ sequence_indices.extend(seqs_from_video)
+
+ # reduce dataset size if needed
+ subset_size = int(len(sequence_indices) * self.subset_frac)
+ start = np.random.randint(0, len(sequence_indices) - subset_size + 1)
+ end = start + subset_size
+
+ sequence_indices = sequence_indices[start:end]
+
+ return sequence_indices
+
+ def _parse_image_name(self, image_path: str) -> Tuple[str, int]:
+ """Parse image name to get video name and frame index.
+
+ Args:
+ image_name (str): Image name.
+
+ Returns:
+ tuple[str, int]: Video name and frame index.
+ """
+ trim, file_name = image_path.split('/')[-2:]
+ frame_id, suffix = file_name.split('.')
+ return trim, frame_id, suffix
+
+ def _load_annotations(self):
+ """Load data from annotations in COCO format."""
+ num_keypoints = self.metainfo['num_keypoints']
+ self._metainfo['CLASSES'] = self.ann_data.loadCats(
+ self.ann_data.getCatIds())
+
+ instance_list = []
+ image_list = []
+
+ for i, _ann_ids in enumerate(self.sequence_indices):
+ expected_num_frames = self.seq_len
+ if self.multiple_target:
+ expected_num_frames = self.multiple_target
+
+ assert len(_ann_ids) == (expected_num_frames), (
+ f'Expected `frame_ids` == {expected_num_frames}, but '
+ f'got {len(_ann_ids)} ')
+
+ anns = self.ann_data.loadAnns(_ann_ids)
+ img_ids = []
+ kpts = np.zeros((len(anns), num_keypoints, 2), dtype=np.float32)
+ kpts_3d = np.zeros((len(anns), num_keypoints, 3), dtype=np.float32)
+ keypoints_visible = np.zeros((len(anns), num_keypoints, 1),
+ dtype=np.float32)
+ for j, ann in enumerate(anns):
+ img_ids.append(ann['image_id'])
+ kpts[j] = np.array(ann['keypoints'], dtype=np.float32)
+ kpts_3d[j] = np.array(ann['keypoints_3d'], dtype=np.float32)
+ keypoints_visible[j] = np.array(
+ ann['keypoints_valid'], dtype=np.float32)
+ imgs = self.ann_data.loadImgs(img_ids)
+ keypoints_visible = keypoints_visible.squeeze(-1)
+
+ scales = np.zeros(len(imgs), dtype=np.float32)
+ centers = np.zeros((len(imgs), 2), dtype=np.float32)
+ img_paths = np.array([img['file_name'] for img in imgs])
+ factors = np.zeros((kpts_3d.shape[0], ), dtype=np.float32)
+
+ target_idx = [-1] if self.causal else [int(self.seq_len // 2)]
+ if self.multiple_target:
+ target_idx = list(range(self.multiple_target))
+
+ cam_param = anns[-1]['camera_param']
+ if 'w' not in cam_param or 'h' not in cam_param:
+ cam_param['w'] = 1000
+ cam_param['h'] = 1000
+
+ instance_info = {
+ 'num_keypoints': num_keypoints,
+ 'keypoints': kpts,
+ 'keypoints_3d': kpts_3d,
+ 'keypoints_visible': keypoints_visible,
+ 'scale': scales,
+ 'center': centers,
+ 'id': i,
+ 'category_id': 1,
+ 'iscrowd': 0,
+ 'img_paths': list(img_paths),
+ 'img_ids': [img['id'] for img in imgs],
+ 'lifting_target': kpts_3d[target_idx],
+ 'lifting_target_visible': keypoints_visible[target_idx],
+ 'target_img_paths': img_paths[target_idx],
+ 'camera_param': cam_param,
+ 'factor': factors,
+ 'target_idx': target_idx,
+ }
+
+ instance_list.append(instance_info)
+
+ for img_id in self.ann_data.getImgIds():
+ img = self.ann_data.loadImgs(img_id)[0]
+ img.update({
+ 'img_id':
+ img_id,
+ 'img_path':
+ osp.join(self.data_prefix['img'], img['file_name']),
+ })
+ image_list.append(img)
+
+ return instance_list, image_list
diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py
index c8204ac7ef..1906f16972 100644
--- a/mmpose/datasets/transforms/converting.py
+++ b/mmpose/datasets/transforms/converting.py
@@ -91,8 +91,12 @@ def transform(self, results: dict) -> dict:
num_instances = results['keypoints'].shape[0]
# Initialize output arrays
- keypoints = np.zeros((num_instances, self.num_keypoints, 2))
+ keypoints = np.zeros((num_instances, self.num_keypoints, 3))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))
+ key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints'
+ c = results[key].shape[-1]
+
+ flip_indices = results.get('flip_indices', None)
# Create a mask to weight visibility loss
keypoints_visible_weights = keypoints_visible.copy()
@@ -100,26 +104,38 @@ def transform(self, results: dict) -> dict:
# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
- keypoints[:, self.target_index] = 0.5 * (
- results['keypoints'][:, self.source_index] +
- results['keypoints'][:, self.source_index2])
-
+ keypoints[:, self.target_index, :c] = 0.5 * (
+ results[key][:, self.source_index] +
+ results[key][:, self.source_index2])
keypoints_visible[:, self.target_index] = results[
- 'keypoints_visible'][:, self.source_index] * \
- results['keypoints_visible'][:, self.source_index2]
-
+ 'keypoints_visible'][:, self.source_index] * results[
+ 'keypoints_visible'][:, self.source_index2]
+ # Flip keypoints if flip_indices provided
+ if flip_indices is not None:
+ for i, (x1, x2) in enumerate(
+ zip(self.source_index, self.source_index2)):
+ idx = flip_indices[x1] if x1 == x2 else i
+ flip_indices[i] = idx if idx < self.num_keypoints else i
+ flip_indices = flip_indices[:len(self.source_index)]
# Otherwise just copy from the source index
else:
keypoints[:,
- self.target_index] = results['keypoints'][:, self.
- source_index]
+ self.target_index, :c] = results[key][:,
+ self.source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]
# Update the results dict
- results['keypoints'] = keypoints
+ results['keypoints'] = keypoints[..., :2]
results['keypoints_visible'] = np.stack(
[keypoints_visible, keypoints_visible_weights], axis=2)
+ if 'keypoints_3d' in results:
+ results['keypoints_3d'] = keypoints
+ results['lifting_target'] = keypoints[results['target_idx']]
+ results['lifting_target_visible'] = keypoints_visible[
+ results['target_idx']]
+ results['flip_indices'] = flip_indices
+
return results
def transform_sigmas(self, sigmas: Union[List, np.ndarray]):
diff --git a/mmpose/evaluation/metrics/__init__.py b/mmpose/evaluation/metrics/__init__.py
index 7090f0226a..9e82356a49 100644
--- a/mmpose/evaluation/metrics/__init__.py
+++ b/mmpose/evaluation/metrics/__init__.py
@@ -7,9 +7,10 @@
from .keypoint_3d_metrics import MPJPE
from .keypoint_partition_metric import KeypointPartitionMetric
from .posetrack18_metric import PoseTrack18Metric
+from .simple_keypoint_3d_metrics import SimpleMPJPE
__all__ = [
'CocoMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'AUC',
'EPE', 'NME', 'PoseTrack18Metric', 'CocoWholeBodyMetric',
- 'KeypointPartitionMetric', 'MPJPE', 'InterHandMetric'
+ 'KeypointPartitionMetric', 'MPJPE', 'InterHandMetric', 'SimpleMPJPE'
]
diff --git a/mmpose/evaluation/metrics/simple_keypoint_3d_metrics.py b/mmpose/evaluation/metrics/simple_keypoint_3d_metrics.py
new file mode 100644
index 0000000000..dc0065d5b9
--- /dev/null
+++ b/mmpose/evaluation/metrics/simple_keypoint_3d_metrics.py
@@ -0,0 +1,119 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Sequence
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+from mmengine.logging import MMLogger
+
+from mmpose.registry import METRICS
+from ..functional import keypoint_mpjpe
+
+
+@METRICS.register_module()
+class SimpleMPJPE(BaseMetric):
+ """MPJPE evaluation metric.
+
+ Calculate the mean per-joint position error (MPJPE) of keypoints.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ mode (str): Method to align the prediction with the
+ ground truth. Supported options are:
+
+ - ``'mpjpe'``: no alignment will be applied
+ - ``'p-mpjpe'``: align in the least-square sense in scale
+ - ``'n-mpjpe'``: align in the least-square sense in
+ scale, rotation, and translation.
+
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+ skip_list (list, optional): The list of subject and action combinations
+ to be skipped. Default: [].
+ """
+
+ ALIGNMENT = {'mpjpe': 'none', 'p-mpjpe': 'procrustes', 'n-mpjpe': 'scale'}
+
+ def __init__(self,
+ mode: str = 'mpjpe',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None,
+ skip_list: List[str] = []) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ allowed_modes = self.ALIGNMENT.keys()
+ if mode not in allowed_modes:
+ raise KeyError("`mode` should be 'mpjpe', 'p-mpjpe', or "
+ f"'n-mpjpe', but got '{mode}'.")
+
+ self.mode = mode
+ self.skip_list = skip_list
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ # predicted keypoints coordinates, [T, K, D]
+ pred_coords = data_sample['pred_instances']['keypoints']
+ if pred_coords.ndim == 4:
+ pred_coords = np.squeeze(pred_coords, axis=0)
+ # ground truth data_info
+ gt = data_sample['gt_instances']
+ # ground truth keypoints coordinates, [T, K, D]
+ gt_coords = gt['lifting_target']
+ # ground truth keypoints_visible, [T, K, 1]
+ mask = gt['lifting_target_visible'].astype(bool).reshape(
+ gt_coords.shape[0], -1)
+
+ result = {
+ 'pred_coords': pred_coords,
+ 'gt_coords': gt_coords,
+ 'mask': mask,
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are the corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ error_name = self.mode.upper()
+
+ logger.info(f'Evaluating {self.mode.upper()}...')
+ return {
+ error_name:
+ keypoint_mpjpe(pred_coords, gt_coords, mask,
+ self.ALIGNMENT[self.mode])
+ }
diff --git a/tests/data/ubody3d/ubody3d_train.json b/tests/data/ubody3d/ubody3d_train.json
new file mode 100644
index 0000000000..55a4ac5226
--- /dev/null
+++ b/tests/data/ubody3d/ubody3d_train.json
@@ -0,0 +1 @@
+{"images": [{"id": 15, "height": 720, "width": 1280, "file_name": "Magic_show/Magic_show_S1_Trim1/Magic_show_S1_Trim1/000016.png"}], "annotations": [{"id": 0, "image_id": 15, "bbox": [74.55498504638672, 8.571063995361328, 1062.4967727661133, 701.8491630554199], "segmentation": [[]], "area": 0, "iscrowd": 0, "category_id": 1, "score": 1, "person_id": 0, "hand_box1": [336.4236145019531, 321.40362548828125, 473.6637268066406, 452.62567138671875], "hand_box2": [699.218994140625, 50.335018157958984, 533.58251953125, 621.6577186584473], "keypoints": [[585.656005859375, 1398.5216064453125], [699.9061889648438, 1586.966064453125], [450.14288330078125, 1596.144775390625], [878.3228149414062, 2171.27783203125], [252.16543579101562, 2132.398681640625], [793.895263671875, 2988.90771484375], [232.56475830078125, 2939.503173828125], [588.2872314453125, 570.474365234375], [862.1456298828125, 514.33837890625], [373.89849853515625, 519.60888671875], [1073.739990234375, 765.0070190429688], [89.8785400390625, 775.919921875], [1000.2418212890625, 635.8955688476562], [189.44015502929688, 567.993408203125], [891.81298828125, 2948.2041015625], [1013.4824829101562, 3015.250732421875], [819.24658203125, 3122.821533203125], [172.14041137695312, 2868.272705078125], [31.46063232421875, 2937.01025390625], [244.37692260742188, 3111.135009765625], [760.2764282226562, 235.35623168945312], [469.04644775390625, 237.359130859375], [672.689453125, 216.68638610839844], [536.8645629882812, 215.08010864257812], [594.4747924804688, 302.86590576171875], [937.543212890625, 563.2012939453125], [877.2040405273438, 564.7064819335938], [826.8228759765625, 548.8115234375], [768.3922729492188, 532.2924194335938], [945.0330810546875, 433.25579833984375], [887.2977905273438, 411.39129638671875], [854.9716796875, 409.1885986328125], [812.5216064453125, 409.8503112792969], [993.1986083984375, 415.13519287109375], [983.431640625, 352.09503173828125], [976.8125610351562, 306.58990478515625], [967.6991577148438, 251.8966064453125], [1042.6788330078125, 439.2115783691406], [1061.695068359375, 382.62310791015625], [1078.3428955078125, 336.8554382324219], [1089.8707275390625, 288.113037109375], [1077.3145751953125, 467.8497009277344], [1113.5694580078125, 449.51904296875], [1147.91796875, 434.2681884765625], [1184.372314453125, 406.7205505371094], [262.0787048339844, 512.4108276367188], [314.8291320800781, 495.84429931640625], [355.2375183105469, 463.73870849609375], [400.5841064453125, 429.6348876953125], [290.11627197265625, 385.6371765136719], [334.016357421875, 356.7796325683594], [352.326904296875, 347.6751403808594], [379.92449951171875, 336.6559143066406], [248.99337768554688, 355.2509460449219], [270.441162109375, 294.56085205078125], [283.58990478515625, 247.07943725585938], [298.6072692871094, 191.95077514648438], [194.588623046875, 364.1822509765625], [197.89288330078125, 304.9277038574219], [198.94699096679688, 255.0223846435547], [207.83172607421875, 206.8009490966797], [152.69793701171875, 380.91925048828125], [126.07894897460938, 349.861083984375], [99.02603149414062, 320.67138671875], [75.35498046875, 280.7127380371094], [605.5189819335938, 258.36474609375], [636.6569213867188, 261.03448486328125], [672.689453125, 216.68638610839844], [536.8645629882812, 215.08010864257812], [480.609130859375, 193.2221221923828], [498.7352294921875, 169.0961151123047], [527.0252075195312, 168.48736572265625], [556.564453125, 174.32501220703125], [582.2213134765625, 183.7449188232422], [619.771728515625, 185.09783935546875], [646.1015625, 177.27572631835938], [678.3016357421875, 172.73214721679688], [709.5665283203125, 174.52818298339844], [730.6221313476562, 199.52928161621094], [600.2632446289062, 215.79234313964844], [598.0828247070312, 240.45635986328125], [596.2218627929688, 264.4862976074219], [594.4674072265625, 287.62481689453125], [572.7188110351562, 305.8975830078125], [583.9725341796875, 311.3199157714844], [596.401123046875, 315.5985107421875], [609.6165771484375, 311.5094909667969], [622.2186279296875, 306.6711120605469], [512.6423950195312, 211.75982666015625], [528.5633544921875, 204.07089233398438], [548.4610595703125, 205.9830780029297], [565.9568481445312, 217.66900634765625], [548.8089599609375, 222.94613647460938], [530.2134399414062, 222.75762939453125], [639.6070556640625, 219.82444763183594], [655.8860473632812, 209.6044158935547], [676.3201904296875, 208.3985595703125], [694.9487915039062, 217.1615753173828], [674.3418579101562, 226.85595703125], [655.4156494140625, 225.6745147705078], [551.7490234375, 353.2354736328125], [564.1500244140625, 346.4883728027344], [583.2034912109375, 344.99609375], [595.4065551757812, 347.21868896484375], [607.8397216796875, 345.721435546875], [629.6182250976562, 348.2886047363281], [648.6402587890625, 353.0809631347656], [634.0433349609375, 361.12738037109375], [612.543212890625, 365.1044921875], [598.9017333984375, 366.5699768066406], [585.4385375976562, 366.0231018066406], [566.12353515625, 362.2437744140625], [553.4495239257812, 352.7164001464844], [583.9151000976562, 355.8670654296875], [596.3876342773438, 356.340576171875], [608.99560546875, 356.22100830078125], [648.081787109375, 352.85076904296875], [612.7412719726562, 351.5333251953125], [598.9871215820312, 351.8242492675781], [585.3312377929688, 352.4969482421875], [464.1539001464844, 202.29954528808594], [465.8164978027344, 244.8143768310547], [469.96026611328125, 282.73333740234375], [474.998779296875, 318.5062255859375], [485.900390625, 354.82257080078125], [503.9440002441406, 389.1557922363281], [533.9607543945312, 420.1808776855469], [569.1990356445312, 439.69488525390625], [604.7715454101562, 445.1242370605469], [641.609130859375, 438.5807189941406], [677.1731567382812, 419.1774597167969], [709.558349609375, 390.3476867675781], [728.9358520507812, 358.6229553222656], [743.6824951171875, 323.7010192871094], [752.355224609375, 286.009033203125], [756.031494140625, 248.0742645263672], [756.6275634765625, 206.8378448486328]], "foot_kpts": [1166.72314453125, 38.096336364746094, 0, 1002.4937744140625, 109.48077392578125, 0, 1049.140869140625, 663.1453857421875, 0, 317.3815002441406, 32.0361328125, 0, 402.523681640625, 303.2774963378906, 0, 177.21731567382812, 665.190673828125, 0], "face_kpts": [482.1813659667969, 206.51531982421875, 0, 474.4501037597656, 248.23251342773438, 1, 482.5657043457031, 282.5651550292969, 1, 490.3671569824219, 326.8166198730469, 1, 498.9546813964844, 355.2204895019531, 1, 519.25634765625, 390.5085754394531, 1, 543.9222412109375, 417.4048156738281, 1, 574.4150390625, 437.6228332519531, 1, 614.6944580078125, 442.5209045410156, 1, 648.99267578125, 436.2539978027344, 1, 682.6341552734375, 416.4512023925781, 1, 702.5023193359375, 392.0824279785156, 1, 725.9093017578125, 358.3260803222656, 1, 739.4346923828125, 328.9374084472656, 1, 746.7598876953125, 285.0207824707031, 1, 748.8603515625, 251.59585571289062, 1, 755.915771484375, 212.4534149169922, 0, 496.4743957519531, 188.47494506835938, 1, 514.8231201171875, 177.99856567382812, 1, 535.214111328125, 176.0469970703125, 1, 556.4619140625, 177.9375, 1, 576.8843994140625, 183.35317993164062, 1, 631.4595947265625, 183.65673828125, 1, 652.4815673828125, 180.27340698242188, 1, 676.221923828125, 180.07711791992188, 1, 698.4794921875, 184.41073608398438, 1, 718.5443115234375, 196.21084594726562, 1, 604.396484375, 218.71194458007812, 1, 602.6702880859375, 245.68115234375, 1, 600.9422607421875, 271.4402770996094, 1, 599.4947509765625, 297.5359802246094, 1, 571.33203125, 313.3100891113281, 1, 586.1724853515625, 317.1542663574219, 1, 601.4893798828125, 320.0868835449219, 1, 617.738525390625, 316.9916687011719, 1, 632.822509765625, 313.9440002441406, 1, 524.906005859375, 216.0177001953125, 1, 542.880859375, 206.15841674804688, 1, 563.9365234375, 208.03213500976562, 1, 578.5321044921875, 222.44454956054688, 1, 559.7491455078125, 226.11843872070312, 1, 541.22607421875, 225.11203002929688, 1, 636.491943359375, 223.62353515625, 1, 652.7271728515625, 210.68789672851562, 1, 674.761474609375, 209.86370849609375, 1, 692.972900390625, 221.53323364257812, 1, 674.9864501953125, 228.75543212890625, 1, 656.0750732421875, 229.04306030273438, 1, 560.0743408203125, 351.4398498535156, 1, 577.081787109375, 347.0306091308594, 1, 594.04638671875, 345.2702941894531, 1, 604.1793212890625, 346.1555480957031, 1, 614.151611328125, 344.8525695800781, 1, 634.447509765625, 345.7118225097656, 1, 656.1597900390625, 347.9260559082031, 1, 640.6773681640625, 358.7562561035156, 1, 624.00732421875, 366.7438049316406, 1, 605.445556640625, 369.8896789550781, 1, 588.646484375, 367.5843811035156, 1, 573.5023193359375, 360.9281921386719, 1, 565.385498046875, 352.2278137207031, 1, 585.1085205078125, 353.1212463378906, 1, 604.616943359375, 355.0426330566406, 1, 626.8272705078125, 351.8833312988281, 1, 650.2919921875, 349.2644958496094, 1, 627.5924072265625, 353.0104675292969, 1, 604.7803955078125, 355.8074645996094, 1, 584.6986083984375, 354.2829284667969, 1], "lefthand_kpts": [942.7679443359375, 607.469482421875, 1, 888.291259765625, 539.277587890625, 1, 832.873291015625, 483.5708923339844, 1, 787.126953125, 436.6972351074219, 1, 710.735107421875, 413.7229309082031, 1, 888.9903564453125, 319.5710754394531, 1, 868.0140380859375, 280.7148742675781, 1, 830.3096923828125, 266.0387268066406, 1, 778.9337158203125, 271.2351379394531, 1, 962.7294921875, 272.7072448730469, 1, 955.781005859375, 187.65567016601562, 1, 953.9222412109375, 103.62838745117188, 1, 959.151611328125, 29.267608642578125, 1, 1047.009033203125, 294.3193664550781, 1, 1056.5989990234375, 215.84146118164062, 1, 1066.36865234375, 147.68014526367188, 1, 1081.0699462890625, 65.11972045898438, 1, 1107.0172119140625, 358.7002258300781, 1, 1159.4434814453125, 319.2156677246094, 1, 1206.9718017578125, 272.8797912597656, 1, 1261.1082763671875, 224.43637084960938, 1], "righthand_kpts": [233.142822265625, 582.3209228515625, 1, 300.6414794921875, 508.47479248046875, 1, 362.43896484375, 455.85186767578125, 1, 377.3603515625, 404.19744873046875, 1, 446.76416015625, 377.29241943359375, 1, 342.8802490234375, 310.6497802734375, 1, 368.6904296875, 284.673095703125, 1, 381.802734375, 251.73486328125, 1, 421.5467529296875, 225.363525390625, 1, 283.64288330078125, 254.122802734375, 1, 304.9996337890625, 170.8004150390625, 1, 320.6651611328125, 98.6851806640625, 1, 335.6553955078125, 28.2318115234375, 1, 199.05755615234375, 256.80859375, 1, 206.0360107421875, 177.01025390625, 1, 215.68804931640625, 106.7457275390625, 1, 224.53521728515625, 32.276611328125, 1, 128.827392578125, 294.99359130859375, 1, 99.0606689453125, 239.12982177734375, 1, 65.53125, 189.2431640625, 1, 37.63360595703125, 116.657958984375, 1], "center": [605.8033447265625, 359.4956359863281], "scale": [6.6406049728393555, 8.854140281677246], "keypoints_score": [0.9791078567504883, 0.9932481050491333, 1.0011144876480103, 0.973096489906311, 0.972457766532898, 0.866172194480896, 0.8760361671447754, 0.3526427149772644, 0.3903506398200989, 0.921836793422699, 0.9433825016021729, 0.20496317744255066, 0.2460474669933319, 0.20729553699493408, 0.17142903804779053, 0.18208564817905426, 0.22269707918167114], "face_kpts_score": [0.3680439293384552, 0.5355573892593384, 0.6418813467025757, 0.6644495725631714, 0.7590401768684387, 0.5538617372512817, 0.5907169580459595, 0.5878690481185913, 0.6348617076873779, 0.7361799478530884, 0.6556291580200195, 0.618322491645813, 0.6537319421768188, 0.5892513394355774, 0.7059171199798584, 0.645734429359436, 0.4574907422065735, 0.9639992713928223, 0.9263820648193359, 0.8876979351043701, 0.9284569621086121, 0.9739065170288086, 0.9502178430557251, 0.9174821376800537, 0.918608546257019, 0.9061530232429504, 0.862210750579834, 0.9776759147644043, 0.973875105381012, 0.974762499332428, 0.9565852880477905, 0.9716235399246216, 1.0059518814086914, 0.946382999420166, 0.9594531059265137, 0.9658107757568359, 1.0158061981201172, 0.9708306789398193, 0.9969902634620667, 0.9845597743988037, 0.9349627494812012, 0.9380444288253784, 0.9717998504638672, 0.9871775507926941, 0.9774664640426636, 0.9537898898124695, 0.9465979933738708, 0.9661000967025757, 0.9713011980056763, 0.9717509746551514, 0.956028938293457, 1.000832438468933, 0.9808722734451294, 0.9960898160934448, 0.9364079236984253, 1.0011546611785889, 0.9167187213897705, 0.9541155099868774, 0.9244742393493652, 0.988551139831543, 0.9954862594604492, 0.9832127094268799, 0.978826642036438, 0.9751479625701904, 0.956895112991333, 0.9974040985107422, 0.9864891767501831, 0.9898920655250549], "foot_kpts_score": [0.24755269289016724, 0.1599443256855011, 0.25949808955192566, 0.2688680589199066, 0.14811083674430847, 0.23364056646823883], "lefthand_kpts_score": [0.603957986831665, 0.46176729202270506, 0.5001004695892334, 0.6286116600036621, 0.7983541250228882, 0.7467568874359131, 0.7094749569892883, 0.7889106035232544, 0.8908322811126709, 0.8638974189758301, 1.0441084861755372, 0.9282500505447387, 0.9102095127105713, 0.7738837957382202, 0.94963458776474, 0.8981462478637695, 0.9926700949668884, 0.7828058958053589, 0.9498528003692627, 0.9387582302093506, 0.8471795082092285], "righthand_kpts_score": [0.6722876787185669, 0.60037282705307, 0.5398626983165741, 0.7077780723571777, 0.7050052642822265, 0.6411999225616455, 0.725990629196167, 0.758279001712799, 0.8829087972640991, 0.889958119392395, 0.9569337129592895, 0.9145335912704468, 0.9213766813278198, 0.8925279140472412, 0.9955486416816711, 1.0033048152923585, 1.0014301896095277, 0.9033888339996338, 0.9002806305885315, 0.8902452945709228, 0.888652241230011], "face_box": [445.3220458984375, 145.05938720703125, 348.63178710937495, 332.0302734375], "face_valid": true, "leftfoot_valid": false, "rightfoot_valid": false, "lefthand_valid": true, "righthand_valid": true, "lefthand_box": [699.218994140625, 50.335018157958984, 533.58251953125, 621.6577186584473], "righthand_box": [81.47227172851564, -7.12115478515625, 398.4362548828125, 664.060546875], "lefthand_update": true, "righthand_update": true, "lefthand_kpts_vitposehand": [942.7679443359375, 607.469482421875, 1, 888.291259765625, 539.277587890625, 1, 832.873291015625, 483.5708923339844, 1, 787.126953125, 436.6972351074219, 1, 710.735107421875, 413.7229309082031, 1, 888.9903564453125, 319.5710754394531, 1, 868.0140380859375, 280.7148742675781, 1, 830.3096923828125, 266.0387268066406, 1, 778.9337158203125, 271.2351379394531, 1, 962.7294921875, 272.7072448730469, 1, 955.781005859375, 187.65567016601562, 1, 953.9222412109375, 103.62838745117188, 1, 959.151611328125, 29.267608642578125, 1, 1047.009033203125, 294.3193664550781, 1, 1056.5989990234375, 215.84146118164062, 1, 1066.36865234375, 147.68014526367188, 1, 1081.0699462890625, 65.11972045898438, 1, 1107.0172119140625, 358.7002258300781, 1, 1159.4434814453125, 319.2156677246094, 1, 1206.9718017578125, 272.8797912597656, 1, 1261.1082763671875, 224.43637084960938, 1], "righthand_kpts_vitposehand": [233.142822265625, 582.3209228515625, 1, 300.6414794921875, 508.47479248046875, 1, 362.43896484375, 455.85186767578125, 1, 377.3603515625, 404.19744873046875, 1, 446.76416015625, 377.29241943359375, 1, 342.8802490234375, 310.6497802734375, 1, 368.6904296875, 284.673095703125, 1, 381.802734375, 251.73486328125, 1, 421.5467529296875, 225.363525390625, 1, 283.64288330078125, 254.122802734375, 1, 304.9996337890625, 170.8004150390625, 1, 320.6651611328125, 98.6851806640625, 1, 335.6553955078125, 28.2318115234375, 1, 199.05755615234375, 256.80859375, 1, 206.0360107421875, 177.01025390625, 1, 215.68804931640625, 106.7457275390625, 1, 224.53521728515625, 32.276611328125, 1, 128.827392578125, 294.99359130859375, 1, 99.0606689453125, 239.12982177734375, 1, 65.53125, 189.2431640625, 1, 37.63360595703125, 116.657958984375, 1], "num_keypoints": 9, "full_body": false, "valid_label": 2, "keypoints_3d": [[585.656005859375, 1398.5216064453125, 8.0], [699.9061889648438, 1586.966064453125, 7.7132415771484375], [450.14288330078125, 1596.144775390625, 7.6570892333984375], [878.3228149414062, 2171.27783203125, 5.664215087890625], [252.16543579101562, 2132.398681640625, 5.6501007080078125], [793.895263671875, 2988.90771484375, 4.6084747314453125], [232.56475830078125, 2939.503173828125, 4.28839111328125], [588.2872314453125, 570.474365234375, 9.544265747070312], [862.1456298828125, 514.33837890625, 8.8726806640625], [373.89849853515625, 519.60888671875, 9.171127319335938], [1073.739990234375, 765.0070190429688, 7.1384735107421875], [89.8785400390625, 775.919921875, 7.5379791259765625], [1000.2418212890625, 635.8955688476562, 5.19927978515625], [189.44015502929688, 567.993408203125, 5.757049560546875], [891.81298828125, 2948.2041015625, 3.0384368896484375], [1013.4824829101562, 3015.250732421875, 3.43035888671875], [819.24658203125, 3122.821533203125, 4.943603515625], [172.14041137695312, 2868.272705078125, 2.809112548828125], [31.46063232421875, 2937.01025390625, 3.1867828369140625], [244.37692260742188, 3111.135009765625, 4.5428619384765625], [760.2764282226562, 235.35623168945312, 9.170547485351562], [469.04644775390625, 237.359130859375, 9.270904541015625], [672.689453125, 216.68638610839844, 8.436477661132812], [536.8645629882812, 215.08010864257812, 8.477508544921875], [594.4747924804688, 302.86590576171875, 8.231826782226562], [937.543212890625, 563.2012939453125, 7.81884765625], [877.2040405273438, 564.7064819335938, 7.746490478515625], [826.8228759765625, 548.8115234375, 7.6898651123046875], [768.3922729492188, 532.2924194335938, 7.540069580078125], [945.0330810546875, 433.25579833984375, 7.78143310546875], [887.2977905273438, 411.39129638671875, 7.68023681640625], [854.9716796875, 409.1885986328125, 7.548248291015625], [812.5216064453125, 409.8503112792969, 7.41748046875], [993.1986083984375, 415.13519287109375, 7.762298583984375], [983.431640625, 352.09503173828125, 7.7212677001953125], [976.8125610351562, 306.58990478515625, 7.644317626953125], [967.6991577148438, 251.8966064453125, 7.58074951171875], [1042.6788330078125, 439.2115783691406, 7.7346954345703125], [1061.695068359375, 382.62310791015625, 7.7144622802734375], [1078.3428955078125, 336.8554382324219, 7.6671142578125], [1089.8707275390625, 288.113037109375, 7.64324951171875], [1077.3145751953125, 467.8497009277344, 7.6988525390625], [1113.5694580078125, 449.51904296875, 7.6714019775390625], [1147.91796875, 434.2681884765625, 7.6133880615234375], [1184.372314453125, 406.7205505371094, 7.566802978515625], [262.0787048339844, 512.4108276367188, 7.7939453125], [314.8291320800781, 495.84429931640625, 7.6787109375], [355.2375183105469, 463.73870849609375, 7.6097564697265625], [400.5841064453125, 429.6348876953125, 7.4446563720703125], [290.11627197265625, 385.6371765136719, 7.82208251953125], [334.016357421875, 356.7796325683594, 7.663116455078125], [352.326904296875, 347.6751403808594, 7.499725341796875], [379.92449951171875, 336.6559143066406, 7.330535888671875], [248.99337768554688, 355.2509460449219, 7.84161376953125], [270.441162109375, 294.56085205078125, 7.848602294921875], [283.58990478515625, 247.07943725585938, 7.8173370361328125], [298.6072692871094, 191.95077514648438, 7.8151092529296875], [194.588623046875, 364.1822509765625, 7.8341217041015625], [197.89288330078125, 304.9277038574219, 7.8556976318359375], [198.94699096679688, 255.0223846435547, 7.8529815673828125], [207.83172607421875, 206.8009490966797, 7.8715667724609375], [152.69793701171875, 380.91925048828125, 7.8072052001953125], [126.07894897460938, 349.861083984375, 7.8142547607421875], [99.02603149414062, 320.67138671875, 7.79296875], [75.35498046875, 280.7127380371094, 7.79833984375], [605.5189819335938, 258.36474609375, 7.6539459228515625], [636.6569213867188, 261.03448486328125, 7.6003265380859375], [672.689453125, 216.68638610839844, 6.8922119140625], [536.8645629882812, 215.08010864257812, 6.9332427978515625], [480.609130859375, 193.2221221923828, 7.156890869140625], [498.7352294921875, 169.0961151123047, 7.0008087158203125], [527.0252075195312, 168.48736572265625, 6.879364013671875], [556.564453125, 174.32501220703125, 6.8116912841796875], [582.2213134765625, 183.7449188232422, 6.796417236328125], [619.771728515625, 185.09783935546875, 6.7884368896484375], [646.1015625, 177.27572631835938, 6.788299560546875], [678.3016357421875, 172.73214721679688, 6.8334197998046875], [709.5665283203125, 174.52818298339844, 6.94036865234375], [730.6221313476562, 199.52928161621094, 7.08001708984375], [600.2632446289062, 215.79234313964844, 6.797698974609375], [598.0828247070312, 240.45635986328125, 6.753753662109375], [596.2218627929688, 264.4862976074219, 6.70782470703125], [594.4674072265625, 287.62481689453125, 6.66571044921875], [572.7188110351562, 305.8975830078125, 6.8535308837890625], [583.9725341796875, 311.3199157714844, 6.8229217529296875], [596.401123046875, 315.5985107421875, 6.804962158203125], [609.6165771484375, 311.5094909667969, 6.8159027099609375], [622.2186279296875, 306.6711120605469, 6.8405303955078125], [512.6423950195312, 211.75982666015625, 7.02471923828125], [528.5633544921875, 204.07089233398438, 6.9400634765625], [548.4610595703125, 205.9830780029297, 6.92816162109375], [565.9568481445312, 217.66900634765625, 6.9529266357421875], [548.8089599609375, 222.94613647460938, 6.9491424560546875], [530.2134399414062, 222.75762939453125, 6.9624176025390625], [639.6070556640625, 219.82444763183594, 6.930755615234375], [655.8860473632812, 209.6044158935547, 6.8970184326171875], [676.3201904296875, 208.3985595703125, 6.8957061767578125], [694.9487915039062, 217.1615753173828, 6.9696502685546875], [674.3418579101562, 226.85595703125, 6.9189300537109375], [655.4156494140625, 225.6745147705078, 6.91705322265625], [551.7490234375, 353.2354736328125, 6.971923828125], [564.1500244140625, 346.4883728027344, 6.88177490234375], [583.2034912109375, 344.99609375, 6.8333587646484375], [595.4065551757812, 347.21868896484375, 6.8253173828125], [607.8397216796875, 345.721435546875, 6.82666015625], [629.6182250976562, 348.2886047363281, 6.8668060302734375], [648.6402587890625, 353.0809631347656, 6.940582275390625], [634.0433349609375, 361.12738037109375, 6.8939056396484375], [612.543212890625, 365.1044921875, 6.8557891845703125], [598.9017333984375, 366.5699768066406, 6.8533477783203125], [585.4385375976562, 366.0231018066406, 6.8624725341796875], [566.12353515625, 362.2437744140625, 6.9132232666015625], [553.4495239257812, 352.7164001464844, 6.97503662109375], [583.9151000976562, 355.8670654296875, 6.8811187744140625], [596.3876342773438, 356.340576171875, 6.8712615966796875], [608.99560546875, 356.22100830078125, 6.8746795654296875], [648.081787109375, 352.85076904296875, 6.94110107421875], [612.7412719726562, 351.5333251953125, 6.865570068359375], [598.9871215820312, 351.8242492675781, 6.8616485595703125], [585.3312377929688, 352.4969482421875, 6.87408447265625], [464.1539001464844, 202.29954528808594, 7.4058380126953125], [465.8164978027344, 244.8143768310547, 7.313018798828125], [469.96026611328125, 282.73333740234375, 7.331451416015625], [474.998779296875, 318.5062255859375, 7.377685546875], [485.900390625, 354.82257080078125, 7.34814453125], [503.9440002441406, 389.1557922363281, 7.29644775390625], [533.9607543945312, 420.1808776855469, 7.2111968994140625], [569.1990356445312, 439.69488525390625, 7.0761260986328125], [604.7715454101562, 445.1242370605469, 7.0256805419921875], [641.609130859375, 438.5807189941406, 7.05670166015625], [677.1731567382812, 419.1774597167969, 7.1628265380859375], [709.558349609375, 390.3476867675781, 7.262908935546875], [728.9358520507812, 358.6229553222656, 7.3195648193359375], [743.6824951171875, 323.7010192871094, 7.3823699951171875], [752.355224609375, 286.009033203125, 7.3757171630859375], [756.031494140625, 248.0742645263672, 7.3575439453125], [756.6275634765625, 206.8378448486328, 7.39019775390625]], "keypoints_valid": [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0]], "camera_param": {"focal": [34553.93155415853, 34553.93075942993], "princpt": [605.3033752441406, 358.99560546875]}}], "categories": [{"supercategory": "person", "id": 1, "name": "person"}]}
\ No newline at end of file
diff --git a/tests/test_datasets/test_datasets/test_body_datasets/test_ubody_dataset.py b/tests/test_datasets/test_datasets/test_body_datasets/test_ubody_dataset.py
new file mode 100644
index 0000000000..12f780e1a0
--- /dev/null
+++ b/tests/test_datasets/test_datasets/test_body_datasets/test_ubody_dataset.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import numpy as np
+
+from mmpose.datasets.datasets.body3d import UBody3dDataset
+
+
+class TestUBody3dDataset(TestCase):
+
+ def build_ubody3d_dataset(self, **kwargs):
+
+ cfg = dict(
+ ann_file='ubody3d_train.json',
+ data_mode='topdown',
+ data_root='tests/data/ubody3d',
+ pipeline=[],
+ test_mode=False)
+
+ cfg.update(kwargs)
+ return UBody3dDataset(**cfg)
+
+ def check_data_info_keys(self, data_info: dict):
+ expected_keys = dict(
+ img_paths=list,
+ keypoints=np.ndarray,
+ keypoints_3d=np.ndarray,
+ scale=np.ndarray,
+ center=np.ndarray,
+ id=int)
+
+ for key, type_ in expected_keys.items():
+ self.assertIn(key, data_info)
+ self.assertIsInstance(data_info[key], type_, key)
+
+ def test_metainfo(self):
+ dataset = self.build_ubody3d_dataset()
+ # test dataset_name
+ self.assertEqual(dataset.metainfo['dataset_name'], 'ubody3d')
+
+ # test number of keypoints
+ num_keypoints = 137
+ self.assertEqual(dataset.metainfo['num_keypoints'], num_keypoints)
+ self.assertEqual(
+ len(dataset.metainfo['keypoint_colors']), num_keypoints)
+ self.assertEqual(
+ len(dataset.metainfo['dataset_keypoint_weights']), num_keypoints)
+
+ # test some extra metainfo
+ self.assertEqual(
+ len(dataset.metainfo['skeleton_links']),
+ len(dataset.metainfo['skeleton_link_colors']))
+
+ def test_topdown(self):
+ # test topdown training
+ dataset = self.build_ubody3d_dataset(data_mode='topdown')
+ dataset.full_init()
+ self.assertEqual(len(dataset), 1)
+ self.check_data_info_keys(dataset[0])
+
+ # test topdown testing
+ dataset = self.build_ubody3d_dataset(
+ data_mode='topdown', test_mode=True)
+ dataset.full_init()
+ self.assertEqual(len(dataset), 1)
+ self.check_data_info_keys(dataset[0])
+
+ # test topdown training with sequence config
+ dataset = self.build_ubody3d_dataset(
+ data_mode='topdown',
+ seq_len=1,
+ seq_step=1,
+ causal=False,
+ pad_video_seq=True)
+ dataset.full_init()
+ self.assertEqual(len(dataset), 1)
+ self.check_data_info_keys(dataset[0])
diff --git a/tests/test_datasets/test_transforms/test_converting.py b/tests/test_datasets/test_transforms/test_converting.py
index 5cce813b70..dc4376baf9 100644
--- a/tests/test_datasets/test_transforms/test_converting.py
+++ b/tests/test_datasets/test_transforms/test_converting.py
@@ -81,6 +81,33 @@ def test_transform(self):
self.data_info['keypoints_visible'][:,
source_index]).all())
+ # check 3d keypoint
+ self.data_info['keypoints_3d'] = np.random.random((4, 17, 3))
+ self.data_info['target_idx'] = [-1]
+ mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
+ transform = KeypointConverter(num_keypoints=5, mapping=mapping)
+ results = transform(self.data_info.copy())
+
+ # check shape
+ self.assertEqual(results['keypoints_3d'].shape[0],
+ self.data_info['keypoints_3d'].shape[0])
+ self.assertEqual(results['keypoints_3d'].shape[1], 5)
+ self.assertEqual(results['keypoints_3d'].shape[2], 3)
+ self.assertEqual(results['keypoints_visible'].shape[0],
+ self.data_info['keypoints_visible'].shape[0])
+ self.assertEqual(results['keypoints_visible'].shape[1], 5)
+
+ # check value
+ for source_index, target_index in mapping:
+ self.assertTrue(
+ (results['keypoints_3d'][:, target_index] ==
+ self.data_info['keypoints_3d'][:, source_index]).all())
+ self.assertEqual(results['keypoints_visible'].ndim, 3)
+ self.assertEqual(results['keypoints_visible'].shape[2], 2)
+ self.assertTrue(
+ (results['keypoints_visible'][:, target_index, 0] ==
+ self.data_info['keypoints_visible'][:, source_index]).all())
+
def test_transform_sigmas(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
diff --git a/tools/dataset_converters/ubody_smplx_to_coco.py b/tools/dataset_converters/ubody_smplx_to_coco.py
new file mode 100644
index 0000000000..16f827fce1
--- /dev/null
+++ b/tools/dataset_converters/ubody_smplx_to_coco.py
@@ -0,0 +1,430 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import os
+import os.path as osp
+from functools import partial
+from typing import Dict, List
+
+import mmengine
+import numpy as np
+import smplx
+import torch
+from pycocotools.coco import COCO
+
+
+class SMPLX(object):
+
+ def __init__(self, human_model_path):
+ self.human_model_path = human_model_path
+ self.layer_args = {
+ 'create_global_orient': False,
+ 'create_body_pose': False,
+ 'create_left_hand_pose': False,
+ 'create_right_hand_pose': False,
+ 'create_jaw_pose': False,
+ 'create_leye_pose': False,
+ 'create_reye_pose': False,
+ 'create_betas': False,
+ 'create_expression': False,
+ 'create_transl': False,
+ }
+
+ self.neutral_model = smplx.create(
+ self.human_model_path,
+ 'smplx',
+ gender='NEUTRAL',
+ use_pca=False,
+ use_face_contour=True,
+ **self.layer_args)
+ if torch.cuda.is_available():
+ self.neutral_model = self.neutral_model.to('cuda:0')
+
+ self.vertex_num = 10475
+ self.face = self.neutral_model.faces
+ self.shape_param_dim = 10
+ self.expr_code_dim = 10
+ # 22 (body joints) + 30 (hand joints) + 1 (face jaw joint)
+ self.orig_joint_num = 53
+
+ # yapf: disable
+ self.orig_joints_name = (
+ # 22 body joints
+ 'Pelvis', 'L_Hip', 'R_Hip', 'Spine_1', 'L_Knee', 'R_Knee',
+ 'Spine2', 'L_Ankle', 'R_Ankle', 'Spine_3', 'L_Foot', 'R_Foot',
+ 'Neck', 'L_Collar', 'R_Collar', 'Head', 'L_Shoulder',
+ 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist',
+ # left hand joints
+ 'L_Index_1', 'L_Index_2', 'L_Index_3', 'L_Middle_1', 'L_Middle_2',
+ 'L_Middle_3', 'L_Pinky_1', 'L_Pinky_2', 'L_Pinky_3', 'L_Ring_1',
+ 'L_Ring_2', 'L_Ring_3', 'L_Thumb_1', 'L_Thumb_2', 'L_Thumb_3',
+ # right hand joints
+ 'R_Index_1', 'R_Index_2', 'R_Index_3', 'R_Middle_1', 'R_Middle_2',
+ 'R_Middle_3', 'R_Pinky_1', 'R_Pinky_2', 'R_Pinky_3', 'R_Ring_1',
+ 'R_Ring_2', 'R_Ring_3', 'R_Thumb_1', 'R_Thumb_2', 'R_Thumb_3',
+ # 1 face jaw joint
+ 'Jaw',
+ )
+ self.orig_flip_pairs = (
+ # body joints
+ (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19),
+ (20, 21),
+ # hand joints
+ (22, 37), (23, 38), (24, 39), (25, 40), (26, 41), (27, 42),
+ (28, 43), (29, 44), (30, 45), (31, 46), (32, 47), (33, 48),
+ (34, 49), (35, 50), (36, 51),
+ )
+ # yapf: enable
+ self.orig_root_joint_idx = self.orig_joints_name.index('Pelvis')
+ self.orig_joint_part = {
+ 'body':
+ range(
+ self.orig_joints_name.index('Pelvis'),
+ self.orig_joints_name.index('R_Wrist') + 1),
+ 'lhand':
+ range(
+ self.orig_joints_name.index('L_Index_1'),
+ self.orig_joints_name.index('L_Thumb_3') + 1),
+ 'rhand':
+ range(
+ self.orig_joints_name.index('R_Index_1'),
+ self.orig_joints_name.index('R_Thumb_3') + 1),
+ 'face':
+ range(
+ self.orig_joints_name.index('Jaw'),
+ self.orig_joints_name.index('Jaw') + 1)
+ }
+
+ # changed SMPLX joint set for the supervision
+ self.joint_num = (
+ 137 # 25 (body joints) + 40 (hand joints) + 72 (face keypoints)
+ )
+ # yapf: disable
+ self.joints_name = (
+ # 25 body joints
+ 'Pelvis', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle',
+ 'R_Ankle', 'Neck', 'L_Shoulder', 'R_Shoulder', 'L_Elbow',
+ 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Big_toe', 'L_Small_toe',
+ 'L_Heel', 'R_Big_toe', 'R_Small_toe', 'R_Heel', 'L_Ear', 'R_Ear',
+ 'L_Eye', 'R_Eye', 'Nose',
+ # left hand joints
+ 'L_Thumb_1', 'L_Thumb_2', 'L_Thumb_3', 'L_Thumb4', 'L_Index_1',
+ 'L_Index_2', 'L_Index_3', 'L_Index_4', 'L_Middle_1', 'L_Middle_2',
+ 'L_Middle_3', 'L_Middle_4', 'L_Ring_1', 'L_Ring_2', 'L_Ring_3',
+ 'L_Ring_4', 'L_Pinky_1', 'L_Pinky_2', 'L_Pinky_3', 'L_Pinky_4',
+ # right hand joints
+ 'R_Thumb_1', 'R_Thumb_2', 'R_Thumb_3', 'R_Thumb_4', 'R_Index_1',
+ 'R_Index_2', 'R_Index_3', 'R_Index_4', 'R_Middle_1', 'R_Middle_2',
+ 'R_Middle_3', 'R_Middle_4', 'R_Ring_1', 'R_Ring_2', 'R_Ring_3',
+ 'R_Ring_4', 'R_Pinky_1', 'R_Pinky_2', 'R_Pinky_3', 'R_Pinky_4',
+ # 72 face keypoints
+ *[
+ f'Face_{i}' for i in range(1, 73)
+ ],
+ )
+
+ self.root_joint_idx = self.joints_name.index('Pelvis')
+ self.lwrist_idx = self.joints_name.index('L_Wrist')
+ self.rwrist_idx = self.joints_name.index('R_Wrist')
+ self.neck_idx = self.joints_name.index('Neck')
+ self.flip_pairs = (
+ # body joints
+ (1, 2), (3, 4), (5, 6), (8, 9), (10, 11), (12, 13), (14, 17),
+ (15, 18), (16, 19), (20, 21), (22, 23),
+ # hand joints
+ (25, 45), (26, 46), (27, 47), (28, 48), (29, 49), (30, 50),
+ (31, 51), (32, 52), (33, 53), (34, 54), (35, 55), (36, 56),
+ (37, 57), (38, 58), (39, 59), (40, 60), (41, 61), (42, 62),
+ (43, 63), (44, 64),
+ # face eyebrow
+ (67, 68), (69, 78), (70, 77), (71, 76), (72, 75), (73, 74),
+ # face below nose
+ (83, 87), (84, 86),
+ # face eyes
+ (88, 97), (89, 96), (90, 95), (91, 94), (92, 99), (93, 98),
+ # face mouse
+ (100, 106), (101, 105), (102, 104), (107, 111), (108, 110),
+ # face lip
+ (112, 116), (113, 115), (117, 119),
+ # face contours
+ (120, 136), (121, 135), (122, 134), (123, 133), (124, 132),
+ (125, 131), (126, 130), (127, 129)
+ )
+ self.joint_idx = (
+ 0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 60, 61, 62, 63,
+ 64, 65, 59, 58, 57, 56, 55, # body joints
+ 37, 38, 39, 66, 25, 26, 27, 67, 28, 29, 30, 68, 34, 35, 36, 69, 31,
+ 32, 33, 70, # left hand joints
+ 52, 53, 54, 71, 40, 41, 42, 72, 43, 44, 45, 73, 49, 50, 51, 74, 46,
+ 47, 48, 75, # right hand joints
+ 22, 15, # jaw, head
+ 57, 56, # eyeballs
+ 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, # eyebrow
+ 86, 87, 88, 89, # nose
+ 90, 91, 92, 93, 94, # below nose
+ 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, # eyes
+ 107, # right mouth
+ 108, 109, 110, 111, 112, # upper mouth
+ 113, # left mouth
+ 114, 115, 116, 117, 118, # lower mouth
+ 119, # right lip
+ 120, 121, 122, # upper lip
+ 123, # left lip
+ 124, 125, 126, # lower lip
+ 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
+ 140, 141, 142, 143, # face contour
+ )
+ # yapf: enable
+
+ self.joint_part = {
+ 'body':
+ range(
+ self.joints_name.index('Pelvis'),
+ self.joints_name.index('Nose') + 1),
+ 'lhand':
+ range(
+ self.joints_name.index('L_Thumb_1'),
+ self.joints_name.index('L_Pinky_4') + 1),
+ 'rhand':
+ range(
+ self.joints_name.index('R_Thumb_1'),
+ self.joints_name.index('R_Pinky_4') + 1),
+ 'hand':
+ range(
+ self.joints_name.index('L_Thumb_1'),
+ self.joints_name.index('R_Pinky_4') + 1),
+ 'face':
+ range(
+ self.joints_name.index('Face_1'),
+ self.joints_name.index('Face_72') + 1)
+ }
+
+
+def read_annotation_file(annotation_file: str) -> List[Dict]:
+ with open(annotation_file, 'r') as f:
+ annotations = json.load(f)
+ return annotations
+
+
+def cam2pixel(cam_coord, f, c):
+ x = cam_coord[:, 0] / cam_coord[:, 2] * f[0] + c[0]
+ y = cam_coord[:, 1] / cam_coord[:, 2] * f[1] + c[1]
+ z = cam_coord[:, 2]
+ return np.stack((x, y, z), 1)
+
+
+def process_scene_anno(scene: str, annotation_root: str, splits: np.array,
+ human_model_path: str):
+ annos = read_annotation_file(
+ osp.join(annotation_root, scene, 'smplx_annotation.json'))
+ keypoint_annos = COCO(
+ osp.join(annotation_root, scene, 'keypoint_annotation.json'))
+ human_model = SMPLX(human_model_path)
+
+ train_annos = []
+ val_annos = []
+ train_imgs = []
+ val_imgs = []
+
+ progress_bar = mmengine.ProgressBar(len(keypoint_annos.anns.keys()))
+ for aid in keypoint_annos.anns.keys():
+ ann = keypoint_annos.anns[aid]
+ img = keypoint_annos.loadImgs(ann['image_id'])[0]
+ if img['file_name'].startswith('/'):
+ file_name = img['file_name'][1:]
+ else:
+ file_name = img['file_name']
+
+ video_name = file_name.split('/')[-2]
+ if 'Trim' in video_name:
+ video_name = video_name.split('_Trim')[0]
+
+ img_path = os.path.join(
+ annotation_root.replace('annotations', 'images'), scene, file_name)
+ if not os.path.exists(img_path):
+ progress_bar.update()
+ continue
+ if str(aid) not in annos:
+ progress_bar.update()
+ continue
+
+ smplx_param = annos[str(aid)]
+ human_model_param = smplx_param['smplx_param']
+ cam_param = smplx_param['cam_param']
+ if 'lhand_valid' not in human_model_param:
+ human_model_param['lhand_valid'] = ann['lefthand_valid']
+ human_model_param['rhand_valid'] = ann['righthand_valid']
+ human_model_param['face_valid'] = ann['face_valid']
+
+ rotation_valid = np.ones((human_model.orig_joint_num),
+ dtype=np.float32)
+ coord_valid = np.ones((human_model.joint_num), dtype=np.float32)
+
+ root_pose = human_model_param['root_pose']
+ body_pose = human_model_param['body_pose']
+ shape = human_model_param['shape']
+ trans = human_model_param['trans']
+
+ if 'lhand_pose' in human_model_param and human_model_param.get(
+ 'lhand_valid', False):
+ lhand_pose = human_model_param['lhand_pose']
+ else:
+ lhand_pose = np.zeros(
+ (3 * len(human_model.orig_joint_part['lhand'])),
+ dtype=np.float32)
+ rotation_valid[human_model.orig_joint_part['lhand']] = 0
+ coord_valid[human_model.orig_joint_part['lhand']] = 0
+
+ if 'rhand_pose' in human_model_param and human_model_param.get(
+ 'rhand_valid', False):
+ rhand_pose = human_model_param['rhand_pose']
+ else:
+ rhand_pose = np.zeros(
+ (3 * len(human_model.orig_joint_part['rhand'])),
+ dtype=np.float32)
+ rotation_valid[human_model.orig_joint_part['rhand']] = 0
+ coord_valid[human_model.orig_joint_part['rhand']] = 0
+
+ if 'jaw_pose' in human_model_param and \
+ 'expr' in human_model_param and \
+ human_model_param.get('face_valid', False):
+ jaw_pose = human_model_param['jaw_pose']
+ expr = human_model_param['expr']
+ else:
+ jaw_pose = np.zeros((3), dtype=np.float32)
+ expr = np.zeros((human_model.expr_code_dim), dtype=np.float32)
+ rotation_valid[human_model.orig_joint_part['face']] = 0
+ coord_valid[human_model.orig_joint_part['face']] = 0
+
+ # init human model inputs
+ device = torch.device(
+ 'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ root_pose = torch.FloatTensor(root_pose).to(device).view(1, 3)
+ body_pose = torch.FloatTensor(body_pose).to(device).view(-1, 3)
+ lhand_pose = torch.FloatTensor(lhand_pose).to(device).view(-1, 3)
+ rhand_pose = torch.FloatTensor(rhand_pose).to(device).view(-1, 3)
+ jaw_pose = torch.FloatTensor(jaw_pose).to(device).view(-1, 3)
+ shape = torch.FloatTensor(shape).to(device).view(1, -1)
+ expr = torch.FloatTensor(expr).to(device).view(1, -1)
+ trans = torch.FloatTensor(trans).to(device).view(1, -1)
+ zero_pose = torch.zeros((1, 3), dtype=torch.float32, device=device)
+
+ with torch.no_grad():
+ output = human_model.neutral_model(
+ betas=shape,
+ body_pose=body_pose.view(1, -1),
+ global_orient=root_pose,
+ transl=trans,
+ left_hand_pose=lhand_pose.view(1, -1),
+ right_hand_pose=rhand_pose.view(1, -1),
+ jaw_pose=jaw_pose.view(1, -1),
+ leye_pose=zero_pose,
+ reye_pose=zero_pose,
+ expression=expr)
+
+ joint_cam = output.joints[0].cpu().numpy()[human_model.joint_idx, :]
+ joint_img = cam2pixel(joint_cam, cam_param['focal'],
+ cam_param['princpt'])
+
+ joint_cam = (joint_cam - joint_cam[human_model.root_joint_idx, None, :]
+ ) # root-relative
+ joint_cam[human_model.joint_part['lhand'], :] = (
+ joint_cam[human_model.joint_part['lhand'], :] -
+ joint_cam[human_model.lwrist_idx, None, :]
+ ) # left hand root-relative
+ joint_cam[human_model.joint_part['rhand'], :] = (
+ joint_cam[human_model.joint_part['rhand'], :] -
+ joint_cam[human_model.rwrist_idx, None, :]
+ ) # right hand root-relative
+ joint_cam[human_model.joint_part['face'], :] = (
+ joint_cam[human_model.joint_part['face'], :] -
+ joint_cam[human_model.neck_idx, None, :]) # face root-relative
+
+ body_3d_size = 2
+ output_hm_shape = (16, 16, 12)
+ joint_img[human_model.joint_part['body'],
+ 2] = ((joint_cam[human_model.joint_part['body'], 2].copy() /
+ (body_3d_size / 2) + 1) / 2.0 * output_hm_shape[0])
+ joint_img[human_model.joint_part['lhand'],
+ 2] = ((joint_cam[human_model.joint_part['lhand'], 2].copy() /
+ (body_3d_size / 2) + 1) / 2.0 * output_hm_shape[0])
+ joint_img[human_model.joint_part['rhand'],
+ 2] = ((joint_cam[human_model.joint_part['rhand'], 2].copy() /
+ (body_3d_size / 2) + 1) / 2.0 * output_hm_shape[0])
+ joint_img[human_model.joint_part['face'],
+ 2] = ((joint_cam[human_model.joint_part['face'], 2].copy() /
+ (body_3d_size / 2) + 1) / 2.0 * output_hm_shape[0])
+
+ keypoints_2d = joint_img[:, :2].copy()
+ keypoints_3d = joint_img.copy()
+ keypoints_valid = coord_valid.reshape((-1, 1))
+
+ ann['keypoints'] = keypoints_2d.tolist()
+ ann['keypoints_3d'] = keypoints_3d.tolist()
+ ann['keypoints_valid'] = keypoints_valid.tolist()
+ ann['camera_param'] = cam_param
+ img['file_name'] = os.path.join(scene, file_name)
+ if video_name in splits:
+ val_annos.append(ann)
+ val_imgs.append(img)
+ else:
+ train_annos.append(ann)
+ train_imgs.append(img)
+ progress_bar.update()
+
+ categories = [{
+ 'supercategory': 'person',
+ 'id': 1,
+ 'name': 'person',
+ 'keypoints': human_model.joints_name,
+ 'skeleton': human_model.flip_pairs
+ }]
+ train_data = {
+ 'images': train_imgs,
+ 'annotations': train_annos,
+ 'categories': categories
+ }
+ val_data = {
+ 'images': val_imgs,
+ 'annotations': val_annos,
+ 'categories': categories
+ }
+
+ mmengine.dump(
+ train_data,
+ osp.join(annotation_root, scene, 'train_3dkeypoint_annotation.json'))
+ mmengine.dump(
+ val_data,
+ osp.join(annotation_root, scene, 'val_3dkeypoint_annotation.json'))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data-root', type=str, default='data/UBody')
+ parser.add_argument('--human-model-path', type=str, default='data/SMPLX')
+ parser.add_argument(
+ '--nproc', default=8, type=int, help='number of process')
+ args = parser.parse_args()
+
+ split_path = f'{args.data_root}/splits/intra_scene_test_list.npy'
+ annotation_path = f'{args.data_root}/annotations'
+
+ folders = os.listdir(annotation_path)
+ folders = [f for f in folders if osp.isdir(osp.join(annotation_path, f))]
+ human_model_path = args.human_model_path
+ splits = np.load(split_path)
+
+ if args.nproc > 1:
+ mmengine.track_parallel_progress(
+ partial(
+ process_scene_anno,
+ annotation_root=annotation_path,
+ splits=splits,
+ human_model_path=human_model_path), folders, args.nproc)
+ else:
+ mmengine.track_progress(
+ partial(
+ process_scene_anno,
+ annotation_root=annotation_path,
+ splits=splits,
+ human_model_path=human_model_path), folders)