diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 9c3e701c9b7..43756df71a5 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -289,7 +289,8 @@ def _load_image(self, id_: int) -> Tensor: ) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array).float() + tensor = torch.from_numpy(array) + tensor = tensor.float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor