Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input data format #25464

Merged
merged 9 commits into from
Aug 16, 2023
43 changes: 38 additions & 5 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")

def rescale(
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
self,
image: np.ndarray,
scale: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Expand All @@ -536,18 +541,24 @@ def rescale(
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

Returns:
`np.ndarray`: The rescaled image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)

def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -565,17 +576,25 @@ def normalize(
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

Returns:
`np.ndarray`: The normalized image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
return normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)

def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -588,12 +607,26 @@ def center_crop(
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
return center_crop(
image,
size=(size["height"], size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)


VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
Expand Down
44 changes: 36 additions & 8 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
Expand Down Expand Up @@ -145,6 +146,7 @@ def resize(
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -159,12 +161,19 @@ def resize(
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
image,
size=(size["height"], size["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)

def reduce_label(self, label: ImageInput) -> np.ndarray:
Expand All @@ -189,21 +198,22 @@ def _preprocess(
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)

if do_resize:
image = self.resize(image=image, size=size, resample=resample)
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

if do_center_crop:
image = self.center_crop(image=image, size=crop_size)
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

return image

Expand All @@ -221,10 +231,13 @@ def _preprocess_image(
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image,
do_reduce_labels=False,
Expand All @@ -238,9 +251,10 @@ def _preprocess_image(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image

def _preprocess_segmentation_map(
Expand All @@ -252,6 +266,7 @@ def _preprocess_segmentation_map(
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_reduce_labels: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""Preprocesses a single segmentation map."""
# All transformations expect numpy arrays.
Expand All @@ -260,8 +275,11 @@ def _preprocess_segmentation_map(
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
added_dimension = True
input_data_format = ChannelDimension.FIRST
else:
added_dimension = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
Expand All @@ -272,6 +290,7 @@ def _preprocess_segmentation_map(
crop_size=crop_size,
do_normalize=False,
do_rescale=False,
input_data_format=ChannelDimension.FIRST,
)
# Remove extra axis if added
if added_dimension:
Expand Down Expand Up @@ -301,6 +320,7 @@ def preprocess(
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -344,8 +364,15 @@ def preprocess(
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
Expand Down Expand Up @@ -403,6 +430,7 @@ def preprocess(
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
]
Expand Down
57 changes: 47 additions & 10 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
Expand Down Expand Up @@ -125,6 +126,7 @@ def resize(
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -140,12 +142,23 @@ def resize(
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
output_size = get_resize_output_image_size(
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)

def preprocess(
self,
Expand All @@ -163,6 +176,7 @@ def preprocess(
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -205,9 +219,15 @@ def preprocess(
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image.
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
Expand Down Expand Up @@ -250,19 +270,36 @@ def preprocess(
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]

if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images]
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]

if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]

if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]

images = [to_channel_dimension_format(image, data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
Loading