Skip to content

Commit

Permalink
fix missing modality
Browse files Browse the repository at this point in the history
  • Loading branch information
xmba15 committed Jul 9, 2024
1 parent 1df33f0 commit b76d17d
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ mamba activate ai4eo
---

- [Model Fusion for Building Type Classification from Aerial and Street View Images](https://www.mdpi.com/2072-4292/11/11/1259#)
- [Multi-modal fusion of satellite and street-view images for urban village classification based on a dual-branch deep neural network](https://www.sciencedirect.com/science/article/pii/S0303243422001209)
25 changes: 15 additions & 10 deletions config/base.yaml
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
---
seed: 1984
seed: 2024

num_workers: 4
experiment_name: "2024-04-07"
experiment_name: "2024-07-09-missing-modality"

dataset:
val_split: 0.1
n_splits: 10
fold_th: 3
train_dir: ~/publicWorkspace/data/building-age-dataset/train/data
test_dir: ~/publicWorkspace/data/building-age-dataset/test/data
train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv
test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv

model:
encoder_name: efficientnet_b2
encoder_name: tf_efficientnetv2_b3
num_classes: 7

optimizer:
type: timm.optim.AdamP
lr: 0.0005
weight_decay: 0.00001

scheduler:
type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
T_0: 10
T_mult: 2
type: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
factor: 0.5
patience: 10
threshold: 0.00005
verbose: True

trainer:
devices: 1
devices: [0]
accelerator: "cuda"
max_epochs: 50
gradient_clip_val: 5.0
accumulate_grad_batches: 16
resume_from_checkpoint:

train_parameters:
batch_size: 3
batch_size: 4

val_parameters:
batch_size: 3
batch_size: 4

output_root_dir: experiments
image_size: 512
56 changes: 40 additions & 16 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import sys

import albumentations as alb
import cv2
import numpy as np
import pytorch_lightning as pl
import yaml
from albumentations.pytorch import ToTensorV2
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Subset

_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -22,7 +23,7 @@


def get_args():
parser = argparse.ArgumentParser("train multimodal")
parser = argparse.ArgumentParser("train for missing-modality inference")
parser.add_argument("--config_path", type=str, default="./config/base.yaml")

return parser.parse_args()
Expand All @@ -41,12 +42,12 @@ def get_transforms(hparams):
alb.Compose(
[
alb.Resize(height=image_size, width=image_size, p=1.0),
alb.Rotate(limit=(-5, 5), p=0.7),
alb.Rotate(limit=(-5, 5), p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0),
]
),
alb.Compose(
[
alb.Rotate(limit=(-5, 5), p=0.7),
alb.Rotate(limit=(-5, 5), p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0),
alb.Resize(height=image_size, width=image_size, p=1.0),
]
),
Expand Down Expand Up @@ -84,10 +85,16 @@ def get_transforms(hparams):
alb.OneOf(
[
alb.Compose(
[alb.Resize(height=image_size, width=image_size), alb.Rotate(limit=(0, 360), p=0.7)]
[
alb.Resize(height=image_size, width=image_size),
alb.Rotate(limit=180, p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0),
]
),
alb.Compose(
[alb.Rotate(limit=(0, 360), p=0.7), alb.Resize(height=image_size, width=image_size)]
[
alb.Rotate(limit=180, p=0.7, border_mode=cv2.BORDER_CONSTANT, value=0),
alb.Resize(height=image_size, width=image_size),
]
),
],
p=1,
Expand All @@ -112,7 +119,7 @@ def clip_s2(image, **params):
return np.clip(image, 0, 10000)

def extract_rgb(image, **params):
return image[:, :, [0, 1, 2]]
return image[:, :, [3, 2, 1]]

all_transforms["s2"] = {
"train": alb.Compose(
Expand Down Expand Up @@ -142,12 +149,19 @@ def setup_train_val_split(
original_dataset,
hparams,
):
train_indices, val_indices = train_test_split(
range(len(original_dataset)),
stratify=original_dataset.labels,
test_size=hparams["dataset"]["val_split"],
kf = StratifiedKFold(
n_splits=hparams["dataset"]["n_splits"],
shuffle=True,
random_state=hparams["seed"],
)

train_indices, val_indices = list(
kf.split(
range(len(original_dataset)),
original_dataset.labels,
)
)[hparams["dataset"]["fold_th"]]

return train_indices, val_indices


Expand Down Expand Up @@ -229,11 +243,21 @@ def main():
],
)

trainer.fit(
model,
train_loader,
val_loader,
)
if hparams["trainer"]["resume_from_checkpoint"] is not None and os.path.isfile(
hparams["trainer"]["resume_from_checkpoint"]
):
trainer.fit(
model,
train_loader,
val_loader,
ckpt_path=hparams["trainer"]["resume_from_checkpoint"],
)
else:
trainer.fit(
model,
train_loader,
val_loader,
)


if __name__ == "__main__":
Expand Down
34 changes: 23 additions & 11 deletions src/integrated/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchmetrics import Accuracy
from torchvision.utils import make_grid

from src.models import DomainClsLoss, FocalLoss, MultiModalNet
from src.models import DomainClsLoss, FocalLoss, FocalLossLabelSmoothing, MultiModalNet
from src.utils import get_object_from_dict

__all__ = (
Expand All @@ -24,17 +24,17 @@ def __init__(self, hparams):
)
self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams["model"]["num_classes"])
self.losses = [
("focal", 1.0, FocalLoss()),
("focal", 1.0, FocalLossLabelSmoothing()),
("domain_cls", 0.02, DomainClsLoss()),
("distribution", 0.1, nn.L1Loss()),
]

def forward(self, batch):
return self.model(batch[0], batch[1], batch[2])
def forward(self, images, s2_data, country_id):
return self.model(images, s2_data, country_id)

def common_step(self, batch, batch_idx, is_val: bool = False):
_, _, _, label = batch
logits, spec_logits, shared_feats = self.forward(batch)
images, s2_data, country_id, label = batch
logits, spec_logits, shared_feats = self.forward(images, s2_data, country_id)
batch_size = logits.shape[0]
num_modal = spec_logits.shape[1]

Expand Down Expand Up @@ -66,7 +66,7 @@ def common_step(self, batch, batch_idx, is_val: bool = False):
return total_loss, losses_dict, acc

def training_step(self, batch, batch_idx):
if batch_idx % 100 == 0:
if batch_idx % 1000 == 0:
self.logger.experiment.add_image(
"train_ortho",
make_grid(
Expand All @@ -85,6 +85,15 @@ def training_step(self, batch, batch_idx):
global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx,
)

self.logger.experiment.add_image(
"train_s2",
make_grid(
batch[1],
nrow=batch[0].shape[0],
),
global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx,
)

total_loss, losses_dict, _ = self.common_step(batch, batch_idx, is_val=False)

self.log(
Expand Down Expand Up @@ -144,10 +153,13 @@ def configure_optimizers(self):
params=[x for x in self.parameters() if x.requires_grad],
)

scheduler = get_object_from_dict(
self.hparams["scheduler"],
optimizer=optimizer,
)
scheduler = {
"scheduler": get_object_from_dict(
self.hparams["scheduler"],
optimizer=optimizer,
),
"monitor": "val_loss",
}

return [optimizer], [scheduler]

Expand Down
28 changes: 28 additions & 0 deletions src/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = (
"FocalLoss",
"DomainClsLoss",
"FocalLossLabelSmoothing",
)


Expand All @@ -30,6 +31,33 @@ def forward(self, input, target):
return loss


class FocalLossLabelSmoothing(nn.Module):
def __init__(self, smoothing=0.1, gamma=2, weight=None):
super(FocalLossLabelSmoothing, self).__init__()
self.smoothing = smoothing
self.gamma = gamma
self.weight = weight

def forward(self, input, target):
"""
input: [N, C]
target: [N]
"""
num_classes = input.size(1)
target_one_hot = torch.zeros_like(input).scatter(1, target.unsqueeze(1), 1)
smooth_targets = (1 - self.smoothing) * target_one_hot + self.smoothing / num_classes

logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
focal_weight = (1 - pt) ** self.gamma

loss = -focal_weight * logpt * smooth_targets
if self.weight is not None:
loss = loss * self.weight.unsqueeze(0)

return loss.sum(dim=1).mean()


class DomainClsLoss(nn.Module):
def __init__(self):
super(DomainClsLoss, self).__init__()
Expand Down

0 comments on commit b76d17d

Please sign in to comment.