diff --git a/data/xView.yaml b/data/xView.yaml index fabcdb0bdd13..b5af470058e0 100644 --- a/data/xView.yaml +++ b/data/xView.yaml @@ -36,7 +36,7 @@ download: | from PIL import Image from tqdm import tqdm - from utils.datasets import autosplit + from utils.datasets_old import autosplit from utils.general import download, xyxy2xywhn diff --git a/detect.py b/detect.py index 0b1d93897d4c..07c57f42d47a 100644 --- a/detect.py +++ b/detect.py @@ -20,7 +20,7 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path from models.experimental import attempt_load -from utils.datasets import LoadStreams, LoadImages +from utils.datasets_old import LoadStreams, LoadImages from utils.general import check_img_size, check_requirements, check_imshow, colorstr, is_ascii, non_max_suppression, \ apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box from utils.plots import Annotator, colors diff --git a/models/common.py b/models/common.py index 90bfef5124b3..457c91107c32 100644 --- a/models/common.py +++ b/models/common.py @@ -17,7 +17,7 @@ from PIL import Image from torch.cuda import amp -from utils.datasets import exif_transpose, letterbox +from utils.datasets_old import exif_transpose, letterbox from utils.general import colorstr, increment_path, is_ascii, make_divisible, non_max_suppression, save_one_box, \ scale_coords, xyxy2xywh from utils.plots import Annotator, colors diff --git a/models/tf.py b/models/tf.py index 40e7d20a9d84..ca19710b2830 100644 --- a/models/tf.py +++ b/models/tf.py @@ -52,7 +52,7 @@ from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3 from models.experimental import MixConv2d, CrossConv, attempt_load from models.yolo import Detect -from utils.datasets import LoadImages +from utils.datasets_old import LoadImages from utils.general import make_divisible, check_file, check_dataset logger = logging.getLogger(__name__) diff --git a/requirements.txt b/requirements.txt index 2ad65ba53e29..dc57f617f9eb 100755 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,4 @@ pandas # pycocotools>=2.0 # COCO mAP # albumentations>=1.0.3 thop # FLOPs computation +# pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/utils/datasets/__init__.py b/tests/utils/datasets/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/utils/datasets/test_coco.py b/tests/utils/datasets/test_coco.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/utils/test_file.py b/tests/utils/test_file.py new file mode 100644 index 000000000000..6b0e5bf8f9e6 --- /dev/null +++ b/tests/utils/test_file.py @@ -0,0 +1,62 @@ +import shutil +from pathlib import Path +from typing import Generator, Optional, List, Callable + +import pytest + +from tests.utils.test_utils import prepare_temporary_dir +from utils.file import dump_text_file, get_directory_content + + +@pytest.fixture +def mock_directory_path() -> Generator[str, None, None]: + output_path = prepare_temporary_dir(directory_name="mock_directory_path") + yield output_path + shutil.rmtree(output_path) + + +def mock_directory_content(directory_path: str) -> None: + dump_text_file(Path(directory_path).joinpath('file_1.json').as_posix(), '') + dump_text_file(Path(directory_path).joinpath('file_2.txt').as_posix(), '') + dump_text_file(Path(directory_path).joinpath('file_3.txt').as_posix(), '') + + +@pytest.mark.parametrize( + "extension, mock_callback, expected_result", + [ + ( + None, + lambda x: None, + 0 + ), # empty directory + ( + None, + mock_directory_content, + 3 + ), # directory contain 3 files + ( + 'json', + mock_directory_content, + 1 + ), # directory contain 1 .json file + ( + 'txt', + mock_directory_content, + 2 + ), # directory contain 2 .txt files + ( + 'avi', + mock_directory_content, + 0 + ), # directory contain 0 .avi files + ] +) +def test_get_directory_content( + mock_directory_path: str, + extension: Optional[str], + mock_callback: Callable[[str], None], + expected_result: List[str] +) -> None: + mock_callback(mock_directory_path) + result = get_directory_content(directory_path=mock_directory_path, extension=extension) + assert len(result) == expected_result diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 000000000000..989bd2173ec4 --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,7 @@ +from pathlib import Path + + +def prepare_temporary_dir(directory_name: str) -> str: + directory_path = Path(__file__).parent.joinpath(directory_name) + directory_path.mkdir(parents=True, exist_ok=True) + return directory_path.as_posix() diff --git a/train.py b/train.py index 2fe38ef043d0..5c936ef707c4 100644 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors -from utils.datasets import create_dataloader +from utils.datasets_old import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 66a2712dfd5d..5699726c7ece 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -109,7 +109,7 @@ def print_results(k): if isinstance(dataset, str): # *.yaml file with open(dataset, errors='ignore') as f: data_dict = yaml.safe_load(f) # model dict - from utils.datasets import LoadImagesAndLabels + from utils.datasets_old import LoadImagesAndLabels dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) # Get label wh diff --git a/utils/datasets/__init__.py b/utils/datasets/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/datasets/coco.py b/utils/datasets/coco.py new file mode 100644 index 000000000000..a1d87d1f9fcb --- /dev/null +++ b/utils/datasets/coco.py @@ -0,0 +1,74 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +COCO dataset loading utils +""" + +import json +from typing import Dict, Union, List +from collections import defaultdict + +import torch +import numpy as np + + +IMAGE_KEY = "images" +IMAGE_FILE_NAME_KEY = "file_name" +IMAGE_ID_KEY = "id" +IMAGE_WIDTH_KEY = "width" +IMAGE_HEIGHT_KEY = "height" +ANNOTATION_KEY = "annotations" +ANNOTATION_IMAGE_ID_KEY = "image_id" +ANNOTATION_BBOX_KEY = "bbox" +ANNOTATION_CATEGORY_ID = "category_id" + + +def read_json_file(file_path: str, **kwargs) -> Union[list, dict]: + with open(file_path, 'r') as file: + return json.load(file, **kwargs) + + +def load_coco_annotations(coco_data: dict) -> Dict[str, torch.Tensor]: + coco_image_entries_map = map_coco_image_entries(coco_image_entries=coco_data[IMAGE_KEY]) + coco_annotation_entries_map = map_coco_annotation_entries(coco_annotation_entries=coco_data[ANNOTATION_KEY]) + return { + coco_image_entries_map[image_id][IMAGE_FILE_NAME_KEY]: process_coco_annotation( + coco_annotation_entries=coco_annotation_entries_map[image_id], + coco_image_data=coco_image_entries_map[image_id] + ) + for image_id + in sorted(coco_image_entries_map.keys()) + } + + +def map_coco_image_entries(coco_image_entries: List[dict]) -> Dict[int, dict]: + return { + image_data[IMAGE_ID_KEY]: image_data + for image_data + in coco_image_entries + } + + +def map_coco_annotation_entries(coco_annotation_entries: List[dict]) -> Dict[int, List[dict]]: + result = defaultdict(list) + for coco_annotation_entry in coco_annotation_entries: + image_id = coco_annotation_entry[ANNOTATION_IMAGE_ID_KEY] + result[image_id].append(coco_annotation_entry) + return result + + +def process_coco_annotation(coco_annotation_entries: List[dict], coco_image_data: dict) -> torch.Tensor: + image_width = coco_image_data[IMAGE_WIDTH_KEY] + image_height = coco_image_data[IMAGE_HEIGHT_KEY] + annotations = [] + for coco_annotation_entry in coco_annotation_entries: + category_id = coco_annotation_entry[ANNOTATION_CATEGORY_ID] + x_min, y_min, width, height = coco_annotation_entry[ANNOTATION_BBOX_KEY] + annotations.append([ + 0, + category_id, + (x_min + width / 2) / image_width, + (y_min + height / 2) / image_height, + width / image_width, + height / image_height + ]) + return torch.as_tensor(np.array(annotations)) diff --git a/utils/datasets/core.py b/utils/datasets/core.py new file mode 100644 index 000000000000..5809e64dfac1 --- /dev/null +++ b/utils/datasets/core.py @@ -0,0 +1,199 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Dataset loaders +""" + +import os +from pathlib import Path +from typing import Tuple, List, Optional + +import torch +from torch.utils.data import Dataset + +from utils.datasets.coco import read_json_file, load_coco_annotations +from utils.datasets.error import COCODatasetError +from utils.datasets.image_cache import ImageProvider +from utils.datasets.label_cache import LabelCache, get_hash +from utils.datasets.yolo import load_image_names_from_paths, img2label_paths + +DatasetEntry = Tuple[torch.Tensor, torch.Tensor, str] + + +def assemble_data_loader() -> None: + pass # TODO + + +class COCODataset(Dataset): + """ + dataset + ├── annotations.json + ├── dataset.cache [optional] + └── images + ├── image-1.jpg + ├── image-2.jpg + └── ... + """ + + ANNOTATIONS_FILE_NAME = "annotations.json" + IMAGES_DIRECTORY_NAME = "images" + + def __init__(self, path: str, cache_images: Optional[str] = None) -> None: + """ + Load COCO labels along with images from provided path. + + Args: + path: `str` - path to `dataset` root directory. + cache_images: `Optional[str]` - flag enabling image caching. Can be equal to one of three values: `"ram"`, + `"disc"` or `None`. `"ram"` - all images are stored in memory to enable fastest access. This may however + result in exceeding the limit of available memory. `"disc"` - all images are stored on hard drive but in + raw, uncompressed form. This prevents memory overflow, and offers faster access to data then regular + image read. `None` - image caching is turned of. + """ + self._validate_dataset_path(path=path) + self.path = path + self.cache_images = cache_images + self.image_paths, self.labels = self._load_image_paths_and_labels(path=path) + self.image_provider = ImageProvider(cache_images=cache_images, paths=self.image_paths) + + def __len__(self) -> int: + return len(self.image_paths) + + def __getitem__(self, index: int) -> DatasetEntry: + image_path = self.image_paths[index] + labels = self.labels[index] + image = self.image_provider.get_image(path=image_path) + return torch.from_numpy(image), labels, image_path + + @staticmethod + def collate_fn(batch: List[DatasetEntry]) -> torch.Tensor: + pass # TODO: + + @staticmethod + def _load_image_paths_and_labels(path: str) -> Tuple[List[str], List[torch.Tensor]]: + images_path = os.path.join(path, COCODataset.IMAGES_DIRECTORY_NAME) + annotations_path = os.path.join(path, COCODataset.ANNOTATIONS_FILE_NAME) + coco_data = read_json_file(file_path=annotations_path) + coco_annotations = load_coco_annotations(coco_data=coco_data) + image_paths = [ + os.path.join(images_path, image_name) + for image_name + in coco_annotations.keys() + ] + return image_paths, list(coco_annotations.values()) + + @staticmethod + def _validate_dataset_path(path: str) -> None: + images_path = os.path.join(path, COCODataset.IMAGES_DIRECTORY_NAME) + annotations_path = os.path.join(path, COCODataset.ANNOTATIONS_FILE_NAME) + if not os.path.isfile(annotations_path) or not os.path.isdir(images_path): + raise COCODatasetError("Given path does not point to COCO dataset.") + + @staticmethod + def resolve_cache_path() -> Path: + pass # TODO: + + +class YOLODataset(Dataset): + """ + dataset + ├── image_names.txt [optional] + ├── image_names.cache [optional] + ├── images + │ ├── image-1.jpg + │ ├── image-2.jpg + │ └── ... + └── labels + ├── image-1.txt + ├── image-2.txt + └── ... + """ + + def __init__(self, path: str, cache_images: Optional[str] = None) -> None: + """ + Load YOLO labels along with images from provided path. + + Args: + path: `str` - path to `images` directory or to `image_names.txt` file. + cache_images: `Optional[str]` - flag enabling image caching. Can be equal to one of three values: `"ram"`, + `"disc"` or `None`. `"ram"` - all images are stored in memory to enable fastest access. This may however + result in exceeding the limit of available memory. `"disc"` - all images are stored on hard drive but in + raw, uncompressed form. This prevents memory overflow, and offers faster access to data then regular + image read. `None` - image caching is turned of. + """ + self.path = path + self.cache_images = cache_images + self.image_paths, self.labels = self._load_image_paths_and_labels(path=path) + self.image_provider = ImageProvider(cache_images=cache_images, paths=self.image_paths) + + def __len__(self) -> int: + return len(self.image_paths) + + def __getitem__(self, index: int) -> DatasetEntry: + image_path = self.image_paths[index] + labels = self.labels[index] + image = self.image_provider.get_image(path=image_path) + return torch.from_numpy(image), labels, image_path + + @staticmethod + def collate_fn(batch: List[DatasetEntry]) -> torch.Tensor: + pass + + @staticmethod + def _load_image_paths_and_labels(path: str) -> Tuple[List[str], List[torch.Tensor]]: + image_paths = load_image_names_from_paths(paths=path) + label_paths = img2label_paths(image_paths=image_paths) + + # TODO: finalize yolo labels cache plugin + cache_path = YOLODataset.resolve_cache_path(path=path, label_paths=label_paths) + label_cache = LabelCache.load( + path=cache_path, + hash=get_hash(label_paths + image_paths) + ) + labels = [ + label_cache[image_path] + for image_path + in image_paths + ] + + return image_paths, labels + + @staticmethod + def resolve_cache_path(path: str, label_paths: List[str]) -> Path: + path = Path(path) + return (path if path.is_file() else Path(label_paths[0]).parent).with_suffix('.cache') + + +class TransformedDataset(Dataset): + + def __init__( + self, + source_dataset: Dataset, + img_size: int = 640, + batch_size: int = 16, + augment: bool = False, + hyp=None, + rect=False, + single_cls: bool = False, + stride: int = 32, + pad: float = 0.0 + ) -> None: + self.source_dataset = source_dataset + self.img_size = img_size + self.batch_size = batch_size + self.augment = augment + self.hyp = hyp + self.rect = rect + self.stride = stride + self.single_cls = single_cls + self.pad = pad + + def __len__(self) -> int: + return len(self.source_dataset) + + def __getitem__(self, index: int) -> DatasetEntry: + image, labels, image_path = self.source_dataset[index] + + if self.single_cls: + labels[:, 0] = 0 + + return image, labels, image_path diff --git a/utils/datasets/error.py b/utils/datasets/error.py new file mode 100644 index 000000000000..9d7341073faf --- /dev/null +++ b/utils/datasets/error.py @@ -0,0 +1,7 @@ + +class CacheError(Exception): + pass + + +class COCODatasetError(Exception): + pass diff --git a/utils/datasets/image_cache.py b/utils/datasets/image_cache.py new file mode 100644 index 000000000000..9ba07a54eb3b --- /dev/null +++ b/utils/datasets/image_cache.py @@ -0,0 +1,162 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Image loading and caching helpers +""" + +import os +from abc import ABC, abstractmethod +from multiprocessing.pool import ThreadPool +from pathlib import Path +from typing import Optional, List, Dict + +import cv2 +import numpy as np +from tqdm import tqdm + +from utils.datasets.error import CacheError + + +NUM_THREADS = min(8, os.cpu_count()) + + +class BaseImageCache(ABC): + + _cache_size = 0 + _loading_completed = False + + def __init__(self, cache_type: str, thread_count: int = 8) -> None: + self._thread_count = min(thread_count, os.cpu_count()) + self._cache_type = cache_type + + @property + def cache_size(self) -> float: + return self._cache_size + + def load_images(self, paths: List[str]) -> None: + if self._loading_completed: + raise CacheError(f"load_images method can only be called once.") + self._load_images(paths=paths) + self._loading_completed = True + print(f"Image caching completed. ({self._cache_size / 1E9:.1f}GB {self._cache_type})") + + def get_image(self, path: str) -> np.ndarray: + if not self._loading_completed: + raise CacheError("Could not obtain the image. Image cache is not yet initialized.") + image = self._get_image(path=path) + if image is None: + raise CacheError(f"Image with {path} path could not be found in cache.") + return image + + @abstractmethod + def _load_images(self, paths: List[str]) -> None: + pass + + @abstractmethod + def _get_image(self, path: str) -> Optional[np.ndarray]: + pass + + +class DiscImageCache(BaseImageCache): + + def __init__(self, thread_count: int = 8) -> None: + super().__init__(cache_type="disc", thread_count=thread_count) + self._image_paths: Dict[str, str] = {} + self._cache_path: Optional[str] = None + + def _load_images(self, paths: List[str]) -> None: + self._cache_path = self._init_cache(paths=paths) + self._image_paths = { + path: Path(self._cache_path) / Path(path).with_suffix('.npy').name + for path + in paths + } + results = ThreadPool(self._thread_count).imap(lambda x: self._load_image(x), paths) + bar = tqdm(enumerate(results), total=len(paths)) + for i in bar: + bar.desc = f"Caching images ({self._cache_size / 1E9:.1f}GB {self._cache_type})" + bar.close() + + def _get_image(self, path: str) -> Optional[np.ndarray]: + target_path = self._image_paths.get(path) + if target_path is None: + return None + return np.load(target_path) + + def _load_image(self, path: str) -> None: + image = cv2.imread(path) + if image is None: + raise CacheError(f"Image with {path} path could not be found.") + target_path = self._image_paths[path] + np.save(target_path, image) + self._cache_size += image.nbytes + + def _init_cache(self, paths: List[str]) -> str: + cache_path = Path(paths[0]).parent.as_posix() + '_npy' + Path(cache_path).mkdir(parents=True, exist_ok=True) + return cache_path + + +class RAMImageCache(BaseImageCache): + + def __init__(self, thread_count: int = 8) -> None: + super().__init__(cache_type="ram", thread_count=thread_count) + self._images: Dict[str, np.ndarray] = {} + + def _load_images(self, paths: List[str]) -> None: + results = ThreadPool(self._thread_count).imap(lambda x: self._load_image(x), paths) + bar = tqdm(enumerate(results), total=len(paths)) + for i in bar: + bar.desc = f"Caching images ({self._cache_size / 1E9:.1f}GB {self._cache_type})" + bar.close() + + def _get_image(self, path: str) -> Optional[np.ndarray]: + return self._images.get(path) + + def _load_image(self, path: str) -> None: + image = cv2.imread(path) + if image is None: + raise CacheError(f"Image with {path} path could not be found.") + self._images[path] = image + self._cache_size += image.nbytes + + +class ImageProvider: + + def __init__(self, cache_images: Optional[str], paths: List[str]) -> None: + """ + High level class responsible for loading images. ImageProvider has the ability to cache images on disk or in + memory to speed up the loading process. + + Args: + cache_images: `Optional[str]` - flag enabling image caching. Can be equal to one of three values: `"ram"`, + `"disc"` or `None`. `"ram"` - all images are stored in memory to enable fastest access. This may however + result in exceeding the limit of available memory. `"disc"` - all images are stored on hard drive but in + raw, uncompressed form. This prevents memory overflow, and offers faster access to data then regular + image read. `None` - image caching is turned of. + paths: `List[str]` - list of image paths that you would like to cache. + """ + self._cache_images = cache_images + self._cache = ImageProvider._init_cache(cache_images=cache_images, paths=paths) + + def get_image(self, path: str) -> np.ndarray: + if self._cache_images: + return self._cache.get_image(path=path) + else: + image = cv2.imread(path) + if image is None: + raise CacheError(f"Image with {path} path could not be found.") + return image + + @staticmethod + def _init_cache(cache_images: Optional[str], paths: List[str]) -> Optional[BaseImageCache]: + if cache_images == "disc": + cache = DiscImageCache(thread_count=NUM_THREADS) + cache.load_images(paths=paths) + return cache + if cache_images == 'ram': + cache = RAMImageCache(thread_count=NUM_THREADS) + cache.load_images(paths=paths) + return cache + if cache_images is None: + return None + raise CacheError(f"Unsupported cache type. Expected disc, ram or None. {cache_images} given.") diff --git a/utils/datasets/label_cache.py b/utils/datasets/label_cache.py new file mode 100644 index 000000000000..c54b5c5f6df6 --- /dev/null +++ b/utils/datasets/label_cache.py @@ -0,0 +1,52 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Labels caching helpers +""" + +import hashlib +import os +from pathlib import Path +from typing import List, Union, Optional + +import numpy as np + + +def get_hash(paths: List[str]) -> str: + """ + Returns a single hash value of a list of paths (files or dirs) + """ + size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes + h = hashlib.md5(str(size).encode()) # hash sizes + h.update(''.join(paths).encode()) # hash paths + return h.hexdigest() # return hash + + +class LabelCache: + + VERSION = 0.4 + VERSION_KEY = "version" + HASH_KEY = "hash" + RESULTS_KEY = "results" + + @staticmethod + def load(path: Union[str, Path], hash: str) -> Optional[dict]: + cache = LabelCache._safe_load(path=path) + if all([ + cache, + cache[LabelCache.VERSION_KEY] == LabelCache.VERSION, + cache[LabelCache.HASH_KEY] == hash + ]): + return cache + else: + return None + + @staticmethod + def save(path: Union[str, Path], hash: str, data: dict) -> None: + pass # TODO + + @staticmethod + def _safe_load(path: Union[str, Path]) -> Optional[dict]: + try: + return np.load(path, allow_pickle=True).item() + except: + return None diff --git a/utils/datasets/todo.txt b/utils/datasets/todo.txt new file mode 100644 index 000000000000..d2a55949f8c1 --- /dev/null +++ b/utils/datasets/todo.txt @@ -0,0 +1,6 @@ +# handle corrupted images +# handle prefix, most likely by using proper logging +# coco label caching +# yolo label loading + +# why we need information about batch_size if we only return one image \ No newline at end of file diff --git a/utils/datasets/yolo.py b/utils/datasets/yolo.py new file mode 100644 index 000000000000..d3a0de921133 --- /dev/null +++ b/utils/datasets/yolo.py @@ -0,0 +1,51 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +YOLO dataset loading utils +""" + +import os +import glob +from pathlib import Path + +from typing import List, Union + +from utils.datasets_old import IMG_FORMATS +from utils.file import read_text_file_lines + + +def load_image_names_from_paths(paths: Union[str, List[str]]) -> List[str]: + image_paths = [] + for path in paths if isinstance(paths, list) else [paths]: + path = Path(path) # os-agnostic + if path.is_dir(): # dir + image_paths += glob.glob(str(path / '**' / '*.*'), recursive=True) + elif path.is_file(): # file + local_paths = read_text_file_lines(path) + parent = str(path.parent) + os.sep + image_paths += [ + local_path.replace('./', parent) if local_path.startswith('./') else local_path + for local_path + in local_paths + ] + else: + raise Exception(f'{path} does not exist') + return sorted([x.replace('/', os.sep) for x in image_paths if x.split('.')[-1].lower() in IMG_FORMATS]) + + +def img2label_paths(image_paths: List[str]) -> List[str]: + """ + Define label paths as a function of image paths. + """ + sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings + return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in image_paths] + + +class YOLOLabelsLoader: + + def __init__(self, image_paths: List[str], labels_paths: List[str]) -> None: + self.image_paths = image_paths + self.labels_paths = labels_paths + self.missing_labels = 0 + + def load_label(self) -> None: + pass # TODO diff --git a/utils/datasets.py b/utils/datasets_old.py similarity index 99% rename from utils/datasets.py rename to utils/datasets_old.py index 852bb7c04aa8..3a789c07f7e0 100755 --- a/utils/datasets.py +++ b/utils/datasets_old.py @@ -91,7 +91,7 @@ def exif_transpose(image): return image -def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, +def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=None, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): @@ -103,7 +103,6 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non single_cls=single_cls, stride=int(stride), pad=pad, - image_weights=image_weights, prefix=prefix) batch_size = min(batch_size, len(dataset)) @@ -365,13 +364,13 @@ def img2label_paths(img_paths): class LoadImagesAndLabels(Dataset): # for training/testing - def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, + + def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): self.img_size = img_size self.augment = augment self.hyp = hyp - self.image_weights = image_weights - self.rect = False if image_weights else rect + self.rect = rect self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) self.mosaic_border = [-img_size // 2, -img_size // 2] self.stride = stride @@ -402,6 +401,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r # Check cache self.label_files = img2label_paths(self.img_files) # labels cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') + try: cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files) @@ -600,6 +600,7 @@ def __getitem__(self, index): @staticmethod def collate_fn(batch): + # print('batch', type(batch), batch) img, label, path, shapes = zip(*batch) # transposed for i, l in enumerate(label): l[:, 0] = i # add target image index for build_targets() @@ -800,7 +801,7 @@ def flatten_recursive(path='../datasets/coco128'): shutil.copyfile(file, new_path / Path(file).name) -def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; extract_boxes() +def extract_boxes(path='../datasets/coco128'): # from utils.datasets_old import *; extract_boxes() # Convert detection dataset into classification dataset, with one directory per class path = Path(path) # images dir shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing diff --git a/utils/file.py b/utils/file.py new file mode 100644 index 000000000000..e0c3066df7ec --- /dev/null +++ b/utils/file.py @@ -0,0 +1,24 @@ +from glob import glob +from pathlib import Path +from typing import List, Union, Optional + + +def read_text_file_lines(file_path: Union[str, Path], remove_blank: bool = True) -> List[str]: + with open(file_path, "r") as file: + lines = [l.strip(' \n') for l in file.readlines()] + if remove_blank: + return list(filter(lambda l: len(l) > 0, lines)) + else: + return lines + + +def get_directory_content(directory_path: str, extension: Optional[str] = None) -> List[str]: + wild_card = '*' if extension is None else f'*.{extension}' + pattern = Path(directory_path).joinpath(wild_card).as_posix() + return glob(pattern) + + +def dump_text_file(file_path: str, content: str) -> None: + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + with open(file_path, 'w', encoding='utf-8') as file: + file.write(content) diff --git a/utils/loggers/wandb/wandb_utils.py b/utils/loggers/wandb/wandb_utils.py index 8b2095afcb8b..d3f33c70415e 100644 --- a/utils/loggers/wandb/wandb_utils.py +++ b/utils/loggers/wandb/wandb_utils.py @@ -12,8 +12,8 @@ FILE = Path(__file__).absolute() sys.path.append(FILE.parents[3].as_posix()) # add yolov5/ to path -from utils.datasets import LoadImagesAndLabels -from utils.datasets import img2label_paths +from utils.datasets_old import LoadImagesAndLabels +from utils.datasets_old import img2label_paths from utils.general import check_dataset, check_file try: diff --git a/val.py b/val.py index cbee8cf1c026..98a4db80dbc1 100644 --- a/val.py +++ b/val.py @@ -21,7 +21,7 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path from models.experimental import attempt_load -from utils.datasets import create_dataloader +from utils.datasets_old import create_dataloader from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr from utils.metrics import ap_per_class, ConfusionMatrix