Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Face Detection Task (task-a-thon) (#606)
Browse files Browse the repository at this point in the history
* .

* merging taskathon PR code

* working

* pep8

* imports

* backbones registry

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests

* tests

* more coverage

* final

* .

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* .

* .

* .

* Update flash/image/face_detection/model.py

Co-authored-by: Sean Naren <sean@grid.ai>

* comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* .

* .

* added comments to clearfy some steps in the face detection task

* imports

* .

* .

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* .

* tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* .

* .

Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sean Naren <sean@grid.ai>
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
5 people authored Sep 30, 2021
1 parent 892f759 commit 538c972
Show file tree
Hide file tree
Showing 14 changed files with 544 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

- Added `FastFace` integration ([#606](https://github.com/PyTorchLightning/lightning-flash/pull/606))

- Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805))

### Changed
Expand Down
14 changes: 14 additions & 0 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os.path
import tarfile
import zipfile
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Type

Expand Down Expand Up @@ -148,10 +149,23 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str):
if os.path.exists(file_path):
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
raise PermissionError(f"Could not extract tar file {file_path}")

if ".zip" in local_filename:
if os.path.exists(local_filename):
with zipfile.ZipFile(local_filename, "r") as zip_ref:
zip_ref.extractall(path)
elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"):
extract_tarfile(local_filename, path, "r:gz")
elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"):
extract_tarfile(local_filename, path, "r:bz2")


def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down
1 change: 1 addition & 0 deletions flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401
from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401
from flash.image.embedding import ImageEmbedder # noqa: F401
from flash.image.face_detection import FaceDetectionData, FaceDetector # noqa: F401
from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401
from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401
from flash.image.segmentation import ( # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions flash/image/face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.image.face_detection.data import FaceDetectionData # noqa: F401
from flash.image.face_detection.model import FaceDetector # noqa: F401
5 changes: 5 additions & 0 deletions flash/image/face_detection/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.face_detection.backbones.fastface_backbones import register_ff_backbones # noqa: F401

FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones")
register_ff_backbones(FACE_DETECTION_BACKBONES)
44 changes: 44 additions & 0 deletions flash/image/face_detection/backbones/fastface_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FASTFACE_AVAILABLE

if _FASTFACE_AVAILABLE:
import fastface as ff

_MODEL_NAMES = ff.list_pretrained_models()
else:
_MODEL_NAMES = []


def fastface_backbone(model_name: str, pretrained: bool, **kwargs):
if pretrained:
pl_model = ff.FaceDetector.from_pretrained(model_name, **kwargs)
else:
arch, config = model_name.split("_")
pl_model = ff.FaceDetector.build(arch, config, **kwargs)

backbone = getattr(pl_model, "arch")

return backbone, pl_model


def register_ff_backbones(register: FlashRegistry):
if _FASTFACE_AVAILABLE:
backbones = [partial(fastface_backbone, model_name=name) for name in _MODEL_NAMES]

for idx, backbone in enumerate(backbones):
register(backbone, name=_MODEL_NAMES[idx])
172 changes: 172 additions & 0 deletions flash/image/face_detection/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource
from flash.image.detection import ObjectDetectionData

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import default_loader

if _FASTFACE_AVAILABLE:
import fastface as ff


def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]:
"""Collate function from fastface.
Organizes individual elements in a batch, calls prepare_batch from fastface and prepares the targets.
"""
samples = {key: [sample[key] for sample in samples] for key in samples[0]}

images, scales, paddings = ff.utils.preprocess.prepare_batch(
samples[DefaultDataKeys.INPUT], None, adaptive_batch=True
)

samples["scales"] = scales
samples["paddings"] = paddings

if DefaultDataKeys.TARGET in samples.keys():
targets = samples[DefaultDataKeys.TARGET]
targets = [{"target_boxes": target["boxes"]} for target in targets]

for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)):
target["target_boxes"] *= scale
target["target_boxes"][:, [0, 2]] += padding[0]
target["target_boxes"][:, [1, 3]] += padding[1]
targets[i]["target_boxes"] = target["target_boxes"]

samples[DefaultDataKeys.TARGET] = targets
samples[DefaultDataKeys.INPUT] = images

return samples


class FastFaceDataSource(DatasetDataSource):
"""Logic for loading from FDDBDataset."""

def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
new_data = []
for img_file_path, targets in zip(data.ids, data.targets):
new_data.append(
super().load_sample(
(
img_file_path,
dict(
boxes=targets["target_boxes"],
# label `1` indicates positive sample
labels=[1 for _ in range(targets["target_boxes"].shape[0])],
),
)
)
)

return new_data

def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
filepath = sample[DefaultDataKeys.INPUT]
img = default_loader(filepath)
sample[DefaultDataKeys.INPUT] = img

w, h = img.size # WxH
sample[DefaultDataKeys.METADATA] = {
"filepath": filepath,
"size": (h, w),
}

return sample


class FaceDetectionPreprocess(Preprocess):
"""Applies default transform and collate_fn for fastface on FastFaceDataSource."""

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (128, 128),
):
self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.DATASETS: FastFaceDataSource(),
},
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Dict[str, Callable]:
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(
DefaultDataKeys.TARGET,
nn.Sequential(
ApplyToKeys("boxes", torch.as_tensor),
ApplyToKeys("labels", torch.as_tensor),
),
),
),
"collate": fastface_collate_fn,
}


class FaceDetectionPostProcess(Postprocess):
"""Generates preds from model output."""

@staticmethod
def per_batch_transform(batch: Any) -> Any:
scales = batch["scales"]
paddings = batch["paddings"]

batch.pop("scales", None)
batch.pop("paddings", None)

preds = batch[DefaultDataKeys.PREDS]

# preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
batch[DefaultDataKeys.PREDS] = preds

return batch


class FaceDetectionData(ObjectDetectionData):
preprocess_cls = FaceDetectionPreprocess
postprocess_cls = FaceDetectionPostProcess
Loading

0 comments on commit 538c972

Please sign in to comment.