diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 5e1b4e4df5..9324620b40 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -117,7 +117,7 @@ def from_files( >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) - >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64, 1), dtype="uint8")) + >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] >>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)] >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] @@ -262,7 +262,7 @@ def from_folders( >>> import os >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) - >>> rand_mask = Image.fromarray(np.random.randint(0, 10, (64, 64, 1), dtype="uint8")) + >>> rand_mask = Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) >>> os.makedirs("train_images", exist_ok=True) >>> os.makedirs("train_masks", exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True)