Skip to content

Commit

Permalink
Merge pull request #4 from cerebrai/main
Browse files Browse the repository at this point in the history
Update factory.py to work for batch of images
  • Loading branch information
developer0hye authored Nov 24, 2024
2 parents dcb645f + 9bd9a2d commit 7f78d0f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 32 deletions.
104 changes: 74 additions & 30 deletions onepose/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import cv2
import torch
import torch.nn as nn
import onepose.models.vitpose as vitpose
from PIL import Image

from onepose.utils import read_cfg, download_weights
from onepose.transforms import ComposeTransforms, BGR2RGB, TopDownAffine, ToTensor, NormalizeTensor, _box2cs
from onepose.functional import keypoints_from_heatmaps
Expand Down Expand Up @@ -73,12 +76,12 @@
}

class Model(nn.Module):
def __init__(self,
def __init__(self,
model_name: str = 'ViTPose_huge_simple_coco') -> None:
super().__init__()

file_path = pathlib.Path(os.path.abspath(__file__)).parent

self.model_cfg = read_cfg(os.path.join(file_path, 'configs', model_config[model_name]['model_cfg']))
self.model = vitpose.ViTPose(self.model_cfg.model)

Expand All @@ -93,53 +96,94 @@ def __init__(self,
weights_folder = os.path.join(file_path, 'weights')
os.makedirs(weights_folder, exist_ok=True)
ckpt = os.path.join(weights_folder, model_config[model_name]['url'].split('/')[-1])
download_weights(model_config[model_name]['url'],
ckpt,
download_weights(model_config[model_name]['url'],
ckpt,
model_config[model_name]['hash'])
self.model.load_state_dict(torch.load(ckpt, map_location='cpu'))
self.model.eval()

dataset_cfg = read_cfg(os.path.join(file_path.parent, 'datasets', model_config[model_name]['dataset_cfg']))
self.keypoint_info = dataset_cfg.dataset_info['keypoint_info']
self.skeleton_info = dataset_cfg.dataset_info['skeleton_info']

@torch.no_grad()
@torch.inference_mode()
def forward(self, x: np.ndarray) -> np.ndarray:
def forward(self, x: Union[np.ndarray, Image.Image, List[Union[np.ndarray, Image.Image]]]) -> Union[Dict, List[Dict]]:
if self.training:
self.eval()

device = next(self.parameters()).device

img_height, img_width = x.shape[:2]
center, scale = _box2cs(self.model_cfg.data_cfg['image_size'], [0, 0, img_width, img_height])
single_image = False
# Input validation and conversion
if isinstance(x, list):
if not x: # empty list check
raise ValueError("Input list cannot be empty")
if not all(isinstance(img, (np.ndarray, Image.Image)) for img in x):
raise TypeError("All elements in the list must be either numpy arrays or PIL Images")
# Convert PIL images to numpy arrays with BGR color space
x = [cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) if isinstance(img, Image.Image) else img for img in x]
elif isinstance(x, (np.ndarray, Image.Image)):
if isinstance(x, Image.Image):
x = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR)
x = [x]
single_image = True
else:
raise TypeError("Input must be either a numpy array, PIL Image, or a list of them")

# Convert grayscale images to BGR
for i, img in enumerate(x):
if img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1):
x[i] = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

results = {'img': x,
'rotation': 0,
'center': center,
'scale': scale,
'image_size': np.array(self.model_cfg.data_cfg['image_size']),
}
batch_results = []
for img in x:
img_height, img_width = img.shape[:2]
center, scale = _box2cs(self.model_cfg.data_cfg['image_size'],
[0, 0, img_width, img_height])

results = self.transforms(results)
results['img'] = results['img'].to(device)
results = {
'img': img,
'rotation': 0,
'center': center,
'scale': scale,
'image_size': np.array(self.model_cfg.data_cfg['image_size']),
}

results = self.transforms(results)
batch_results.append(results['img'])

# Stack transformed images into a batch
batch_tensor = torch.stack(batch_results).to(device)

out = self.model(results['img'][None, ...])
# Forward pass
out = self.model(batch_tensor)
out = out.cpu().numpy()

out, maxvals = keypoints_from_heatmaps(out,
center=[center],
scale=[scale],
unbiased=False,
post_process='default',
kernel=11,
valid_radius_factor=0.0546875,
use_udp=self.use_udp,
target_type='GaussianHeatmap')
out = out[0]
maxvals = maxvals[0]
out = {'points': out, 'confidence': maxvals}
return out
# Process each image's predictions
centers = [_box2cs(self.model_cfg.data_cfg['image_size'],
[0, 0, img.shape[1], img.shape[0]])[0] for img in x]
scales = [_box2cs(self.model_cfg.data_cfg['image_size'],
[0, 0, img.shape[1], img.shape[0]])[1] for img in x]

points, maxvals = keypoints_from_heatmaps(
out,
center=centers,
scale=scales,
unbiased=False,
post_process='default',
kernel=11,
valid_radius_factor=0.0546875,
use_udp=self.use_udp,
target_type='GaussianHeatmap'
)

outputs = [{'points': p, 'confidence': c} for p, c in zip(points, maxvals)]

# Return single result for single image input
if single_image:
return outputs[0]
return outputs

def create_model(model_name: str = 'ViTPose_huge_simple_coco') -> Model:
model = Model(model_name=model_name)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
setup(
name='onepose',
version='1.0',
install_requires=['opencv-python', 'torch', 'torchvision', 'tqdm', 'numpy'],
install_requires=['opencv-python', 'torch', 'torchvision', 'tqdm', 'numpy', 'Pillow'],
packages=find_packages(exclude='notebooks')
)
)

0 comments on commit 7f78d0f

Please sign in to comment.