diff --git a/agml/io.py b/agml/io.py index 0cf7dce2..e5bd8b9d 100644 --- a/agml/io.py +++ b/agml/io.py @@ -15,6 +15,8 @@ import random import inspect +import cv2 + from agml.utils.io import ( get_file_list as _get_file_list, get_dir_list as _get_dir_list, @@ -112,5 +114,15 @@ def random_file(path, **kwargs): return random.choice(get_file_list(path, **kwargs)) +def read_image(path, **kwargs): + """Reads an image from a file. + + Args: + path (str): The path to the image file. + **kwargs: Keyword arguments to pass to `cv2.imread`. + Returns: + numpy.ndarray: The image. + """ + return cv2.imread(path, **kwargs) diff --git a/agml/models/segmentation.py b/agml/models/segmentation.py index 8cce07d1..14088721 100644 --- a/agml/models/segmentation.py +++ b/agml/models/segmentation.py @@ -32,7 +32,7 @@ from agml.data.public import source from agml.utils.general import resolve_list_value from agml.utils.image import resolve_image_size -from agml.viz.masks import show_image_with_overlaid_mask, show_image_and_mask +from agml.viz.masks import show_image_and_overlaid_mask, show_image_and_mask # This is last since `agml.models.base` will check for PyTorch Lightning, # and PyTorch Lightning automatically installed torchmetrics with it. @@ -250,7 +250,7 @@ def show_prediction(self, image, overlay = False, **kwargs): image = self._expand_input_images(image)[0] mask = self.predict(image, **kwargs) if overlay: - return show_image_with_overlaid_mask(image, mask, **kwargs) + return show_image_and_overlaid_mask(image, mask, **kwargs) return show_image_and_mask(image, mask, **kwargs) def load_benchmark(self, dataset): diff --git a/agml/viz/boxes.py b/agml/viz/boxes.py index 9a51279d..096cbe94 100644 --- a/agml/viz/boxes.py +++ b/agml/viz/boxes.py @@ -110,6 +110,8 @@ def annotate_object_detection(image, "either `bbox` or `bboxes` for bounding boxes.") if bbox_format is not None: bboxes = convert_bbox_format(bboxes, bbox_format) + if labels is None: + labels = [0] * len(bboxes) # Run a few final checks in order to ensure data is formatted properly. image = format_image(image, mask = False)