From af5d2159a8a5fcf22491b12be593517eaf9726cc Mon Sep 17 00:00:00 2001 From: xmba15 Date: Tue, 9 Jul 2024 22:49:12 +0900 Subject: [PATCH] fix missing modality --- README.md | 1 + config/base.yaml | 25 +- config/base_full_modality_1.yaml | 17 +- config/base_full_modality_2.yaml | 19 +- config/base_full_modality_3.yaml | 17 +- config/base_full_modality_4.yaml | 49 +++ .../base_full_modality_input_dropout_1.yaml | 49 +++ .../base_full_modality_input_dropout_2.yaml | 49 +++ .../base_full_modality_input_dropout_3.yaml | 49 +++ .../base_full_modality_input_dropout_4.yaml | 49 +++ config/base_missing_modality_1.yaml | 48 +++ config/base_missing_modality_2.yaml | 49 +++ config/result_config.yaml | 29 ++ scripts/submit_result.py | 167 ++++++++++ scripts/test_inference_val.py | 167 ++++++++++ scripts/train.py | 94 ++++-- scripts/train_full_modality.py | 46 ++- scripts/train_full_modality_input_dropout.py | 304 ++++++++++++++++++ scripts/train_shared_street_ortho.py | 293 +++++++++++++++++ src/integrated/model.py | 207 +++++++++++- src/models/loss.py | 28 ++ src/models/model.py | 127 ++++++++ src/utils/utils.py | 9 +- 23 files changed, 1807 insertions(+), 85 deletions(-) create mode 100644 config/base_full_modality_4.yaml create mode 100644 config/base_full_modality_input_dropout_1.yaml create mode 100644 config/base_full_modality_input_dropout_2.yaml create mode 100644 config/base_full_modality_input_dropout_3.yaml create mode 100644 config/base_full_modality_input_dropout_4.yaml create mode 100644 config/base_missing_modality_1.yaml create mode 100644 config/base_missing_modality_2.yaml create mode 100644 config/result_config.yaml create mode 100644 scripts/submit_result.py create mode 100644 scripts/test_inference_val.py create mode 100644 scripts/train_full_modality_input_dropout.py create mode 100644 scripts/train_shared_street_ortho.py diff --git a/README.md b/README.md index 71ec321..4e90b5c 100644 --- a/README.md +++ b/README.md @@ -16,3 +16,4 @@ mamba activate ai4eo --- - [Model Fusion for Building Type Classification from Aerial and Street View Images](https://www.mdpi.com/2072-4292/11/11/1259#) +- [Multi-modal fusion of satellite and street-view images for urban village classification based on a dual-branch deep neural network](https://www.sciencedirect.com/science/article/pii/S0303243422001209) diff --git a/config/base.yaml b/config/base.yaml index 789b5b2..48ea3c4 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -1,31 +1,36 @@ --- -seed: 1984 +seed: 2024 num_workers: 4 -experiment_name: "2024-04-07" +experiment_name: "2024-07-09-missing-modality" dataset: - val_split: 0.1 + n_splits: 10 + fold_th: 3 train_dir: ~/publicWorkspace/data/building-age-dataset/train/data test_dir: ~/publicWorkspace/data/building-age-dataset/test/data train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv model: - encoder_name: efficientnet_b2 + encoder_name: tf_efficientnetv2_b3 num_classes: 7 optimizer: type: timm.optim.AdamP - lr: 0.0005 + lr: 0.00025 + weight_decay: 0.00001 scheduler: - type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts - T_0: 10 - T_mult: 2 + type: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.5 + patience: 5 + threshold: 0.00005 + verbose: True trainer: - devices: 1 + devices: [0] accelerator: "cuda" max_epochs: 50 gradient_clip_val: 5.0 @@ -36,7 +41,7 @@ train_parameters: batch_size: 3 val_parameters: - batch_size: 3 + batch_size: 4 output_root_dir: experiments image_size: 512 diff --git a/config/base_full_modality_1.yaml b/config/base_full_modality_1.yaml index 23928aa..eb15c57 100644 --- a/config/base_full_modality_1.yaml +++ b/config/base_full_modality_1.yaml @@ -17,23 +17,24 @@ model: encoder_name: tf_efficientnetv2_s num_classes: 7 +loss: + classification: + type: src.models.FocalLossLabelSmoothing + optimizer: type: timm.optim.AdamW - lr: 0.0005 + lr: 0.0002 weight_decay: 0.001 scheduler: - type: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.5 - patience: 10 - threshold: 0.00005 - verbose: True + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 trainer: devices: [0] accelerator: "cuda" - max_epochs: 50 + max_epochs: 30 gradient_clip_val: 5.0 accumulate_grad_batches: 8 resume_from_checkpoint: diff --git a/config/base_full_modality_2.yaml b/config/base_full_modality_2.yaml index add709c..bd9f32a 100644 --- a/config/base_full_modality_2.yaml +++ b/config/base_full_modality_2.yaml @@ -14,26 +14,27 @@ dataset: model: type: src.models.MultiModalNetFullModalityFeatureFusion - encoder_name: nextvit_base + encoder_name: mobilevitv2_150 num_classes: 7 +loss: + classification: + type: src.models.FocalLossLabelSmoothing + optimizer: type: timm.optim.AdamW - lr: 0.0005 + lr: 0.0002 weight_decay: 0.001 scheduler: - type: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.5 - patience: 10 - threshold: 0.00005 - verbose: True + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 trainer: devices: [1] accelerator: "cuda" - max_epochs: 50 + max_epochs: 30 gradient_clip_val: 5.0 accumulate_grad_batches: 8 resume_from_checkpoint: diff --git a/config/base_full_modality_3.yaml b/config/base_full_modality_3.yaml index 3673e71..2c3f715 100644 --- a/config/base_full_modality_3.yaml +++ b/config/base_full_modality_3.yaml @@ -17,23 +17,24 @@ model: encoder_name: tf_efficientnetv2_b3 num_classes: 7 +loss: + classification: + type: src.models.FocalLossLabelSmoothing + optimizer: type: timm.optim.AdamW - lr: 0.0005 + lr: 0.0002 weight_decay: 0.001 scheduler: - type: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.5 - patience: 10 - threshold: 0.00005 - verbose: True + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 trainer: devices: [0] accelerator: "cuda" - max_epochs: 50 + max_epochs: 30 gradient_clip_val: 5.0 accumulate_grad_batches: 8 resume_from_checkpoint: diff --git a/config/base_full_modality_4.yaml b/config/base_full_modality_4.yaml new file mode 100644 index 0000000..2c3f715 --- /dev/null +++ b/config/base_full_modality_4.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-08-f2" + +dataset: + n_splits: 10 + fold_th: 2 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + type: src.models.MultiModalNetFullModalityGeometricFusion + encoder_name: tf_efficientnetv2_b3 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 8 + resume_from_checkpoint: + +train_parameters: + batch_size: 6 + +val_parameters: + batch_size: 6 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_full_modality_input_dropout_1.yaml b/config/base_full_modality_input_dropout_1.yaml new file mode 100644 index 0000000..69a9e14 --- /dev/null +++ b/config/base_full_modality_input_dropout_1.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-08-input-dropout-f5" + +dataset: + n_splits: 10 + fold_th: 5 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + type: src.models.MultiModalNetFullModalityGeometricFusion + encoder_name: tf_efficientnetv2_b3 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 8 + resume_from_checkpoint: + +train_parameters: + batch_size: 6 + +val_parameters: + batch_size: 6 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_full_modality_input_dropout_2.yaml b/config/base_full_modality_input_dropout_2.yaml new file mode 100644 index 0000000..f29c628 --- /dev/null +++ b/config/base_full_modality_input_dropout_2.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-08-input-dropout-f6" + +dataset: + n_splits: 10 + fold_th: 6 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + type: src.models.MultiModalNetFullModalityFeatureFusion + encoder_name: fastvit_sa24 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 8 + resume_from_checkpoint: + +train_parameters: + batch_size: 4 + +val_parameters: + batch_size: 4 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_full_modality_input_dropout_3.yaml b/config/base_full_modality_input_dropout_3.yaml new file mode 100644 index 0000000..c536bd2 --- /dev/null +++ b/config/base_full_modality_input_dropout_3.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-08-input-dropout-f6" + +dataset: + n_splits: 10 + fold_th: 6 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + type: src.models.MultiModalNetFullModalityFeatureFusion + encoder_name: mobilevitv2_150 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 10 + resume_from_checkpoint: + +train_parameters: + batch_size: 4 + +val_parameters: + batch_size: 4 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_full_modality_input_dropout_4.yaml b/config/base_full_modality_input_dropout_4.yaml new file mode 100644 index 0000000..c536bd2 --- /dev/null +++ b/config/base_full_modality_input_dropout_4.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-08-input-dropout-f6" + +dataset: + n_splits: 10 + fold_th: 6 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + type: src.models.MultiModalNetFullModalityFeatureFusion + encoder_name: mobilevitv2_150 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 10 + resume_from_checkpoint: + +train_parameters: + batch_size: 4 + +val_parameters: + batch_size: 4 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_missing_modality_1.yaml b/config/base_missing_modality_1.yaml new file mode 100644 index 0000000..5aab30d --- /dev/null +++ b/config/base_missing_modality_1.yaml @@ -0,0 +1,48 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-09-missing-modality-f3" + +dataset: + n_splits: 10 + fold_th: 3 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + encoder_name: tf_efficientnetv2_b1 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 50 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 50 + gradient_clip_val: 5.0 + accumulate_grad_batches: 8 + resume_from_checkpoint: + +train_parameters: + batch_size: 5 + +val_parameters: + batch_size: 5 + +output_root_dir: experiments +image_size: 512 diff --git a/config/base_missing_modality_2.yaml b/config/base_missing_modality_2.yaml new file mode 100644 index 0000000..37ec47b --- /dev/null +++ b/config/base_missing_modality_2.yaml @@ -0,0 +1,49 @@ +--- +seed: 2024 + +num_workers: 4 +experiment_name: "2024-07-09-missing-modality-shared-f4" + +dataset: + n_splits: 10 + fold_th: 4 + train_dir: ~/publicWorkspace/data/building-age-dataset/train/data + test_dir: ~/publicWorkspace/data/building-age-dataset/test/data + train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv + test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv + +model: + encoder_name: tf_efficientnetv2_b1 + s2_encoder_name: tf_efficientnetv2_b0 + num_classes: 7 + +loss: + classification: + type: src.models.FocalLossLabelSmoothing + +optimizer: + type: timm.optim.AdamW + lr: 0.0002 + weight_decay: 0.001 + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 30 + eta_min: 0.00005 + +trainer: + devices: [0] + accelerator: "cuda" + max_epochs: 30 + gradient_clip_val: 5.0 + accumulate_grad_batches: 8 + resume_from_checkpoint: + +train_parameters: + batch_size: 5 + +val_parameters: + batch_size: 5 + +output_root_dir: experiments +image_size: 512 diff --git a/config/result_config.yaml b/config/result_config.yaml new file mode 100644 index 0000000..b97c5f4 --- /dev/null +++ b/config/result_config.yaml @@ -0,0 +1,29 @@ +# tf_efficientnetv2_s +base_full_modality_1: + config_path: ./config/base_full_modality_1.yaml + weights_path: ./ + +# mobilevitv2_150 +base_full_modality_2: + config_path: ./config/base_full_modality_2.yaml + weights_path: ./ + +# tf_efficientnetv2_b3 +base_full_modality_3: + config_path: ./config/base_full_modality_3.yaml + weights_path: ./ + +# tf_efficientnetv2_b3 +base_full_modality_input_dropout_1: + config_path: ./config/base_full_modality_input_dropout_1.yaml + weights_path: ./ + +# fastvit_sa24 +base_full_modality_input_dropout_2: + config_path: ./config/base_full_modality_input_dropout_2.yaml + weights_path: ./ + +# mobilevitv2_150 +base_full_modality_input_dropout_3: + config_path: ./config/base_full_modality_input_dropout_3.yaml + weights_path: ./ diff --git a/scripts/submit_result.py b/scripts/submit_result.py new file mode 100644 index 0000000..1bdb390 --- /dev/null +++ b/scripts/submit_result.py @@ -0,0 +1,167 @@ +import argparse +import os +import sys + +import albumentations as alb +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import tqdm +import yaml +from albumentations.pytorch import ToTensorV2 +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Subset +from torchmetrics import Accuracy + +_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(_CURRENT_DIR, "../")) +from src.data import CountryCode, CustomSubset, MapYourCityDataset, S2RandomRotation +from src.integrated import MultiModalNetFullModalityPl +from src.models import MultiModalNet +from src.utils import fix_seed, worker_init_fn + + +def get_args(): + parser = argparse.ArgumentParser("test inference") + parser.add_argument("--config_path", type=str, default="./config/base_missing_modality_1.yaml") + parser.add_argument("--checkpoint_path", type=str, required=True) + + return parser.parse_args() + + +def get_transforms(hparams): + image_size = hparams["image_size"] + + all_transforms = {} + all_transforms["street"] = { + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + all_transforms["ortho"] = { + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + def clip_s2(image, **params): + return np.clip(image, 0, 10000) + + all_transforms["s2"] = { + "val": alb.Compose( + [ + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + } + + return all_transforms + + +def setup_train_val_split( + original_dataset, + hparams, +): + kf = StratifiedKFold( + n_splits=hparams["dataset"]["n_splits"], + shuffle=True, + random_state=hparams["seed"], + ) + + train_indices, val_indices = list( + kf.split( + range(len(original_dataset)), + original_dataset.labels, + ) + )[hparams["dataset"]["fold_th"]] + + return train_indices, val_indices + + +def main(): + args = get_args() + assert os.path.isfile(args.checkpoint_path) + with open(args.config_path, encoding="utf-8") as f: + hparams = yaml.load(f, Loader=yaml.SafeLoader) + os.makedirs(hparams["output_root_dir"], exist_ok=True) + fix_seed(hparams["seed"]) + pl.seed_everything(hparams["seed"]) + + dataset = MapYourCityDataset( + csv_path=hparams["dataset"]["train_csv"], + data_dir=hparams["dataset"]["train_dir"], + train=True, + ) + + _, val_indices = setup_train_val_split(dataset, hparams) + + transforms_dict = get_transforms(hparams) + val_dataset = CustomSubset( + Subset(dataset, val_indices), + transforms_dict={ + "street": transforms_dict["street"]["val"], + "ortho": transforms_dict["ortho"]["val"], + "s2": transforms_dict["s2"]["val"], + }, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=hparams["val_parameters"]["batch_size"], + num_workers=hparams["num_workers"], + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + accuracy = Accuracy(task="multiclass", num_classes=hparams["model"]["num_classes"]) + accuracy_missing_modality = Accuracy(task="multiclass", num_classes=hparams["model"]["num_classes"]) + model = MultiModalNetFullModalityPl.load_from_checkpoint( + args.checkpoint_path, + hparams=hparams, + map_location=device, + ) + model.eval() + + for batch in tqdm.tqdm(val_loader): + images, s2_data, country_id, label = batch + images = images.to(device) + s2_data = s2_data.to(device) + country_id = country_id.to(device) + + with torch.no_grad(): + logits = model(images, s2_data, country_id) + logits = logits.cpu() + preds = torch.argmax(logits.cpu(), dim=1) + + images[:, 3:, :, :] = 0 + with torch.no_grad(): + logits_missing_modality = model(images, s2_data, country_id) + + logits_missing_modality = logits_missing_modality.cpu() + preds_missing_modality = torch.argmax(logits_missing_modality.cpu(), dim=1) + + accuracy.update(preds, label) + accuracy_missing_modality.update(preds_missing_modality, batch[3]) + + final_accuracy = accuracy.compute() + final_accuracy_missing_modality = accuracy_missing_modality.compute() + print(f"final accuracy full modality: {final_accuracy}") + print(f"final accuracy missing modality: {final_accuracy_missing_modality}") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_inference_val.py b/scripts/test_inference_val.py new file mode 100644 index 0000000..d001b0b --- /dev/null +++ b/scripts/test_inference_val.py @@ -0,0 +1,167 @@ +import argparse +import os +import sys + +import albumentations as alb +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import tqdm +import yaml +from albumentations.pytorch import ToTensorV2 +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Subset +from torchmetrics import Accuracy + +_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(_CURRENT_DIR, "../")) +from src.data import CountryCode, CustomSubset, MapYourCityDataset, S2RandomRotation +from src.integrated import MultiModalNetFullModalityPl, MultiModalNetPl +from src.models import MultiModalNet +from src.utils import fix_seed, worker_init_fn + + +def get_args(): + parser = argparse.ArgumentParser("test inference") + parser.add_argument("--config_path", type=str, default="./config/base_missing_modality_1.yaml") + parser.add_argument("--checkpoint_path", type=str, required=True) + + return parser.parse_args() + + +def get_transforms(hparams): + image_size = hparams["image_size"] + + all_transforms = {} + all_transforms["street"] = { + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + all_transforms["ortho"] = { + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + def clip_s2(image, **params): + return np.clip(image, 0, 10000) + + all_transforms["s2"] = { + "val": alb.Compose( + [ + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + } + + return all_transforms + + +def setup_train_val_split( + original_dataset, + hparams, +): + kf = StratifiedKFold( + n_splits=hparams["dataset"]["n_splits"], + shuffle=True, + random_state=hparams["seed"], + ) + + train_indices, val_indices = list( + kf.split( + range(len(original_dataset)), + original_dataset.labels, + ) + )[hparams["dataset"]["fold_th"]] + + return train_indices, val_indices + + +def main(): + args = get_args() + assert os.path.isfile(args.checkpoint_path) + with open(args.config_path, encoding="utf-8") as f: + hparams = yaml.load(f, Loader=yaml.SafeLoader) + os.makedirs(hparams["output_root_dir"], exist_ok=True) + fix_seed(hparams["seed"]) + pl.seed_everything(hparams["seed"]) + + dataset = MapYourCityDataset( + csv_path=hparams["dataset"]["train_csv"], + data_dir=hparams["dataset"]["train_dir"], + train=True, + ) + + _, val_indices = setup_train_val_split(dataset, hparams) + + transforms_dict = get_transforms(hparams) + val_dataset = CustomSubset( + Subset(dataset, val_indices), + transforms_dict={ + "street": transforms_dict["street"]["val"], + "ortho": transforms_dict["ortho"]["val"], + "s2": transforms_dict["s2"]["val"], + }, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=hparams["val_parameters"]["batch_size"], + num_workers=hparams["num_workers"], + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + accuracy = Accuracy(task="multiclass", num_classes=hparams["model"]["num_classes"]) + accuracy_missing_modality = Accuracy(task="multiclass", num_classes=hparams["model"]["num_classes"]) + model = MultiModalNetFullModalityPl.load_from_checkpoint( + args.checkpoint_path, + hparams=hparams, + map_location=device, + ) + model.eval() + + for batch in tqdm.tqdm(val_loader): + images, s2_data, country_id, label = batch + images = images.to(device) + s2_data = s2_data.to(device) + country_id = country_id.to(device) + + with torch.no_grad(): + logits = model(images, s2_data, country_id) + logits = logits.cpu() + preds = torch.argmax(logits.cpu(), dim=1) + + images[:, 3:, :, :] = 0 + with torch.no_grad(): + logits_missing_modality = model(images, s2_data, country_id) + + logits_missing_modality = logits_missing_modality.cpu() + preds_missing_modality = torch.argmax(logits_missing_modality.cpu(), dim=1) + + accuracy.update(preds, label) + accuracy_missing_modality.update(preds_missing_modality, batch[3]) + + final_accuracy = accuracy.compute() + final_accuracy_missing_modality = accuracy_missing_modality.compute() + print(f"final accuracy full modality: {final_accuracy}") + print(f"final accuracy missing modality: {final_accuracy_missing_modality}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index 574dd09..0e75f89 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -3,6 +3,7 @@ import sys import albumentations as alb +import cv2 import numpy as np import pytorch_lightning as pl import yaml @@ -10,7 +11,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from sklearn.model_selection import train_test_split +from sklearn.model_selection import StratifiedKFold from torch.utils.data import DataLoader, Subset _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -22,8 +23,8 @@ def get_args(): - parser = argparse.ArgumentParser("train multimodal") - parser.add_argument("--config_path", type=str, default="./config/base.yaml") + parser = argparse.ArgumentParser("train for missing-modality inference") + parser.add_argument("--config_path", type=str, default="./config/base_missing_modality_1.yaml") return parser.parse_args() @@ -35,18 +36,29 @@ def get_transforms(hparams): all_transforms["street"] = { "train": alb.Compose( [ - alb.RandomCropFromBorders(crop_left=0.05, crop_right=0.05, crop_top=0.05, crop_bottom=0.05, p=0.5), alb.OneOf( [ alb.Compose( [ alb.Resize(height=image_size, width=image_size, p=1.0), - alb.Rotate(limit=(-5, 5), p=0.7), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), ] ), alb.Compose( [ - alb.Rotate(limit=(-5, 5), p=0.7), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), alb.Resize(height=image_size, width=image_size, p=1.0), ] ), @@ -80,14 +92,31 @@ def get_transforms(hparams): all_transforms["ortho"] = { "train": alb.Compose( [ - alb.RandomCropFromBorders(crop_left=0.01, crop_right=0.01, crop_top=0.01, crop_bottom=0.01, p=0.6), alb.OneOf( [ alb.Compose( - [alb.Resize(height=image_size, width=image_size), alb.Rotate(limit=(0, 360), p=0.7)] + [ + alb.Resize(height=image_size, width=image_size), + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + ] ), alb.Compose( - [alb.Rotate(limit=(0, 360), p=0.7), alb.Resize(height=image_size, width=image_size)] + [ + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + alb.Resize(height=image_size, width=image_size), + ] ), ], p=1, @@ -95,6 +124,12 @@ def get_transforms(hparams): alb.ColorJitter(p=0.5), alb.AdvancedBlur(p=0.5), alb.Flip(p=0.7), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=100, max_holes=200), + ], + p=0.6, + ), alb.ToFloat(max_value=255.0), ToTensorV2(), ] @@ -112,14 +147,14 @@ def clip_s2(image, **params): return np.clip(image, 0, 10000) def extract_rgb(image, **params): - return image[:, :, [0, 1, 2]] + return image[:, :, [3, 2, 1]] all_transforms["s2"] = { "train": alb.Compose( [ alb.Lambda(image=extract_rgb), - S2RandomRotation(limits=(0, 360), always_apply=False, p=0.7), - alb.Flip(p=0.7), + S2RandomRotation(limits=(0, 360), always_apply=False, p=0.9), + alb.Flip(p=0.9), alb.Lambda(image=clip_s2), alb.ToFloat(max_value=10000.0), ToTensorV2(), @@ -142,12 +177,19 @@ def setup_train_val_split( original_dataset, hparams, ): - train_indices, val_indices = train_test_split( - range(len(original_dataset)), - stratify=original_dataset.labels, - test_size=hparams["dataset"]["val_split"], + kf = StratifiedKFold( + n_splits=hparams["dataset"]["n_splits"], + shuffle=True, + random_state=hparams["seed"], ) + train_indices, val_indices = list( + kf.split( + range(len(original_dataset)), + original_dataset.labels, + ) + )[hparams["dataset"]["fold_th"]] + return train_indices, val_indices @@ -229,11 +271,21 @@ def main(): ], ) - trainer.fit( - model, - train_loader, - val_loader, - ) + if hparams["trainer"]["resume_from_checkpoint"] is not None and os.path.isfile( + hparams["trainer"]["resume_from_checkpoint"] + ): + trainer.fit( + model, + train_loader, + val_loader, + ckpt_path=hparams["trainer"]["resume_from_checkpoint"], + ) + else: + trainer.fit( + model, + train_loader, + val_loader, + ) if __name__ == "__main__": diff --git a/scripts/train_full_modality.py b/scripts/train_full_modality.py index b321e88..323b20a 100644 --- a/scripts/train_full_modality.py +++ b/scripts/train_full_modality.py @@ -35,18 +35,29 @@ def get_transforms(hparams): all_transforms["street"] = { "train": alb.Compose( [ - alb.RandomCropFromBorders(crop_left=0.05, crop_right=0.05, crop_top=0.05, crop_bottom=0.05, p=0.5), alb.OneOf( [ alb.Compose( [ alb.Resize(height=image_size, width=image_size, p=1.0), - alb.Rotate(limit=(-5, 5), p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), ] ), alb.Compose( [ - alb.Rotate(limit=(-5, 5), p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), alb.Resize(height=image_size, width=image_size, p=1.0), ] ), @@ -62,7 +73,7 @@ def get_transforms(hparams): alb.GridDropout(), alb.Spatter(), ], - p=0.5, + p=0.7, ), alb.ToFloat(max_value=255.0), ToTensorV2(), @@ -80,18 +91,29 @@ def get_transforms(hparams): all_transforms["ortho"] = { "train": alb.Compose( [ - alb.RandomCropFromBorders(crop_left=0.01, crop_right=0.01, crop_top=0.01, crop_bottom=0.01, p=0.6), alb.OneOf( [ alb.Compose( [ alb.Resize(height=image_size, width=image_size), - alb.Rotate(limit=180, p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0), + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), ] ), alb.Compose( [ - alb.Rotate(limit=180, p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0), + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), alb.Resize(height=image_size, width=image_size), ] ), @@ -101,6 +123,12 @@ def get_transforms(hparams): alb.ColorJitter(p=0.5), alb.AdvancedBlur(p=0.5), alb.Flip(p=0.7), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=100, max_holes=200), + ], + p=0.6, + ), alb.ToFloat(max_value=255.0), ToTensorV2(), ] @@ -120,8 +148,8 @@ def clip_s2(image, **params): all_transforms["s2"] = { "train": alb.Compose( [ - S2RandomRotation(limits=(0, 360), always_apply=False, p=0.7), - alb.Flip(p=0.7), + S2RandomRotation(limits=(0, 360), always_apply=False, p=0.9), + alb.Flip(p=0.9), alb.Lambda(image=clip_s2), alb.ToFloat(max_value=10000.0), ToTensorV2(), diff --git a/scripts/train_full_modality_input_dropout.py b/scripts/train_full_modality_input_dropout.py new file mode 100644 index 0000000..7cdbdf4 --- /dev/null +++ b/scripts/train_full_modality_input_dropout.py @@ -0,0 +1,304 @@ +import argparse +import os +import sys + +import albumentations as alb +import cv2 +import numpy as np +import pytorch_lightning as pl +import yaml +from albumentations.pytorch import ToTensorV2 +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Subset + +_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(_CURRENT_DIR, "../")) +from src.data import CountryCode, CustomSubset, MapYourCityDataset, S2RandomRotation +from src.integrated import MultiModalNetFullModalityPl, MultiModalNetPl +from src.utils import fix_seed, worker_init_fn + + +def get_args(): + parser = argparse.ArgumentParser("train multimodal") + parser.add_argument("--config_path", type=str, default="./config/base_full_modality_1.yaml") + + return parser.parse_args() + + +def get_transforms(hparams): + image_size = hparams["image_size"] + + all_transforms = {} + + def create_input_drop_out(image, **params): + image_size = hparams["image_size"] + return np.zeros((int(image_size), int(image_size), 3), dtype=np.float32) + + all_transforms["street"] = { + "train": alb.Compose( + [ + alb.OneOf( + [ + alb.Compose( + [ + alb.Resize(height=image_size, width=image_size, p=1.0), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), + ] + ), + alb.Compose( + [ + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), + alb.Resize(height=image_size, width=image_size, p=1.0), + ] + ), + ], + p=1, + ), + alb.ColorJitter(p=0.5), + alb.AdvancedBlur(p=0.5), + alb.HorizontalFlip(p=0.5), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=200, max_holes=400), + alb.GridDropout(), + alb.Spatter(), + ], + p=0.7, + ), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + all_transforms["street"]["train"] = alb.OneOf( + [ + all_transforms["street"]["train"], + alb.Compose( + [ + alb.Lambda(image=create_input_drop_out), + ToTensorV2(), + ] + ), + ], + p=1.0, + ) + + all_transforms["ortho"] = { + "train": alb.Compose( + [ + alb.OneOf( + [ + alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + ] + ), + alb.Compose( + [ + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + alb.Resize(height=image_size, width=image_size), + ] + ), + ], + p=1, + ), + alb.ColorJitter(p=0.5), + alb.AdvancedBlur(p=0.5), + alb.Flip(p=0.7), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=100, max_holes=200), + ], + p=0.6, + ), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + def clip_s2(image, **params): + return np.clip(image, 0, 10000) + + all_transforms["s2"] = { + "train": alb.Compose( + [ + S2RandomRotation(limits=(0, 360), always_apply=False, p=0.9), + alb.Flip(p=0.9), + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + } + + return all_transforms + + +def setup_train_val_split( + original_dataset, + hparams, +): + kf = StratifiedKFold( + n_splits=hparams["dataset"]["n_splits"], + shuffle=True, + random_state=hparams["seed"], + ) + + train_indices, val_indices = list( + kf.split( + range(len(original_dataset)), + original_dataset.labels, + ) + )[hparams["dataset"]["fold_th"]] + + return train_indices, val_indices + + +def main(): + args = get_args() + with open(args.config_path, encoding="utf-8") as f: + hparams = yaml.load(f, Loader=yaml.SafeLoader) + os.makedirs(hparams["output_root_dir"], exist_ok=True) + fix_seed(hparams["seed"]) + pl.seed_everything(hparams["seed"]) + + dataset = MapYourCityDataset( + csv_path=hparams["dataset"]["train_csv"], + data_dir=hparams["dataset"]["train_dir"], + train=True, + ) + + train_indices, val_indices = setup_train_val_split(dataset, hparams) + + transforms_dict = get_transforms(hparams) + train_dataset = CustomSubset( + Subset(dataset, train_indices), + transforms_dict={ + "street": transforms_dict["street"]["train"], + "ortho": transforms_dict["ortho"]["train"], + "s2": transforms_dict["s2"]["train"], + }, + ) + + val_dataset = CustomSubset( + Subset(dataset, val_indices), + transforms_dict={ + "street": transforms_dict["street"]["val"], + "ortho": transforms_dict["ortho"]["val"], + "s2": transforms_dict["s2"]["val"], + }, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["train_parameters"]["batch_size"], + shuffle=True, + drop_last=True, + num_workers=hparams["num_workers"], + worker_init_fn=worker_init_fn, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=hparams["val_parameters"]["batch_size"], + num_workers=hparams["num_workers"], + ) + + model = MultiModalNetFullModalityPl(hparams) + trainer = Trainer( + default_root_dir=hparams["output_root_dir"], + max_epochs=hparams["trainer"]["max_epochs"], + devices=hparams["trainer"]["devices"], + accelerator=hparams["trainer"]["accelerator"], + gradient_clip_val=hparams["trainer"]["gradient_clip_val"], + accumulate_grad_batches=hparams["trainer"]["accumulate_grad_batches"], + deterministic=True, + logger=TensorBoardLogger( + save_dir=hparams["output_root_dir"], + version=f"{hparams['experiment_name']}_{hparams['model']['encoder_name']}_" + f"{hparams['train_parameters']['batch_size']*hparams['trainer']['accumulate_grad_batches']}_" + f"{hparams['optimizer']['lr']}", + name=f"{hparams['experiment_name']}_{hparams['model']['encoder_name']}", + ), + callbacks=[ + ModelCheckpoint( + monitor="val_acc", + mode="max", + save_top_k=1, + verbose=True, + ), + LearningRateMonitor(logging_interval="step"), + ], + ) + + if hparams["trainer"]["resume_from_checkpoint"] is not None and os.path.isfile( + hparams["trainer"]["resume_from_checkpoint"] + ): + trainer.fit( + model, + train_loader, + val_loader, + ckpt_path=hparams["trainer"]["resume_from_checkpoint"], + ) + else: + trainer.fit( + model, + train_loader, + val_loader, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_shared_street_ortho.py b/scripts/train_shared_street_ortho.py new file mode 100644 index 0000000..8b3465b --- /dev/null +++ b/scripts/train_shared_street_ortho.py @@ -0,0 +1,293 @@ +import argparse +import os +import sys + +import albumentations as alb +import cv2 +import numpy as np +import pytorch_lightning as pl +import yaml +from albumentations.pytorch import ToTensorV2 +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Subset + +_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(_CURRENT_DIR, "../")) +from src.data import CountryCode, CustomSubset, MapYourCityDataset, S2RandomRotation +from src.integrated import MultiModalNetSharedStreetOrthoPl +from src.models import MultiModalNet +from src.utils import fix_seed, worker_init_fn + + +def get_args(): + parser = argparse.ArgumentParser("train for missing-modality inference") + parser.add_argument("--config_path", type=str, default="./config/base_missing_modality_2.yaml") + + return parser.parse_args() + + +def get_transforms(hparams): + image_size = hparams["image_size"] + + all_transforms = {} + all_transforms["street"] = { + "train": alb.Compose( + [ + alb.OneOf( + [ + alb.Compose( + [ + alb.Resize(height=image_size, width=image_size, p=1.0), + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), + ] + ), + alb.Compose( + [ + alb.ShiftScaleRotate( + shift_limit=(-0.05, 0.05), + rotate_limit=(-5, 5), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.7, + ), + alb.Resize(height=image_size, width=image_size, p=1.0), + ] + ), + ], + p=1, + ), + alb.ColorJitter(p=0.5), + alb.AdvancedBlur(p=0.5), + alb.HorizontalFlip(p=0.5), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=200, max_holes=400), + alb.GridDropout(), + alb.Spatter(), + ], + p=0.7, + ), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + all_transforms["ortho"] = { + "train": alb.Compose( + [ + alb.OneOf( + [ + alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + ] + ), + alb.Compose( + [ + alb.ShiftScaleRotate( + shift_limit=(-0.005, 0.005), + rotate_limit=(-180, 180), + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.9, + ), + alb.Resize(height=image_size, width=image_size), + ] + ), + ], + p=1, + ), + alb.ColorJitter(p=0.5), + alb.AdvancedBlur(p=0.5), + alb.Flip(p=0.7), + alb.OneOf( + [ + alb.CoarseDropout(min_holes=100, max_holes=200), + ], + p=0.6, + ), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Resize(height=image_size, width=image_size), + alb.ToFloat(max_value=255.0), + ToTensorV2(), + ] + ), + } + + def clip_s2(image, **params): + return np.clip(image, 0, 10000) + + def extract_rgb(image, **params): + return image[:, :, [3, 2, 1]] + + all_transforms["s2"] = { + "train": alb.Compose( + [ + alb.Lambda(image=extract_rgb), + S2RandomRotation(limits=(0, 360), always_apply=False, p=0.9), + alb.Flip(p=0.9), + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + "val": alb.Compose( + [ + alb.Lambda(image=extract_rgb), + alb.Lambda(image=clip_s2), + alb.ToFloat(max_value=10000.0), + ToTensorV2(), + ] + ), + } + + return all_transforms + + +def setup_train_val_split( + original_dataset, + hparams, +): + kf = StratifiedKFold( + n_splits=hparams["dataset"]["n_splits"], + shuffle=True, + random_state=hparams["seed"], + ) + + train_indices, val_indices = list( + kf.split( + range(len(original_dataset)), + original_dataset.labels, + ) + )[hparams["dataset"]["fold_th"]] + + return train_indices, val_indices + + +def main(): + args = get_args() + with open(args.config_path, encoding="utf-8") as f: + hparams = yaml.load(f, Loader=yaml.SafeLoader) + os.makedirs(hparams["output_root_dir"], exist_ok=True) + fix_seed(hparams["seed"]) + pl.seed_everything(hparams["seed"]) + + dataset = MapYourCityDataset( + csv_path=hparams["dataset"]["train_csv"], + data_dir=hparams["dataset"]["train_dir"], + train=True, + ) + + train_indices, val_indices = setup_train_val_split(dataset, hparams) + + transforms_dict = get_transforms(hparams) + train_dataset = CustomSubset( + Subset(dataset, train_indices), + transforms_dict={ + "street": transforms_dict["street"]["train"], + "ortho": transforms_dict["ortho"]["train"], + "s2": transforms_dict["s2"]["train"], + }, + ) + + val_dataset = CustomSubset( + Subset(dataset, val_indices), + transforms_dict={ + "street": transforms_dict["street"]["val"], + "ortho": transforms_dict["ortho"]["val"], + "s2": transforms_dict["s2"]["val"], + }, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["train_parameters"]["batch_size"], + shuffle=True, + drop_last=True, + num_workers=hparams["num_workers"], + worker_init_fn=worker_init_fn, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=hparams["val_parameters"]["batch_size"], + num_workers=hparams["num_workers"], + ) + + model = MultiModalNetSharedStreetOrthoPl(hparams) + + trainer = Trainer( + default_root_dir=hparams["output_root_dir"], + max_epochs=hparams["trainer"]["max_epochs"], + devices=hparams["trainer"]["devices"], + accelerator=hparams["trainer"]["accelerator"], + gradient_clip_val=hparams["trainer"]["gradient_clip_val"], + accumulate_grad_batches=hparams["trainer"]["accumulate_grad_batches"], + deterministic=True, + logger=TensorBoardLogger( + save_dir=hparams["output_root_dir"], + version=f"{hparams['experiment_name']}_{hparams['model']['encoder_name']}_" + f"{hparams['train_parameters']['batch_size']*hparams['trainer']['accumulate_grad_batches']}_" + f"{hparams['optimizer']['lr']}", + name=f"{hparams['experiment_name']}_{hparams['model']['encoder_name']}", + ), + callbacks=[ + ModelCheckpoint( + monitor="val_acc", + mode="max", + save_top_k=1, + verbose=True, + ), + LearningRateMonitor(logging_interval="step"), + ], + ) + + if hparams["trainer"]["resume_from_checkpoint"] is not None and os.path.isfile( + hparams["trainer"]["resume_from_checkpoint"] + ): + trainer.fit( + model, + train_loader, + val_loader, + ckpt_path=hparams["trainer"]["resume_from_checkpoint"], + ) + else: + trainer.fit( + model, + train_loader, + val_loader, + ) + + +if __name__ == "__main__": + main() diff --git a/src/integrated/model.py b/src/integrated/model.py index 17c51ea..e16a4f1 100644 --- a/src/integrated/model.py +++ b/src/integrated/model.py @@ -4,11 +4,18 @@ from torchmetrics import Accuracy from torchvision.utils import make_grid -from src.models import DomainClsLoss, FocalLoss, MultiModalNet +from src.models import ( + DomainClsLoss, + FocalLoss, + FocalLossLabelSmoothing, + MultiModalNet, + MultiModalNetSharedStreetOrtho, +) from src.utils import get_object_from_dict __all__ = ( "MultiModalNetPl", + "MultiModalNetSharedStreetOrthoPl", "MultiModalNetFullModalityPl", ) @@ -23,18 +30,25 @@ def __init__(self, hparams): self.hparams["model"]["num_classes"], ) self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams["model"]["num_classes"]) + self.losses = [ - ("focal", 1.0, FocalLoss()), + ( + "focal", + 1.0, + get_object_from_dict( + self.hparams["loss"]["classification"], + ), + ), ("domain_cls", 0.02, DomainClsLoss()), ("distribution", 0.1, nn.L1Loss()), ] - def forward(self, batch): - return self.model(batch[0], batch[1], batch[2]) + def forward(self, images, s2_data, country_id): + return self.model(images, s2_data, country_id) def common_step(self, batch, batch_idx, is_val: bool = False): - _, _, _, label = batch - logits, spec_logits, shared_feats = self.forward(batch) + images, s2_data, country_id, label = batch + logits, spec_logits, shared_feats = self.forward(images, s2_data, country_id) batch_size = logits.shape[0] num_modal = spec_logits.shape[1] @@ -66,7 +80,7 @@ def common_step(self, batch, batch_idx, is_val: bool = False): return total_loss, losses_dict, acc def training_step(self, batch, batch_idx): - if batch_idx % 100 == 0: + if batch_idx % 4000 == 0: self.logger.experiment.add_image( "train_ortho", make_grid( @@ -85,6 +99,15 @@ def training_step(self, batch, batch_idx): global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx, ) + self.logger.experiment.add_image( + "train_s2", + make_grid( + batch[1], + nrow=batch[0].shape[0], + ), + global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx, + ) + total_loss, losses_dict, _ = self.common_step(batch, batch_idx, is_val=False) self.log( @@ -144,10 +167,168 @@ def configure_optimizers(self): params=[x for x in self.parameters() if x.requires_grad], ) - scheduler = get_object_from_dict( - self.hparams["scheduler"], - optimizer=optimizer, + scheduler = { + "scheduler": get_object_from_dict( + self.hparams["scheduler"], + optimizer=optimizer, + ), + "monitor": "val_loss", + } + + return [optimizer], [scheduler] + + +class MultiModalNetSharedStreetOrthoPl(pl.LightningModule): + def __init__(self, hparams): + super().__init__() + self.hparams.update(hparams) + + self.model = MultiModalNetSharedStreetOrtho( + self.hparams["model"]["encoder_name"], + self.hparams["model"]["s2_encoder_name"], + self.hparams["model"]["num_classes"], ) + self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams["model"]["num_classes"]) + + self.losses = [ + ( + "focal", + 1.0, + get_object_from_dict( + self.hparams["loss"]["classification"], + ), + ), + ("domain_cls", 0.02, DomainClsLoss()), + ("distribution", 0.1, nn.L1Loss()), + ] + + def forward(self, images, s2_data, country_id): + return self.model(images, s2_data, country_id) + + def common_step(self, batch, batch_idx, is_val: bool = False): + images, s2_data, country_id, label = batch + logits, spec_logits, shared_feats = self.forward(images, s2_data, country_id) + batch_size = logits.shape[0] + num_modal = spec_logits.shape[1] + + total_loss = 0.0 + losses_dict = {} + for loss_name, weight, loss_class in self.losses: + if loss_name == "focal": + cur_loss = loss_class(logits, label) + elif loss_name == "domain_cls": + spec_labels = torch.arange(num_modal).repeat_interleave(batch_size).to(spec_logits.device) + cur_loss = loss_class(spec_logits, spec_labels) + else: + assert loss_name == "distribution" + cur_loss = loss_class(shared_feats[0], shared_feats[1]) + + total_loss += weight * cur_loss + + losses_dict[loss_name] = cur_loss + + acc = None + if is_val: + _, pred = logits.max(1) + acc = self.accuracy(pred, label) + + return total_loss, losses_dict, acc + + def training_step(self, batch, batch_idx): + if batch_idx % 4000 == 0: + self.logger.experiment.add_image( + "train_ortho", + make_grid( + batch[0][:, :3, :, :], + nrow=batch[0].shape[0], + ), + global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx, + ) + + self.logger.experiment.add_image( + "train_street", + make_grid( + batch[0][:, 3:, :, :], + nrow=batch[0].shape[0], + ), + global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx, + ) + + self.logger.experiment.add_image( + "train_s2", + make_grid( + batch[1], + nrow=batch[0].shape[0], + ), + global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx, + ) + + total_loss, losses_dict, _ = self.common_step(batch, batch_idx, is_val=False) + + self.log( + "train_loss", + total_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + for loss_name in losses_dict: + self.log( + f"train_loss_{loss_name}", + losses_dict[loss_name], + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + return total_loss + + def validation_step(self, batch, batch_idx): + total_loss, losses_dict, acc = self.common_step(batch, batch_idx, is_val=True) + + self.log( + "val_loss", + total_loss, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + for loss_name in losses_dict: + self.log( + f"val_loss_{loss_name}", + losses_dict[loss_name], + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + self.log( + "val_acc", + acc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return acc + + def configure_optimizers(self): + optimizer = get_object_from_dict( + self.hparams["optimizer"], + params=[x for x in self.parameters() if x.requires_grad], + ) + + scheduler = { + "scheduler": get_object_from_dict( + self.hparams["scheduler"], + optimizer=optimizer, + ), + "monitor": "val_loss", + } return [optimizer], [scheduler] @@ -161,7 +342,9 @@ def __init__(self, hparams): self.hparams["model"], ) self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams["model"]["num_classes"]) - self.loss = FocalLoss() + self.loss = get_object_from_dict( + self.hparams["loss"]["classification"], + ) def forward(self, images, s2_data, country_id): return self.model(images, s2_data, country_id) @@ -181,7 +364,7 @@ def common_step(self, batch, batch_idx, is_val: bool = False): return loss, acc def training_step(self, batch, batch_idx): - if batch_idx % 1000 == 0: + if batch_idx % 4000 == 0: self.logger.experiment.add_image( "train_ortho", make_grid( diff --git a/src/models/loss.py b/src/models/loss.py index 9e3b6ca..ff0f81d 100644 --- a/src/models/loss.py +++ b/src/models/loss.py @@ -5,6 +5,7 @@ __all__ = ( "FocalLoss", "DomainClsLoss", + "FocalLossLabelSmoothing", ) @@ -30,6 +31,33 @@ def forward(self, input, target): return loss +class FocalLossLabelSmoothing(nn.Module): + def __init__(self, smoothing=0.1, gamma=2, weight=None): + super(FocalLossLabelSmoothing, self).__init__() + self.smoothing = smoothing + self.gamma = gamma + self.weight = weight + + def forward(self, input, target): + """ + input: [N, C] + target: [N] + """ + num_classes = input.size(1) + target_one_hot = torch.zeros_like(input).scatter(1, target.unsqueeze(1), 1) + smooth_targets = (1 - self.smoothing) * target_one_hot + self.smoothing / num_classes + + logpt = F.log_softmax(input, dim=1) + pt = torch.exp(logpt) + focal_weight = (1 - pt) ** self.gamma + + loss = -focal_weight * logpt * smooth_targets + if self.weight is not None: + loss = loss * self.weight.unsqueeze(0) + + return loss.sum(dim=1).mean() + + class DomainClsLoss(nn.Module): def __init__(self): super(DomainClsLoss, self).__init__() diff --git a/src/models/model.py b/src/models/model.py index 64d0f02..5413aef 100644 --- a/src/models/model.py +++ b/src/models/model.py @@ -5,6 +5,7 @@ __all__ = ( "MultiModalNet", + "MultiModalNetSharedStreetOrtho", "MultiModalNetFullModalityFeatureFusion", "MultiModalNetFullModalityGeometricFusion", ) @@ -124,6 +125,132 @@ def forward( return logits, spec_logits, shared_feats +class CompositionalLayerStreetOrtho(nn.Module): + def __init__( + self, + in_features, + ): + super().__init__() + self.conv = nn.Conv2d(in_features * 2, in_features, kernel_size=1) + + def forward(self, f1, f2): + """ + :param f1: shared-modality fts + :param f2: specific-modality fts + :return: + """ + residual = torch.cat((f1, f2), 1) + residual = self.conv(residual) + features = f1 + residual + + return features + + +class MultiModalNetSharedStreetOrtho(nn.Module): + def __init__( + self, + encoder_name, + s2_encoder_name, + num_classes, + ): + super().__init__() + self.num_classes = num_classes + self.in_features = _get_extractor_in_features(encoder_name) + self.s2_in_features = _get_extractor_in_features(s2_encoder_name) + + self.shared_enc = timm.create_model( + encoder_name, + pretrained=True, + global_pool="", + num_classes=0, + ) + + self.s2_enc = timm.create_model( + s2_encoder_name, + pretrained=True, + num_classes=0, + ) + + self.ortho_enc = timm.create_model( + encoder_name, + pretrained=True, + num_classes=0, + global_pool="", + ) + + self.street_enc = timm.create_model( + encoder_name, + pretrained=True, + num_classes=0, + global_pool="", + ) + + self.compos_layer = CompositionalLayerStreetOrtho(self.in_features) + self.domain_classfier = nn.Sequential( + SelectAdaptivePool2d(pool_type="avg", flatten=nn.Flatten(start_dim=1, end_dim=-1)), + nn.Linear(in_features=self.in_features, out_features=2, bias=True), + ) + + self.fused_layer = nn.Sequential( + nn.Conv2d(self.in_features * 2, self.in_features, kernel_size=1), + nn.BatchNorm2d(self.in_features), + nn.ReLU(inplace=True), + SelectAdaptivePool2d(pool_type="avg", flatten=nn.Flatten(start_dim=1, end_dim=-1)), + ) + + self.fc = nn.Linear(self.in_features + self.s2_in_features + self.num_classes, self.num_classes) + + def forward( + self, + images, + s2_data, + country_id, + ): + num_channel = images.shape[1] + assert num_channel in [3, 6] + num_modal = num_channel // 3 + + spec_feats = self.ortho_enc(images[:, :3, ...])[None, ...] + if num_channel == 6: + spec_feats = torch.cat((spec_feats, self.street_enc(images[:, 3:, ...])[None, ...]), axis=0) + + shared_feats = self.shared_enc(images[:, :3, ...])[None, ...] + if num_channel == 6: + shared_feats = torch.cat((shared_feats, self.shared_enc(images[:, 3:, ...])[None, ...]), axis=0) + + fused_feats = self.compos_layer(shared_feats[0], spec_feats[0])[None, ...] + for i in range(1, num_modal): + fused_feats = torch.cat((fused_feats, self.compos_layer(shared_feats[i], spec_feats[i])[None, ...]), axis=0) + + if num_modal == 1: + fused_feats = torch.cat((fused_feats, shared_feats), axis=0) + + fused_feats = fused_feats.transpose(0, 1).reshape( + fused_feats.shape[1], + fused_feats.shape[0] * fused_feats.shape[2], + fused_feats.shape[3], + fused_feats.shape[4], + ) + fused_feats = self.fused_layer(fused_feats) + + fused_feats = torch.cat( + ( + fused_feats, + self.s2_enc(s2_data), + nn.functional.one_hot(country_id, num_classes=self.num_classes).float(), + ), + axis=1, + ) + + logits = self.fc(fused_feats) + + spec_logits = self.domain_classfier(spec_feats[0]) + for i in range(1, num_modal): + spec_logits = torch.cat((spec_logits, self.domain_classfier(spec_feats[i])), axis=0) + + return logits, spec_logits, shared_feats + + class MultiModalNetFullModalityFeatureFusion(nn.Module): def __init__( self, diff --git a/src/utils/utils.py b/src/utils/utils.py index a096782..fbe0dcd 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -33,11 +33,4 @@ def fix_seed(seed): def worker_init_fn(worker_id): - seed = 0 - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) + np.random.seed(np.random.get_state()[1][0] + worker_id)