Skip to content

Commit

Permalink
Fixed small issues and added image reading method
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh7joshi committed Jul 14, 2023
1 parent da42150 commit b217352
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
12 changes: 12 additions & 0 deletions agml/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import random
import inspect

import cv2

from agml.utils.io import (
get_file_list as _get_file_list,
get_dir_list as _get_dir_list,
Expand Down Expand Up @@ -112,5 +114,15 @@ def random_file(path, **kwargs):
return random.choice(get_file_list(path, **kwargs))


def read_image(path, **kwargs):
"""Reads an image from a file.
Args:
path (str): The path to the image file.
**kwargs: Keyword arguments to pass to `cv2.imread`.
Returns:
numpy.ndarray: The image.
"""
return cv2.imread(path, **kwargs)

4 changes: 2 additions & 2 deletions agml/models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from agml.data.public import source
from agml.utils.general import resolve_list_value
from agml.utils.image import resolve_image_size
from agml.viz.masks import show_image_with_overlaid_mask, show_image_and_mask
from agml.viz.masks import show_image_and_overlaid_mask, show_image_and_mask

# This is last since `agml.models.base` will check for PyTorch Lightning,
# and PyTorch Lightning automatically installed torchmetrics with it.
Expand Down Expand Up @@ -250,7 +250,7 @@ def show_prediction(self, image, overlay = False, **kwargs):
image = self._expand_input_images(image)[0]
mask = self.predict(image, **kwargs)
if overlay:
return show_image_with_overlaid_mask(image, mask, **kwargs)
return show_image_and_overlaid_mask(image, mask, **kwargs)
return show_image_and_mask(image, mask, **kwargs)

def load_benchmark(self, dataset):
Expand Down
2 changes: 2 additions & 0 deletions agml/viz/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def annotate_object_detection(image,
"either `bbox` or `bboxes` for bounding boxes.")
if bbox_format is not None:
bboxes = convert_bbox_format(bboxes, bbox_format)
if labels is None:
labels = [0] * len(bboxes)

# Run a few final checks in order to ensure data is formatted properly.
image = format_image(image, mask = False)
Expand Down

0 comments on commit b217352

Please sign in to comment.