Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix CI and improve flaky tests (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali authored Jul 18, 2022
1 parent 140c5f6 commit 04f1b8f
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where grayscale images were not properly converted to RGB when loaded. ([#1394](https://github.com/PyTorchLightning/lightning-flash/pull/1394))

- Fixed a bug where size of mask for instance segmentation doesn't match to size of original image. ([#1353](https://github.com/PyTorchLightning/lightning-flash/pull/1353))

- Fixed image classification data `show_train_batch` for subplots with rows > 1. ([#1339](https://github.com/PyTorchLightning/lightning-flash/pull/1315))
Expand Down
7 changes: 3 additions & 4 deletions flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@
TSV_EXTENSIONS = (".tsv",)


def _load_image_from_image(file, drop_alpha: bool = True):
def _load_image_from_image(file):
img = Image.open(file)
img.load()

if img.mode == "RGBA" and drop_alpha:
img = img.convert("RGB")
img = img.convert("RGB")
return img


Expand All @@ -74,7 +73,7 @@ def _load_image_from_numpy(file):


def _load_spectrogram_from_image(file):
img = _load_image_from_image(file, drop_alpha=False)
img = _load_image_from_image(file)
return np.array(img).astype("float32")


Expand Down
4 changes: 2 additions & 2 deletions flash_examples/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=4,
batch_size=8,
)

# 2. Build the task
Expand Down Expand Up @@ -49,7 +49,7 @@
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
],
batch_size=3,
batch_size=2,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

Expand Down
2 changes: 1 addition & 1 deletion flash_examples/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
Expand Down
4 changes: 2 additions & 2 deletions tests/core/data/utilities/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def test_speed(case):
formatter = get_target_formatter(targets)
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target
assert (end - start) / len(targets) < 1e-4 # 0.1ms per target

start = time.perf_counter()
_ = [formatter(t) for t in targets]
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target
assert (end - start) / len(targets) < 1e-4 # 0.1ms per target
2 changes: 1 addition & 1 deletion tests/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def call_script(
filepath: str,
args: Optional[List[str]] = None,
timeout: Optional[int] = 60 * 10,
timeout: Optional[int] = 60 * 20, # (20 minutes)
) -> Tuple[int, str, str]:
with open(filepath) as original:
data = original.readlines()
Expand Down

0 comments on commit 04f1b8f

Please sign in to comment.