Skip to content

Commit

Permalink
add photo clustering demo (PaddlePaddle#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 22, 2022
1 parent 5e3c5f8 commit 2204308
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 26 deletions.
45 changes: 19 additions & 26 deletions plsc/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,12 @@
import numpy as np


def default_preprocess_fn(img,
scale=1.0 / 255.0,
mean=0.5,
std=0.5,
swap_rb=True):
img = (img.astype('float32') * scale - mean) / std
if swap_rb:
img = img[:, :, ::-1]
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, 0)
return img


class Inference(object):
class Predictor(object):
def __init__(self,
model_type='paddle',
model_file=None,
params_file=None,
preprocess_fn=default_preprocess_fn,
preprocess_fn=None,
postprocess_fn=None):

assert model_type in ['paddle', 'onnx']
Expand All @@ -51,12 +38,8 @@ def __init__(self,
config = paddle_infer.Config(model_file, params_file)
self.predictor = paddle_infer.create_predictor(config)

input_names = self.predictor.get_input_names()
self.input_handle = self.predictor.get_input_handle(input_names[0])

output_names = self.predictor.get_output_names()
self.output_handle = self.predictor.get_output_handle(output_names[
0])
self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names()

elif model_type == 'onnx':
assert model_file is not None and os.path.splitext(model_file)[
Expand All @@ -81,17 +64,27 @@ def __init__(self,
def predict(self, img):

if self.preprocess_fn is not None:
img = self.preprocess_fn(img)
inputs = self.preprocess_fn(img)
else:
inputs = img

if self.model_type == 'paddle':
self.input_handle.copy_from_cpu(img)
for input_name in self.input_names:
input_tensor = self.predictor.get_input_handle(input_name)
input_tensor.copy_from_cpu(inputs[input_name])
self.predictor.run()
output_data = self.output_handle.copy_to_cpu()
outputs = []
for output_idx in range(len(self.output_names)):
output_tensor = self.predictor.get_output_handle(
self.output_names[output_idx])
outputs.append(output_tensor.copy_to_cpu())

elif self.model_type == 'onnx':
output_data = self.predictor.run(None, {self.input_name: img})
outputs = self.predictor.run(None, inputs)

if self.postprocess_fn is not None:
output_data = self.postprocess_fn(output_data)
output_data = self.postprocess_fn(outputs)
else:
output_data = outputs

return output_data
266 changes: 266 additions & 0 deletions task/recognition/face/photo_clustering.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "8162093e-99cb-4646-93be-9e3a63eecc84",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import shutil\n",
"import os\n",
"import glob\n",
"import numpy as np\n",
"import cv2\n",
"from functools import partial\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from io import BytesIO\n",
"import IPython\n",
"from sklearn.cluster import DBSCAN\n",
"\n",
"from plsc.engine.inference import Predictor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cd617b7-4ace-4d03-ac67-ed6c65b96e00",
"metadata": {},
"outputs": [],
"source": [
"# Download models and assets\n",
"!mkdir -p models\n",
"if not os.path.exists('models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel'):\n",
" !wget https://paddle-model-ecology.bj.bcebos.com/model/insight-face/blazeface_fpn_ssh_1000e_v1.0_infer.tar -P models/\n",
" !tar -xzf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar -C models/\n",
" !rm -rf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar\n",
" \n",
"if not os.path.exists('models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel'):\n",
" !wget https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_infer.tgz -P models/\n",
" !tar -xzf models/FaceViT_tiny_patch9_112_infer.tgz -C models/\n",
" !rm -rf models/FaceViT_tiny_patch9_112_infer.tgz\n",
" \n",
"if not os.path.exists('images'):\n",
" !mkdir -p images\n",
" !wget https://plsc.bj.bcebos.com/dataset/BigBang.tgz -P images\n",
" !tar -xzf images/BigBang.tgz --strip-components 1 -C images\n",
" !rm -rf images/BigBang.tgz"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee26a767-0870-4059-86c4-d855f38434fe",
"metadata": {},
"outputs": [],
"source": [
"def draw(img, box_list):\n",
" im = Image.fromarray(img)\n",
" draw = ImageDraw.Draw(im)\n",
"\n",
" for i, dt in enumerate(box_list):\n",
" bbox, score = dt[2:], dt[1]\n",
" color = 'red'\n",
"\n",
" xmin, ymin, xmax, ymax = bbox\n",
" draw.rectangle(\n",
" [(xmin, ymin), (xmax, ymax)], width=2, outline=color)\n",
" return im\n",
"\n",
"def display_img_array(img):\n",
" bio = BytesIO()\n",
" img.save(bio, format='png')\n",
" IPython.display.display(IPython.display.Image(bio.getvalue(), format='png'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce6ef4da-d1f7-4f3a-9a25-52b29cf7aa39",
"metadata": {},
"outputs": [],
"source": [
"def facedetect_preprocess_fn(img, target_size=[640, 640]):\n",
" resize_h, resize_w = target_size\n",
" img_shape = img.shape\n",
" img_scale_x = resize_w / img_shape[1]\n",
" img_scale_y = resize_h / img_shape[0]\n",
" img = cv2.resize(\n",
" img, None, None, fx=img_scale_x, fy=img_scale_y, interpolation=1)\n",
" \n",
" scale = 1. / 255.\n",
" mean = np.array([[[0.485, 0.456, 0.406]]])\n",
" std = np.array([[[0.229, 0.224, 0.225]]])\n",
"\n",
" img = (img.astype('float32') * scale - mean) / std\n",
" img_info = {}\n",
" img_info[\"im_shape\"] = np.array(\n",
" img.shape[:2], dtype=np.float32)[np.newaxis, :]\n",
" img_info[\"scale_factor\"] = np.array(\n",
" [img_scale_y, img_scale_x], dtype=np.float32)[np.newaxis, :]\n",
"\n",
" img = img.transpose((2, 0, 1)).copy()\n",
" img_info[\"image\"] = img[np.newaxis, :, :, :].astype(np.float32)\n",
" return img_info\n",
"\n",
"def facedetect_postprocess_fn(outputs, thresh=0.8):\n",
" np_boxes = outputs[0]\n",
" expect_boxes = (np_boxes[:, 1] > thresh) & (np_boxes[:, 0] > -1)\n",
" return np_boxes[expect_boxes, :]\n",
"\n",
"face_detector = Predictor(\n",
" model_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel',\n",
" params_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdiparams',\n",
" preprocess_fn=facedetect_preprocess_fn,\n",
" postprocess_fn=facedetect_postprocess_fn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94d384f8-da7b-4dc8-8bd0-2e6ce9c63b01",
"metadata": {},
"outputs": [],
"source": [
"def facerecog_preprocess_fn(img):\n",
" scale = 1.0 / 255.0\n",
" mean = 0.5\n",
" std = 0.5\n",
" img = (img.astype('float32') * scale - mean) / std\n",
" img = img[:, :, ::-1]\n",
" img = img.transpose((0, 3, 1, 2))\n",
"\n",
" return {'inputs': img}\n",
"\n",
"def crop_face(img, box_list):\n",
" batch = []\n",
" for idx, box in enumerate(box_list):\n",
" box[box < 0] = 0\n",
" xmin, ymin, xmax, ymax = list(map(int, box[2:]))\n",
" w = xmax - xmin + 1\n",
" h = ymax - ymin + 1\n",
" radius = int(round(max(h, w) / 2.0))\n",
" cx = int(round((xmax + xmin) / 2.0))\n",
" cy = int(round((ymax + ymin) / 2.0))\n",
" xmin = cx - radius\n",
" xmax = cx + radius\n",
" ymin = cy - radius\n",
" ymax = cy + radius\n",
" \n",
" face_img = img[ymin:ymax, xmin:xmax, :]\n",
" face_img = cv2.resize(face_img, (112, 112)).copy()\n",
" batch.append(face_img)\n",
" return np.stack(batch)\n",
"\n",
"face_recog = Predictor(\n",
" model_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel',\n",
" params_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdiparams',\n",
" preprocess_fn=facerecog_preprocess_fn,\n",
" postprocess_fn=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eeea0e11-c565-44ee-a10c-395c8c59097c",
"metadata": {},
"outputs": [],
"source": [
"feats_list = []\n",
"fileid_list = []\n",
"boxes_list = []\n",
"\n",
"filenames = glob.glob('images/*.png')\n",
"for idx, filename in enumerate(filenames):\n",
" img = cv2.imread(filename)\n",
" boxes = face_detector.predict(img)\n",
"\n",
" faces = crop_face(img, boxes)\n",
" feats = face_recog.predict(faces)\n",
" \n",
" feats_list.append(feats[0])\n",
" fileid = np.empty(faces.shape[0], dtype=np.int32)\n",
" fileid.fill(idx)\n",
" fileid_list.append(fileid)\n",
" boxes_list.append(boxes)\n",
" \n",
"face_feat = np.concatenate(feats_list, axis=0)\n",
"face_file = np.concatenate(fileid_list, axis=0)\n",
"face_boxes = np.concatenate(boxes_list, axis=0)\n",
"\n",
"X = face_feat / np.linalg.norm(face_feat, axis=-1, keepdims=True)\n",
"\n",
"db = DBSCAN(eps=0.5, min_samples=2, metric=\"cosine\").fit(X) ##metric默认是欧式距离\n",
"core_samples_mask = np.zeros_like(db.labels_, dtype=bool)\n",
"core_samples_mask[db.core_sample_indices_] = True\n",
"labels = db.labels_"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2c8f5a7-73a6-4341-b794-f11f35960ec6",
"metadata": {},
"outputs": [],
"source": [
"show_image = True\n",
"copy_image = False\n",
"\n",
"clusters = set(labels)\n",
"output_root = 'clusters'\n",
"for clusters_id in clusters:\n",
" # noise cluster\n",
" # if int(clusters_id) == -1:\n",
" # continue\n",
" face_idx = np.where(labels == clusters_id)\n",
" \n",
" sel_fileids = face_file[face_idx]\n",
" sel_boxes = face_boxes[face_idx]\n",
" print()\n",
" print('='*20, f'face id {clusters_id}', '='*20)\n",
" for idx in range(sel_fileids.shape[0]):\n",
" filename = filenames[sel_fileids[idx]]\n",
" img = cv2.imread(filename)\n",
" img_drawed = draw(img[:,:,::-1], [sel_boxes[idx]])\n",
" \n",
" if show_image:\n",
" display_img_array(img_drawed)\n",
"\n",
" if copy_image:\n",
" output_dir = os.path.join(output_root, str(clusters_id))\n",
" if not os.path.exists(output_dir):\n",
" os.makedirs(output_dir)\n",
" shutil.copyfile(filename, os.path.join(output_dir, filename.split('/')[-1]))\n",
"\n",
" if idx == 0:\n",
" cropped = crop_face(img, [sel_boxes[idx]])[0]\n",
" cv2.imwrite(os.path.join(output_dir, 'thumbnail.png'), cropped)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 2204308

Please sign in to comment.