Skip to content

Commit

Permalink
Input data format (#25464)
Browse files Browse the repository at this point in the history
* Add copied from statements for image processors

* Move out rescale and normalize to base image processor

* Remove rescale and normalize from vit (post rebase)

* Update docstrings and tidy up

* PR comments

* Add input_data_format as preprocess argument

* Resolve tests and tidy up

* Remove num_channels argument

* Update doc strings -> default ints not in code formatting
  • Loading branch information
amyeroberts authored Aug 16, 2023
1 parent a6609ca commit 6bca43b
Show file tree
Hide file tree
Showing 55 changed files with 3,054 additions and 591 deletions.
43 changes: 38 additions & 5 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")

def rescale(
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
self,
image: np.ndarray,
scale: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Expand All @@ -536,18 +541,24 @@ def rescale(
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)

def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -565,17 +576,25 @@ def normalize(
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
return normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)

def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -588,12 +607,26 @@ def center_crop(
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
return center_crop(
image,
size=(size["height"], size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)


VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
Expand Down
44 changes: 36 additions & 8 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
Expand Down Expand Up @@ -145,6 +146,7 @@ def resize(
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -159,12 +161,19 @@ def resize(
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
image,
size=(size["height"], size["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)

def reduce_label(self, label: ImageInput) -> np.ndarray:
Expand All @@ -189,21 +198,22 @@ def _preprocess(
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)

if do_resize:
image = self.resize(image=image, size=size, resample=resample)
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

if do_center_crop:
image = self.center_crop(image=image, size=crop_size)
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

return image

Expand All @@ -221,10 +231,13 @@ def _preprocess_image(
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image,
do_reduce_labels=False,
Expand All @@ -238,9 +251,10 @@ def _preprocess_image(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image

def _preprocess_segmentation_map(
Expand All @@ -252,6 +266,7 @@ def _preprocess_segmentation_map(
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_reduce_labels: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""Preprocesses a single segmentation map."""
# All transformations expect numpy arrays.
Expand All @@ -260,8 +275,11 @@ def _preprocess_segmentation_map(
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
added_dimension = True
input_data_format = ChannelDimension.FIRST
else:
added_dimension = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
Expand All @@ -272,6 +290,7 @@ def _preprocess_segmentation_map(
crop_size=crop_size,
do_normalize=False,
do_rescale=False,
input_data_format=ChannelDimension.FIRST,
)
# Remove extra axis if added
if added_dimension:
Expand Down Expand Up @@ -301,6 +320,7 @@ def preprocess(
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -344,8 +364,15 @@ def preprocess(
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
Expand Down Expand Up @@ -403,6 +430,7 @@ def preprocess(
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
]
Expand Down
57 changes: 47 additions & 10 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
Expand Down Expand Up @@ -125,6 +126,7 @@ def resize(
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -140,12 +142,23 @@ def resize(
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
output_size = get_resize_output_image_size(
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)

def preprocess(
self,
Expand All @@ -163,6 +176,7 @@ def preprocess(
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -205,9 +219,15 @@ def preprocess(
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image.
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
Expand Down Expand Up @@ -250,19 +270,36 @@ def preprocess(
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]

if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]

if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]

if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]

images = [to_channel_dimension_format(image, data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
Loading

0 comments on commit 6bca43b

Please sign in to comment.