-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Front page example for VHR10 dataset does not work #1686
Comments
I'm not sure if you want it, but after staring at this for awhile to figure out what it is I'm looking at (never used any of this stuff before) this is what I came up with: import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchgeo.datasets import VHR10
# Define the resize transform
resize_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
# Custom collate function
def custom_collate(batch):
images = [item["image"] for item in batch]
labels = [item["labels"] for item in batch]
resized_images = [resize_transform(img) for img in images]
resized_images = torch.stack(resized_images)
# Since labels can have different lengths, we keep them as a list instead of stacking
return {"image": resized_images, "labels": labels}
# Initialize the dataset
dataset = VHR10(root="./raw_data", download=True, checksum=True)
# Initialize the dataloader with the custom collate function
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, collate_fn=custom_collate)
# Training loop
for batch in dataloader:
image = batch["image"]
labels = batch["labels"] I can PR it with an explanation for noobies if you want but like you said I'm not sure if it's what you want. |
Our datasets aren't really compatible with torchvision transforms, you'll have much better luck with kornia transforms. Something like: from kornia.augmentation import Resize
from torchgeo.transforms import AugmentationSequential
transforms = AugmentationSequential(
Resize(..., ...), data_keys=["image"]
) See https://torchgeo.readthedocs.io/en/stable/tutorials/transforms.html for more examples. I tried this and I don't think it's compatible with our current design of VHR10. This should be reworked in #1082. In the meantime, it's probably easier to give an example using a different dataset where images don't require resizing. |
@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example. |
Apologies for my delayed response. I've already done all my other modeling with the VHR10 dataset so for me that's what I'll probably stick with. If I get the chance though, I'll write something up for EuroSAT. |
@adamjstewart Can you explain why its complicated? I'm facing a similar issue with the chesapeake dataset. I read the corresponding datamodule code, but its unclear why resizing the image before applying a crop is the ideal solution, especially for applications where the pixel resolution matters. Furthermore, each tile in the dataset is quite large, so I'm also not sure why this resizing is even necessary. |
VHR-10 is complicated because it has images, masks, and bounding boxes. Chesapeake only has images and masks, so it's much easier. I think the problem with Chesapeake is slightly different since it's a GeoDataset. In theory, resize/crop shouldn't be needed, but it's needed right now because it's not using the RasterDataset base class. If you open a separate issue maybe @calebrob6 can take a look at fixing this. |
The following example will work. Though a simpler dataset might be better suited for the README. import kornia.augmentation as K
import torch
from torch.utils.data import DataLoader
from torchgeo.datamodules.utils import AugPipe, collate_fn_detection
from torchgeo.datasets import VHR10
from torchgeo.transforms import AugmentationSequential
batch_size = 2
# Initialize the dataset
dataset = VHR10(root="./raw_data/", download=True, checksum=True)
# Initialize the dataloader with the custom collate function
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=collate_fn_detection,
)
# Initialize augs to normalize and resize images to size (512, 512)
aug = AugPipe(
augs=AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Resize((512, 512)),
data_keys=["image", "boxes", "masks"],
),
batch_size=batch_size,
)
# Training loop
for batch in dataloader:
batch = aug(batch)
images = batch["image"] # List of images
boxes = batch["boxes"] # List of boxes
labels = batch["labels"] # List of labels
masks = batch["masks"] # List of masks |
I do really like the VHR-10 pic we use in the README though... Want to submit a PR to use that code to fix the README example? I would also except a PR that uses a different dataset like EuroSAT instead. We're trying to release 0.5.2 tomorrow or Saturday so it kinda needs to happen fast if we want to get this fixed before the next release. |
Ok, let's go with VHR-10 (#1920) for now. We can always switch later. |
Description
I started by copying and pasting the example as is from the frontpage:
This produces:
Steps to reproduce
Version
0.5.0
The text was updated successfully, but these errors were encountered: