Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

support RLE and binary mask #150

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions maskrcnn_benchmark/structures/segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch

import numpy as np
from torch.nn.functional import interpolate
import pycocotools.mask as mask_utils

# transpose
Expand All @@ -15,8 +16,36 @@ class Mask(object):
a 2d tensor
"""

def __init__(self, masks, size, mode):
self.masks = masks
def __init__(self, segm, size, mode):
width, height = size
if isinstance(segm, Mask):
mask = segm.mask
else:
if type(segm) == list:
wangg12 marked this conversation as resolved.
Show resolved Hide resolved
# polygons
mask = (
Polygons(segm, size, "polygon")
.convert("mask")
.to(dtype=torch.float32)
)
elif type(segm) == dict and "counts" in segm:
if type(segm["counts"]) == list:
# uncompressed RLE
h, w = segm["size"]
rle = mask_utils.frPyObjects(segm, h, w)
mask = mask_utils.decode(rle)
mask = torch.from_numpy(mask).to(dtype=torch.float32)
else:
# compressed RLE
mask = mask_utils.decode(segm)
mask = torch.from_numpy(mask).to(dtype=torch.float32)
else:
# binary mask
if type(segm) == np.ndarray:
mask = torch.from_numpy(segm).to(dtype=torch.float32)
else: # torch.Tensor
mask = segm.to(dtype=torch.float32)
self.mask = mask
self.size = size
self.mode = mode

Expand All @@ -28,24 +57,45 @@ def transpose(self, method):

width, height = self.size
if method == FLIP_LEFT_RIGHT:
dim = width
idx = 2
max_idx = width
dim = 1
elif method == FLIP_TOP_BOTTOM:
dim = height
idx = 1
max_idx = height
dim = 0

flip_idx = list(range(dim)[::-1])
flipped_masks = self.masks.index_select(dim, flip_idx)
return Mask(flipped_masks, self.size, self.mode)
flip_idx = torch.tensor(list(range(max_idx)[::-1]))
flipped_mask = self.mask.index_select(dim, flip_idx)
return Mask(flipped_mask, self.size, self.mode)

def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]

cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]]
return Mask(cropped_masks, size=(w, h), mode=self.mode)
box = [round(float(b)) for b in box]
w, h = box[2] - box[0] + 1, box[3] - box[1] + 1
w = max(w, 1)
h = max(h, 1)
cropped_mask = self.mask[box[1] : box[3], box[0] : box[2]]
return Mask(cropped_mask, size=(w, h), mode=self.mode)

def resize(self, size, *args, **kwargs):
pass
width, height = size
scaled_mask = interpolate(
self.mask[None, None, :, :], (height, width), mode="bilinear"
)[0, 0]
return Mask(scaled_mask, size=size, mode=self.mode)

def convert(self, mode):
mask = self.mask.to(dtype=torch.uint8)
return mask

def __iter__(self):
return iter(self.mask)

def __repr__(self):
s = self.__class__.__name__ + "("
# s += "num_mask={}, ".format(len(self.mask))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s


class Polygons(object):
Copy link
Contributor

@botcs botcs Feb 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the mode field for the Polygon is completely irrelevant.
It is never used, but causes:

  • additional argument passing when constructing
  • a headache when trying to find out what Polygon really is

Question:
Wouldn't it be more consistent if the convert method of a Polygon would be renamed to convert_to_mask and would return a Mask instance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I agree with your points.

The reason why this is currently the case was that I wanted to keep the same interface between Polygons and Box (which is not implemented, but is the single-box equivalent of BoxList).
And my original idea was that we would be able to specify what was the underlying type of the data via the mode: is it a polygon, or a mask?

I'm not sure about changing the convert name of the method though.

In general, I think both box_list and segmentation_mask could benefit from some better design / cleanup, but I'm not sure what that would be

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And my original idea was that we would be able to specify what was the underlying type of the data via the mode: is it a polygon, or a mask?

I think I cannot follow this part:

To specify what was the underlying type of the data via the mode

As in the current implementation, a Polygon instance:

  1. can be initialized either with a list of polygons
  2. can be initialized either with a Polygon instance (which is referenced now, but should be hard-copied IMO)
  3. cannot be initialized with a Mask, which feature could be added if necessary (I am doing this to convert GTA binary masks to COCO Polygon format, but only because the binary masks are not supported).

So the underlying data would be specified: Polygon.

On the other hand, about the convert function:

I'm not sure about changing the convert name of the method though.

  1. The convert function takes an argument for the target mode, but actually it accepts just a single answer, which is odd.
  2. If I assume that a Polygon can be only convert-ed to a Mask than convert name is OK, but relies on the assumption that the data can be represented either in Polygons or Masks and nothing else, which is not necessary a trivial assumption, so changing the name to convert_to_mask would be clear from the very first encounter.

keep the same interface

  1. We should in this case add the convert or convert_to_polygon method to the Mask class as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are all reasonable points, and I'm willing to accept PRs that improve the overall consistency and software design of the codebase

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa Thanks, these points were considered for the refactored version, PR #473

Expand Down Expand Up @@ -148,17 +198,26 @@ class SegmentationMask(object):
This class stores the segmentations for all objects in the image
"""

def __init__(self, polygons, size, mode=None):
def __init__(self, segms, size, mode=None):
"""
Arguments:
polygons: a list of list of lists of numbers. The first
segms: three types
(1) polygons: a list of list of lists of numbers. The first
level of the list correspond to individual instances,
the second level to all the polygons that compose the
object, and the third level to the polygon coordinates.
(2) rles: COCO's run length encoding format, uncompressed or compressed
(3) binary masks
size: (width, height)
mode: 'polygon', 'mask'. if mode is 'mask', convert mask of any format to binary mask
"""
assert isinstance(polygons, list)

self.polygons = [Polygons(p, size, mode) for p in polygons]
assert isinstance(segms, list)
if not isinstance(segms[0], (list, Polygons)):
mode = "mask"
if mode == "mask":
self.masks = [Mask(m, size, mode) for m in segms]
else: # polygons
self.masks = [Polygons(p, size, mode) for p in segms]
self.size = size
self.mode = mode

Expand All @@ -169,46 +228,46 @@ def transpose(self, method):
)

flipped = []
for polygon in self.polygons:
flipped.append(polygon.transpose(method))
for mask in self.masks:
flipped.append(mask.transpose(method))
return SegmentationMask(flipped, size=self.size, mode=self.mode)

def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]
cropped = []
for polygon in self.polygons:
cropped.append(polygon.crop(box))
for mask in self.masks:
cropped.append(mask.crop(box))
return SegmentationMask(cropped, size=(w, h), mode=self.mode)

def resize(self, size, *args, **kwargs):
scaled = []
for polygon in self.polygons:
scaled.append(polygon.resize(size, *args, **kwargs))
for mask in self.masks:
scaled.append(mask.resize(size, *args, **kwargs))
return SegmentationMask(scaled, size=size, mode=self.mode)

def to(self, *args, **kwargs):
return self

def __getitem__(self, item):
if isinstance(item, (int, slice)):
selected_polygons = [self.polygons[item]]
selected_masks = [self.masks[item]]
else:
# advanced indexing on a single dimension
selected_polygons = []
selected_masks = []
if isinstance(item, torch.Tensor) and item.dtype == torch.uint8:
item = item.nonzero()
item = item.squeeze(1) if item.numel() > 0 else item
item = item.tolist()
for i in item:
selected_polygons.append(self.polygons[i])
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
selected_masks.append(self.masks[i])
return SegmentationMask(selected_masks, size=self.size, mode=self.mode)

def __iter__(self):
return iter(self.polygons)
return iter(self.masks)

def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.polygons))
s += "num_instances={}, ".format(len(self.masks))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
54 changes: 54 additions & 0 deletions tests/test_segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import numpy as np
import unittest
from maskrcnn_benchmark.structures.segmentation_mask import Mask, Polygons, SegmentationMask


class TestSegmentationMask(unittest.TestCase):
def __init__(self, method_name='runTest'):
super(TestSegmentationMask, self).__init__(method_name)
self.poly = [[423.0, 306.5, 406.5, 277.0, 400.0, 271.5, 389.5, 277.0, 387.5, 292.0,
384.5, 295.0, 374.5, 220.0, 378.5, 210.0, 391.0, 200.5, 404.0, 199.5,
414.0, 203.5, 425.5, 221.0, 438.5, 297.0, 423.0, 306.5],
[385.5, 240.0, 404.0, 234.5, 419.5, 234.0, 416.5, 219.0, 409.0, 209.5,
394.0, 207.5, 385.5, 213.0, 382.5, 221.0, 385.5, 240.0]]
self.width = 640
self.height = 480
self.size = (self.width, self.height)
self.box = [35, 55, 540, 400] # xyxy

self.polygon = Polygons(self.poly, self.size, 'polygon')
self.mask = Mask(self.poly, self.size, 'mask')

def test_crop(self):
poly_crop = self.polygon.crop(self.box)
mask_from_poly_crop = poly_crop.convert('mask')
mask_crop = self.mask.crop(self.box).convert('mask')

self.assertTrue(torch.equal(mask_from_poly_crop, mask_crop))

def test_convert(self):
mask_from_poly_convert = self.polygon.convert('mask')
mask = self.mask.convert('mask')
self.assertTrue(torch.equal(mask_from_poly_convert, mask))

def test_transpose(self):
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
methods = (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM)
for method in methods:
mask_from_poly_flip = self.polygon.transpose(method).convert('mask')
mask_flip = self.mask.transpose(method).convert('mask')
print(method, torch.abs(mask_flip.float() - mask_from_poly_flip.float()).sum())
self.assertTrue(torch.equal(mask_flip, mask_from_poly_flip))

def test_resize(self):
new_size = (600, 500)
mask_from_poly_resize = self.polygon.resize(new_size).convert('mask')
mask_resize = self.mask.resize(new_size).convert('mask')
print('diff resize: ', torch.abs(mask_from_poly_resize.float() - mask_resize.float()).sum())
self.assertTrue(torch.equal(mask_from_poly_resize, mask_resize))

if __name__ == "__main__":
unittest.main()