Skip to content

Commit

Permalink
Fix some issues, fix defaults
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Jun 18, 2024
1 parent c9fe73a commit 5a9b70f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 40 deletions.
2 changes: 1 addition & 1 deletion project/configs/datamodule/imagenet.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- vision
- vision
_target_: project.datamodules.ImageNetDataModule
# todo: add good configuration options here.
64 changes: 37 additions & 27 deletions project/datamodules/image_classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
from torchvision.transforms import v2 as transform_lib

from project.datamodules.vision.base import VisionDataModule
from project.utils.env_vars import NUM_WORKERS, DATA_DIR
from project.utils.types import C, H, StageStr, W
from project.utils.types.protocols import Module

logger = get_logger(__name__)


def imagenet_normalization():
return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return transform_lib.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)


type ClassIndex = int
Expand Down Expand Up @@ -64,11 +67,10 @@ class ImageNetDataModule(VisionDataModule):

def __init__(
self,
data_dir: str | Path | None = None,
data_dir: str | Path = DATA_DIR,
*,
val_split: int
| float = 0.01, # save `val_split`% of the training data *of each class* for validation.
num_workers: int | None = None,
val_split: int | float = 0.01,
num_workers: int = NUM_WORKERS,
normalize: bool = False,
image_size: int = 224,
batch_size: int = 32,
Expand Down Expand Up @@ -138,7 +140,9 @@ def setup(self, stage: StageStr | None = None) -> None:
logger.debug(f"Setup ImageNet datamodule for {stage=}")
super().setup(stage)

def _split_dataset(self, dataset: ImageNet, train: bool = True) -> torch.utils.data.Dataset:
def _split_dataset(
self, dataset: ImageNet, train: bool = True
) -> torch.utils.data.Dataset:
class_item_indices: dict[ClassIndex, list[ImageIndex]] = defaultdict(list)
for dataset_index, y in enumerate(dataset.targets):
class_item_indices[y].append(dataset_index)
Expand Down Expand Up @@ -261,6 +265,8 @@ def prepare_imagenet(
train_archive_file_name = "ILSVRC2012_img_train.tar"
devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz"
md5sums_file_name = "md5sums"
if not root.exists():
root.mkdir(parents=True)

def _symlink_if_needed(filename: str, network_imagenet_dir: Path):
if not (symlink := root / filename).exists():
Expand All @@ -283,11 +289,14 @@ def _symlink_if_needed(filename: str, network_imagenet_dir: Path):
train_dir = root / "train"
train_dir.mkdir(exist_ok=True, parents=True)
train_archive = network_imagenet_dir / train_archive_file_name
previously_extracted_dirs_file = train_dir / ".previously_extracted_dirs.txt"
_extract_train_archive(
train_archive=train_archive,
train_dir=train_dir,
previously_extracted_dirs_file=root / "previously_extracted_dirs.txt",
previously_extracted_dirs_file=previously_extracted_dirs_file,
)
if previously_extracted_dirs_file.exists():
previously_extracted_dirs_file.unlink()

# OR: could just reuse the equivalent-ish from torchvision, but which doesn't support
# resuming after an interrupt.
Expand Down Expand Up @@ -336,37 +345,42 @@ def _extract_train_archive(
*, train_archive: Path, train_dir: Path, previously_extracted_dirs_file: Path
) -> None:
# The ImageNet train archive is a tarfile of tarfiles (one for each class).
logger.debug("Extracting the ImageNet train archive using Olexa's tar magic in python form...")
logger.debug(
"Extracting the ImageNet train archive using Olexa's tar magic in python form..."
)
train_dir.mkdir(exist_ok=True, parents=True)

# Save a small text file or something that tells us which subdirs are
# done extracting so we can just skip ahead to the right directory?
previously_extracted_dirs: set[str] = set()

if previously_extracted_dirs_file.exists():
previously_extracted_dirs = set(
stripped_line
for line in previously_extracted_dirs_file.read_text().splitlines()
if (stripped_line := line.strip())
)
if len(previously_extracted_dirs) == 1000:
logger.info("Train archive already fully extracted. Skipping.")
return
logger.debug(
f"{len(previously_extracted_dirs)} directories have already been fully extracted."
)
previously_extracted_dirs_file.write_text(
"\n".join(sorted(previously_extracted_dirs)) + "\n"
)

if len(previously_extracted_dirs) == 1000:
elif len(list(train_dir.iterdir())) == 1000:
logger.info("Train archive already fully extracted. Skipping.")
return

with tarfile.open(train_archive, mode="r") as train_tarfile:
for class_id, member in enumerate(
tqdm.tqdm(
train_tarfile,
total=1000, # hard-coded here, since we know there are 1000 folders.
desc="Extracting train archive",
unit="Directories",
position=0,
)
for member in tqdm.tqdm(
train_tarfile,
total=1000, # hard-coded here, since we know there are 1000 folders.
desc="Extracting train archive",
unit="Directories",
position=0,
):
if member.name in previously_extracted_dirs:
continue
Expand All @@ -377,18 +391,14 @@ def _extract_train_archive(
class_subdir = train_dir / member.name.replace(".tar", "")
class_subdir_existed = class_subdir.exists()
if class_subdir_existed:
files_in_subdir = set(p.name for p in class_subdir.iterdir())
# Remove all the (potentially partially constructed) files in the directory.
logger.debug(f"Removing partially-constructed dir {class_subdir}")
shutil.rmtree(class_subdir, ignore_errors=False)
else:
class_subdir.mkdir(parents=True, exist_ok=True)
files_in_subdir = set()

with tarfile.open(fileobj=buffer, mode="r|*") as sub_tarfile:
for tarinfo in sub_tarfile:
image_file_path = class_subdir / tarinfo.name
if files_in_subdir and image_file_path.name in files_in_subdir:
# Image file is already in the directory.
continue
sub_tarfile.extract(tarinfo, class_subdir)

with tarfile.open(fileobj=buffer, mode="r|*") as class_tarfile:
class_tarfile.extractall(class_subdir, filter="data")

# Alternative: .extractall with a list of members to extract:
# members = sub_tarfile.getmembers() # note: loads the full archive.
Expand Down
46 changes: 34 additions & 12 deletions project/datamodules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
data_dir: str | Path = DATA_DIR,
val_split: int | float = 0.2,
num_workers: int | None = NUM_WORKERS,
num_workers: int = NUM_WORKERS,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
Expand Down Expand Up @@ -100,9 +100,15 @@ def __init__(

# todo: what about the shuffling at each epoch?
_rng = torch.Generator(device="cpu").manual_seed(self.seed)
self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())
self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())
self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())
self.train_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)
self.val_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)
self.test_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)

self.test_dataset_cls = self.dataset_cls

Expand Down Expand Up @@ -150,7 +156,9 @@ def setup(self, stage: StageStr | None = None) -> None:

if stage == "test" or stage is None:
logger.debug(f"creating test dataset with kwargs {self.train_kwargs}")
self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs)
self.dataset_test = self.test_dataset_cls(
str(self.data_dir), **self.test_kwargs
)

def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset:
"""Splits the dataset into train and validation set."""
Expand Down Expand Up @@ -182,7 +190,9 @@ def _get_splits(self, len_dataset: int) -> list[int]:
def default_transforms(self) -> Callable:
"""Default transform for the dataset."""

def train_dataloader[**P](
def train_dataloader[
**P
](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -202,7 +212,9 @@ def train_dataloader[**P](
),
)

def val_dataloader[**P](
def val_dataloader[
**P
](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -214,10 +226,15 @@ def val_dataloader[**P](
self.dataset_val,
_dataloader_fn=_dataloader_fn,
*args,
**(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs),
**(
dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed))
| kwargs
),
)

def test_dataloader[**P](
def test_dataloader[
**P
](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -231,10 +248,15 @@ def test_dataloader[**P](
self.dataset_test,
_dataloader_fn=_dataloader_fn,
*args,
**(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs),
**(
dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed))
| kwargs
),
)

def _data_loader[**P](
def _data_loader[
**P
](
self,
dataset: Dataset,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
Expand All @@ -247,7 +269,7 @@ def _data_loader[**P](
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
persistent_workers=True if self.num_workers > 0 else False,
persistent_workers=(self.num_workers or 0) > 0,
)
| dataloader_kwargs
)
Expand Down

0 comments on commit 5a9b70f

Please sign in to comment.