Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Enhancement/1314 allow sizing show sample (#1381)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicola Occelli <nicola.occelli@ulb.be>
Co-authored-by: Nicola Occelli <nocc0001@hpda.ulb.ac.be>
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
  • Loading branch information
4 people authored Jul 21, 2022
1 parent 4f6fe93 commit 804ca86
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 48 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `figsize` and `limit_nb_samples` for showing batch images ([#1381](https://github.com/Lightning-AI/lightning-flash/pull/1381))

- Added support for `from_lists` for Tabular Classification and Regression ([#1337](https://github.com/PyTorchLightning/lightning-flash/pull/1337))

- Added support for `from_dicts` for Tabular Classification and Regression ([#1331](https://github.com/PyTorchLightning/lightning-flash/pull/1331))
Expand Down
75 changes: 62 additions & 13 deletions flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple

from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -99,10 +99,23 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
the data won't be accessible when using ``num_workers > 0``.
"""

def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None:
self.show(self.batches[running_stage], running_stage, func_names_list)

def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None:
def _show(
self,
running_stage: RunningStage,
func_names_list: List[str],
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
self.show(self.batches[running_stage], running_stage, func_names_list, limit_nb_samples, figsize)

def show(
self,
batch: Dict[str, Any],
running_stage: RunningStage,
func_names_list: List[str],
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""Override this function when you want to visualize a composition."""
# filter out the functions to visualise
func_names_set: Set[str] = set(func_names_list) & set(_CALLBACK_FUNCS)
Expand All @@ -112,22 +125,58 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li
for func_name in func_names_set:
hook_name = f"show_{func_name}"
if _is_overridden(hook_name, self, BaseVisualization):
getattr(self, hook_name)(batch[func_name], running_stage)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
getattr(self, hook_name)(batch[func_name], running_stage, limit_nb_samples, figsize)

def show_load_sample(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
"""Override to visualize ``load_sample`` output data."""

def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
def show_per_sample_transform(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
"""Override to visualize ``per_sample_transform`` output data."""

def show_collate(self, batch: List[Any], running_stage: RunningStage) -> None:
def show_collate(
self,
batch: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""Override to visualize ``collate`` output data."""

def show_per_batch_transform(self, batch: List[Any], running_stage: RunningStage) -> None:
def show_per_batch_transform(
self,
batch: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""Override to visualize ``per_batch_transform`` output data."""

def show_per_sample_transform_on_device(self, samples: List[Any], running_stage: RunningStage) -> None:
def show_per_sample_transform_on_device(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""Override to visualize ``per_sample_transform_on_device`` output data."""

def show_per_batch_transform_on_device(self, batch: List[Any], running_stage: RunningStage) -> None:
def show_per_batch_transform_on_device(
self,
batch: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""Override to visualize ``per_batch_transform_on_device`` output data."""
54 changes: 44 additions & 10 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,14 @@ def _reset_iterator(self, stage: str) -> Iterable[Any]:
setattr(self, iter_name, iterator)
return iterator

def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None:
def _show_batch(
self,
stage: str,
func_names: Union[str, List[str]],
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
reset: bool = True,
) -> None:
"""This function is used to handle transforms profiling for batch visualization."""
# don't show in CI
if os.getenv("FLASH_TESTING", "0") == "1":
Expand All @@ -469,6 +476,9 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool
if isinstance(func_names, str):
func_names = [func_names]

if not limit_nb_samples:
limit_nb_samples = self.batch_size

iter_dataloader = getattr(self, iter_name)
with self.data_fetcher.enable():
if reset:
Expand All @@ -479,29 +489,53 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
data_fetcher._show(stage, func_names, limit_nb_samples, figsize)
if reset:
self.data_fetcher.batches[stage] = {}

def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
def show_train_batch(
self,
hooks_names: Union[str, List[str]] = "load_sample",
reset: bool = True,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""This function is used to visualize a batch from the train dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
self._show_batch(stage_name, hooks_names, reset=reset)
self._show_batch(stage_name, hooks_names, limit_nb_samples, figsize, reset=reset)

def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
def show_val_batch(
self,
hooks_names: Union[str, List[str]] = "load_sample",
reset: bool = True,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""This function is used to visualize a batch from the validation dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING]
self._show_batch(stage_name, hooks_names, reset=reset)
self._show_batch(stage_name, hooks_names, limit_nb_samples, figsize, reset=reset)

def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
def show_test_batch(
self,
hooks_names: Union[str, List[str]] = "load_sample",
reset: bool = True,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""This function is used to visualize a batch from the test dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.TESTING]
self._show_batch(stage_name, hooks_names, reset=reset)
self._show_batch(stage_name, hooks_names, limit_nb_samples, figsize, reset=reset)

def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
def show_predict_batch(
self,
hooks_names: Union[str, List[str]] = "load_sample",
reset: bool = True,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
) -> None:
"""This function is used to visualize a batch from the prediction dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)
self._show_batch(stage_name, hooks_names, limit_nb_samples, figsize, reset=reset)

def _get_property(self, property_name: str) -> Optional[Any]:
train = getattr(self.train_dataset, property_name, None)
Expand Down
43 changes: 33 additions & 10 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1217,13 +1217,22 @@ def _to_numpy(img: Union[np.ndarray, torch.Tensor, Image.Image]) -> np.ndarray:
return out

@requires("matplotlib")
def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
def _show_images_and_labels(
self,
data: List[Any],
num_samples: int,
title: str,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
num_samples = max(1, min(num_samples, limit_nb_samples))

# define the image grid
cols: int = min(num_samples, self.max_cols)
rows: int = num_samples // cols

# create figure and set title
fig, axs = plt.subplots(rows, cols)
fig, axs = plt.subplots(rows, cols, figsize=figsize)
fig.suptitle(title)

if not isinstance(axs, np.ndarray):
Expand All @@ -1248,14 +1257,28 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
ax.axis("off")
plt.show(block=self.block_viz_window)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
def show_load_sample(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
win_title: str = f"{running_stage} - show_load_sample"
self._show_images_and_labels(samples, len(samples), win_title)

def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
self._show_images_and_labels(samples, len(samples), win_title, limit_nb_samples, figsize)

def show_per_sample_transform(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
win_title: str = f"{running_stage} - show_per_sample_transform"
self._show_images_and_labels(samples, len(samples), win_title)
self._show_images_and_labels(samples, len(samples), win_title, limit_nb_samples, figsize)

def show_per_batch_transform(self, batch: List[Any], running_stage):
def show_per_batch_transform(
self, batch: List[Any], running_stage, limit_nb_samples: int = None, figsize: Tuple[int, int] = (6.4, 4.8)
):
win_title: str = f"{running_stage} - show_per_batch_transform"
self._show_images_and_labels(batch[0], batch[0][DataKeys.INPUT].shape[0], win_title)
self._show_images_and_labels(batch[0], batch[0][DataKeys.INPUT].shape[0], win_title, limit_nb_samples, figsize)
39 changes: 32 additions & 7 deletions flash/image/segmentation/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,29 @@ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
return out

@requires("matplotlib")
def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
def _show_images_and_labels(
self,
data: List[Any],
num_samples: int,
title: str,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
num_samples = max(1, min(num_samples, limit_nb_samples))

# define the image grid
cols: int = min(num_samples, self.max_cols)
rows: int = num_samples // cols

# create figure and set title
fig, axs = plt.subplots(rows, cols)
fig, axs = plt.subplots(rows, cols, figsize=figsize)
fig.suptitle(title)

for i, ax in enumerate(axs.ravel()):
if not isinstance(axs, np.ndarray):
axs = np.array(axs)
axs = axs.flatten()

for i, ax in enumerate(axs):
# unpack images and labels
sample = data[i]
if isinstance(sample, dict):
Expand All @@ -81,10 +94,22 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
ax.axis("off")
plt.show(block=self.block_viz_window)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
def show_load_sample(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int,
figsize: Tuple[int, int] = (6.4, 4.8),
):
win_title: str = f"{running_stage} - show_load_sample"
self._show_images_and_labels(samples, len(samples), win_title)
self._show_images_and_labels(samples, len(samples), win_title, limit_nb_samples, figsize)

def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
def show_per_sample_transform(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int,
figsize: Tuple[int, int] = (6.4, 4.8),
):
win_title: str = f"{running_stage} - show_per_sample_transform"
self._show_images_and_labels(samples, len(samples), win_title)
self._show_images_and_labels(samples, len(samples), win_title, limit_nb_samples, figsize)
18 changes: 15 additions & 3 deletions flash/template/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Tuple, Type

import numpy as np
import torch
Expand Down Expand Up @@ -240,8 +240,20 @@ class TemplateVisualization(BaseVisualization):
If you want to provide a visualization with your task, you can override these hooks.
"""

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
def show_load_sample(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
print(samples)

def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
def show_per_sample_transform(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
print(samples)
Loading

0 comments on commit 804ca86

Please sign in to comment.