Skip to content

Commit

Permalink
Merge pull request #57 from BloodAxe/develop
Browse files Browse the repository at this point in the history
PyTorch Toolbelt 0.4.3
  • Loading branch information
BloodAxe authored Apr 2, 2021
2 parents a04e28b + d8e2a30 commit f3acfca
Show file tree
Hide file tree
Showing 24 changed files with 700 additions and 536 deletions.
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.4.2"
__version__ = "0.4.3"
1 change: 1 addition & 0 deletions pytorch_toolbelt/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .common import *
from .classification import *
from .segmentation import *
from .wrappers import *
32 changes: 14 additions & 18 deletions pytorch_toolbelt/datasets/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import cv2

__all__ = [
"IGNORE_LABEL",
"INPUT_IMAGE_ID_KEY",
"INPUT_IMAGE_KEY",
"INPUT_INDEX_KEY",
"INPUT_MASK_16_KEY",
"INPUT_MASK_32_KEY",
"INPUT_MASK_4_KEY",
"INPUT_MASK_64_KEY",
"INPUT_MASK_8_KEY",
"OUTPUT_EMBEDDINGS_KEY",
"OUTPUT_LOGITS_KEY",
"OUTPUT_MASK_16_KEY",
Expand All @@ -19,15 +15,19 @@
"OUTPUT_MASK_KEY",
"TARGET_CLASS_KEY",
"TARGET_LABELS_KEY",
"TARGET_MASK_16_KEY",
"TARGET_MASK_2_KEY",
"TARGET_MASK_32_KEY",
"TARGET_MASK_4_KEY",
"TARGET_MASK_64_KEY",
"TARGET_MASK_8_KEY",
"TARGET_MASK_KEY",
"TARGET_MASK_WEIGHT_KEY",
"UNLABELED_SAMPLE",
"name_for_stride",
"read_image_rgb",
]

# Smaller masks for deep supervision

def name_for_stride(name, stride: int):
return f"{name}_{stride}"

Expand All @@ -36,18 +36,17 @@ def name_for_stride(name, stride: int):
INPUT_IMAGE_KEY = "image"
INPUT_IMAGE_ID_KEY = "image_id"

TARGET_MASK_KEY = "true_mask"
TARGET_MASK_WEIGHT_KEY = "true_weights"
TARGET_CLASS_KEY = "true_class"
TARGET_LABELS_KEY = "true_labels"


TARGET_MASK_KEY = "true_mask"
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
INPUT_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
INPUT_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
INPUT_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
INPUT_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
INPUT_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)
TARGET_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
TARGET_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
TARGET_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
TARGET_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
TARGET_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)

OUTPUT_MASK_KEY = "pred_mask"
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
Expand All @@ -60,9 +59,6 @@ def name_for_stride(name, stride: int):
OUTPUT_LOGITS_KEY = "pred_logits"
OUTPUT_EMBEDDINGS_KEY = "pred_embeddings"

UNLABELED_SAMPLE = 127
IGNORE_LABEL = 255


def read_image_rgb(fname: str):
image = cv2.imread(fname)[..., ::-1]
Expand Down
63 changes: 36 additions & 27 deletions pytorch_toolbelt/datasets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
TARGET_MASK_WEIGHT_KEY,
TARGET_MASK_KEY,
name_for_stride,
UNLABELED_SAMPLE,
)
from ..utils import fs, image_to_tensor

__all__ = ["mask_to_bce_target", "mask_to_ce_target", "SegmentationDataset", "compute_weight_mask"]
__all__ = ["mask_to_bce_target", "mask_to_ce_target", "read_binary_mask", "SegmentationDataset", "compute_weight_mask"]


def mask_to_bce_target(mask):
Expand Down Expand Up @@ -62,8 +61,21 @@ def _block_reduce_dominant_label(x: np.ndarray, axis):


def read_binary_mask(mask_fname: str) -> np.ndarray:
mask = cv2.imread(mask_fname, cv2.IMREAD_COLOR)
return cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY, dst=mask)
"""
Read image as binary mask, all non-zero values are treated as positive labels and converted to 1
Args:
mask_fname: Image with mask
Returns:
Numpy array with {0,1} values
"""

mask = cv2.imread(mask_fname, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Cannot find {mask_fname}")

cv2.threshold(mask, thresh=0, maxval=1, type=cv2.THRESH_BINARY, dst=mask)
return mask


class SegmentationDataset(Dataset):
Expand All @@ -81,11 +93,16 @@ def __init__(
need_weight_mask=False,
need_supervision_masks=False,
make_mask_target_fn: Callable = mask_to_ce_target,
image_ids: Optional[List[str]] = None,
):
if mask_filenames is not None and len(image_filenames) != len(mask_filenames):
raise ValueError("Number of images does not corresponds to number of targets")

self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
if image_ids is None:
self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
else:
self.image_ids = image_ids

self.need_weight_mask = need_weight_mask
self.need_supervision_masks = need_supervision_masks

Expand All @@ -100,39 +117,31 @@ def __init__(
def __len__(self):
return len(self.images)

def set_target(self, index: int, value: np.ndarray):
mask_fname = self.masks[index]

value = (value * 255).astype(np.uint8)
cv2.imwrite(mask_fname, value)

def __getitem__(self, index):
image = self.read_image(self.images[index])

data = {"image": image}
if self.masks is not None:
mask = self.read_mask(self.masks[index])
else:
mask = np.ones((image.shape[0], image.shape[1], 1), dtype=np.uint8) * UNLABELED_SAMPLE
data["mask"] = self.read_mask(self.masks[index])

data = self.transform(image=image, mask=mask)
data = self.transform(**data)

image = data["image"]
mask = data["mask"]

sample = {
INPUT_INDEX_KEY: index,
INPUT_IMAGE_ID_KEY: self.image_ids[index],
INPUT_IMAGE_KEY: image_to_tensor(image),
TARGET_MASK_KEY: self.make_target(mask),
}

if self.need_weight_mask:
sample[TARGET_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(mask)).float()

if self.need_supervision_masks:
for i in range(1, 5):
stride = 2 ** i
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)
if self.masks is not None:
mask = data["mask"]
sample[TARGET_MASK_KEY] = self.make_target(mask)
if self.need_weight_mask:
sample[TARGET_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(mask)).float()

if self.need_supervision_masks:
for i in range(1, 6):
stride = 2 ** i
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)

return sample
55 changes: 55 additions & 0 deletions pytorch_toolbelt/datasets/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import random
from typing import Any

from torch.utils.data import Dataset
import numpy as np

__all__ = ["RandomSubsetDataset", "RandomSubsetWithMaskDataset"]


class RandomSubsetDataset(Dataset):
"""
Wrapper to get desired number of samples from underlying dataset
"""

def __init__(self, dataset, num_samples: int):
self.dataset = dataset
self.num_samples = num_samples

def __len__(self) -> int:
return self.num_samples

def __getitem__(self, _) -> Any:
index = random.randrange(len(self.dataset))
return self.dataset[index]


class RandomSubsetWithMaskDataset(Dataset):
"""
Wrapper to get desired number of samples from underlying dataset only considering
samples P for which mask[P] equals True
"""

def __init__(self, dataset: Dataset, mask: np.ndarray, num_samples: int):
if (
not isinstance(mask, np.ndarray)
or mask.dtype != np.bool
or len(mask.shape) != 1
or len(mask) != len(dataset)
):
raise ValueError("Mask must be boolean 1-D numpy array")

if not mask.any():
raise ValueError("Mask must have at least one positive value")

self.dataset = dataset
self.mask = mask
self.num_samples = num_samples
self.indexes = np.flatnonzero(self.mask)

def __len__(self) -> int:
return self.num_samples

def __getitem__(self, _) -> Any:
index = random.choice(self.indexes)
return self.dataset[index]
34 changes: 16 additions & 18 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from torch import nn, Tensor
from typing import List, Union
from typing import List, Union, Iterable, Optional

__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput"]

from pytorch_toolbelt.inference.tta import _deaugment_averaging


class ApplySoftmaxTo(nn.Module):
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", dim=1, temperature=1):
Expand Down Expand Up @@ -55,40 +58,35 @@ class Ensembler(nn.Module):
Compute sum (or average) of outputs of several models.
"""

def __init__(self, models: List[nn.Module], average=True, outputs=None):
def __init__(self, models: List[nn.Module], reduction: str = "mean", outputs: Optional[Iterable[str]] = None):
"""
:param models:
:param average:
:param reduction: Reduction key ('mean', 'sum', 'gmean', 'hmean' or None)
:param outputs: Name of model outputs to average and return from Ensembler.
If None, all outputs from the first model will be used.
"""
super().__init__()
self.outputs = outputs
self.models = nn.ModuleList(models)
self.average = average
self.reduction = reduction

def forward(self, *input, **kwargs): # skipcq: PYL-W0221
output_0 = self.models[0](*input, **kwargs)
num_models = len(self.models)
outputs = [model(*input, **kwargs) for model in self.models]

if self.outputs:
keys = self.outputs
else:
keys = output_0.keys()

for index in range(1, num_models):
output_i = self.models[index](*input, **kwargs)

# Sum outputs
for key in keys:
output_0[key].add_(output_i[key])
keys = outputs[0].keys()

if self.average:
for key in keys:
output_0[key].mul_(1.0 / num_models)
averaged_output = {}
for key in keys:
predictions = [output[key] for output in outputs]
predictions = torch.stack(predictions)
predictions = _deaugment_averaging(predictions, self.reduction)
averaged_output[key] = predictions

return output_0
return averaged_output


class PickModelOutput(nn.Module):
Expand Down
35 changes: 34 additions & 1 deletion pytorch_toolbelt/inference/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Sized, Iterable
from collections.abc import Sized, Iterable
from typing import Union, Tuple

import torch
Expand All @@ -25,6 +25,8 @@
"pad_image_tensor",
"unpad_image_tensor",
"unpad_xyxy_bboxes",
"geometric_mean",
"harmonic_mean",
]


Expand Down Expand Up @@ -205,3 +207,34 @@ def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1):
pad = pad.unsqueeze(dim)

return bboxes_tensor - pad


def geometric_mean(x: Tensor, dim: int) -> Tensor:
"""
Compute geometric mean along given dimension.
This implementation assume values are in range (0...1) (Probabilities)
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Returns:
Tensor
"""
return x.log().mean(dim=dim).exp()


def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
"""
Compute harmonic mean along given dimension.
This implementation assume values are in range (0...1) (Probabilities)
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Returns:
Tensor
"""
x = torch.reciprocal(x.clamp_min(eps))
x = torch.mean(x, dim=dim)
x = torch.reciprocal(x.clamp_min(eps))
return x
Loading

0 comments on commit f3acfca

Please sign in to comment.