diff --git a/tests/data/reforestree/data.py b/tests/data/reforestree/data.py index 6479535cacb..41d210a31b8 100644 --- a/tests/data/reforestree/data.py +++ b/tests/data/reforestree/data.py @@ -51,7 +51,6 @@ def create_img(path: str) -> None: if __name__ == "__main__": data_root = "data" # remove old data - # remove old data if os.path.isdir(data_root): shutil.rmtree(data_root) else: diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index a94ea7620dd..6121ca4af29 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -21,7 +21,7 @@ class ReforesTree(VisionDataset): """ReforesTree dataset. - The `ReforesTree `_ + The `ReforesTree `__ dataset contains drone imagery that can be used for tree crown detection, tree species classification and Aboveground Biomass (AGB) estimation. @@ -48,21 +48,13 @@ class ReforesTree(VisionDataset): If you use this dataset in your research, please cite the following paper: - * Reiersen, Gyri, et al. "ReforesTree: A Dataset for Estimating Tropical - Forest Carbon Stock with Deep Learning and Aerial Imagery." - arXiv preprint arXiv:2201.11192 (2022). + * https://arxiv.org/abs/2201.11192 .. versionadded:: 0.3 """ classes = ["other", "banana", "cacao", "citrus", "fruit", "timber"] - url = ( - "https://uc62e9acc213a8f56919a7a7add8.dl.dropboxusercontent.com/" - "zip_download_get/BKn4_obEfo0iR6UUpOqBMp9K0ohkldQtJTUN802va4SgJvd" - "OsN36ap5fcbbxFLbhPMHoI1oJBrkjwJWyYkcBE-SxCeFPTZ0a-gOBozXL7CfaKg?" - "_download_id=600502027618175772326304665489490007078769290252973" - "08895167120694292&_notify_domain=www.dropbox.com&dl=1" - ) + url = "https://uc62e9acc213a8f56919a7a7add8.dl.dropboxusercontent.com/zip_download_get/BKn4_obEfo0iR6UUpOqBMp9K0ohkldQtJTUN802va4SgJvdOsN36ap5fcbbxFLbhPMHoI1oJBrkjwJWyYkcBE-SxCeFPTZ0a-gOBozXL7CfaKg?_download_id=60050202761817577232630466548949000707876929025297308895167120694292&_notify_domain=www.dropbox.com&dl=1" md5 = "f6a4a1d8207aeaa5fbab7b21b683a302" zipfilename = "data.zip" @@ -161,7 +153,7 @@ def _load_image(self, path: str) -> Tensor: """ with Image.open(path) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor: Tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -177,13 +169,11 @@ def _load_target(self, filepath: str) -> Tuple[Tensor, ...]: """ tile_df = self.annot_df[self.annot_df["img_path"] == os.path.basename(filepath)] - boxes: Tensor = torch.Tensor( - tile_df[["xmin", "ymin", "xmax", "ymax"]].values.tolist() - ) - labels: Tensor = torch.Tensor( + boxes = torch.Tensor(tile_df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()) + labels = torch.Tensor( [self.class2idx[label] for label in tile_df["group"].tolist()] ) - agb: Tensor = torch.Tensor(tile_df["AGB"].tolist()) + agb = torch.Tensor(tile_df["AGB"].tolist()) return boxes, labels, agb @@ -192,6 +182,9 @@ def _verify(self) -> None: Returns: True if the dataset directories are found, else False + + Raises: + Runtime Error if dataset is not found in root or is corrupted """ filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]] if all([os.path.exists(filepath) for filepath in filepaths]): @@ -207,8 +200,8 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download " + f"Dataset not found in {self.root} directory, either specify a" + + " different `root` directory or manually download " + "the dataset to this directory." )