Skip to content

Commit

Permalink
full modality
Browse files Browse the repository at this point in the history
  • Loading branch information
xmba15 committed Jul 8, 2024
1 parent 07d8ec7 commit 913f0f8
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 18 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ mamba activate ai4eo
## :gem: References

---

- [Model Fusion for Building Type Classification from Aerial and Street View Images](https://www.mdpi.com/2072-4292/11/11/1259#)
48 changes: 48 additions & 0 deletions config/base_full_modality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
seed: 2024

num_workers: 4
experiment_name: "2024-07-08"

dataset:
n_splits: 10
fold_th: 0
train_dir: ~/publicWorkspace/data/building-age-dataset/train/data
test_dir: ~/publicWorkspace/data/building-age-dataset/test/data
train_csv: ~/publicWorkspace/data/building-age-dataset/train/train-set.csv
test_csv: ~/publicWorkspace/data/building-age-dataset/test/test-set.csv

model:
type: src.models.MultiModalNetFullModalityFeatureFusion
encoder_name: efficientnet_b3
num_classes: 7

optimizer:
type: timm.optim.AdamW
lr: 0.0005
weight_decay: 0.01

scheduler:
type: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
factor: 0.5
patience: 10
threshold: 0.00005
verbose: True

trainer:
devices: 1
accelerator: "cuda"
max_epochs: 50
gradient_clip_val: 5.0
accumulate_grad_batches: 16
resume_from_checkpoint:

train_parameters:
batch_size: 3

val_parameters:
batch_size: 3

output_root_dir: experiments
image_size: 512
221 changes: 221 additions & 0 deletions scripts/train_full_modality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import argparse
import os
import sys

import albumentations as alb
import numpy as np
import pytorch_lightning as pl
import yaml
from albumentations.pytorch import ToTensorV2
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Subset

_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(_CURRENT_DIR, "../"))
from src.data import CountryCode, CustomSubset, MapYourCityDataset, S2RandomRotation
from src.integrated import MultiModalNetFullModalityPl, MultiModalNetPl
from src.utils import fix_seed, worker_init_fn


def get_args():
parser = argparse.ArgumentParser("train multimodal")
parser.add_argument("--config_path", type=str, default="./config/base_full_modality.yaml")

return parser.parse_args()


def get_transforms(hparams):
image_size = hparams["image_size"]

all_transforms = {}
all_transforms["street"] = {
"train": alb.Compose(
[
alb.RandomCropFromBorders(crop_left=0.05, crop_right=0.05, crop_top=0.05, crop_bottom=0.05, p=0.5),
alb.OneOf(
[
alb.Compose(
[
alb.Resize(height=image_size, width=image_size, p=1.0),
alb.Rotate(limit=(-5, 5), p=0.7),
]
),
alb.Compose(
[
alb.Rotate(limit=(-5, 5), p=0.7),
alb.Resize(height=image_size, width=image_size, p=1.0),
]
),
],
p=1,
),
alb.ColorJitter(p=0.5),
alb.AdvancedBlur(p=0.5),
alb.HorizontalFlip(p=0.5),
alb.OneOf(
[
alb.CoarseDropout(min_holes=200, max_holes=400),
alb.GridDropout(),
alb.Spatter(),
],
p=0.5,
),
alb.ToFloat(max_value=255.0),
ToTensorV2(),
]
),
"val": alb.Compose(
[
alb.Resize(height=image_size, width=image_size),
alb.ToFloat(max_value=255.0),
ToTensorV2(),
]
),
}

all_transforms["ortho"] = {
"train": alb.Compose(
[
alb.RandomCropFromBorders(crop_left=0.01, crop_right=0.01, crop_top=0.01, crop_bottom=0.01, p=0.6),
alb.OneOf(
[
alb.Compose(
[alb.Resize(height=image_size, width=image_size), alb.Rotate(limit=(0, 360), p=0.7)]
),
alb.Compose(
[alb.Rotate(limit=(0, 360), p=0.7), alb.Resize(height=image_size, width=image_size)]
),
],
p=1,
),
alb.ColorJitter(p=0.5),
alb.AdvancedBlur(p=0.5),
alb.Flip(p=0.7),
alb.ToFloat(max_value=255.0),
ToTensorV2(),
]
),
"val": alb.Compose(
[
alb.Resize(height=image_size, width=image_size),
alb.ToFloat(max_value=255.0),
ToTensorV2(),
]
),
}

def clip_s2(image, **params):
return np.clip(image, 0, 10000)

all_transforms["s2"] = {
"train": alb.Compose(
[
S2RandomRotation(limits=(0, 360), always_apply=False, p=0.7),
alb.Flip(p=0.7),
alb.Lambda(image=clip_s2),
alb.ToFloat(max_value=10000.0),
ToTensorV2(),
]
),
"val": alb.Compose(
[
alb.Lambda(image=clip_s2),
alb.ToFloat(max_value=10000.0),
ToTensorV2(),
]
),
}

return all_transforms


def setup_train_val_split(
original_dataset,
hparams,
):
kf = StratifiedKFold(
n_splits=hparams["dataset"]["n_splits"],
shuffle=True,
random_state=hparams["seed"],
)

train_indices, val_indices = list(
kf.split(
range(len(original_dataset)),
original_dataset.labels,
)
)[hparams["dataset"]["fold_th"]]

return train_indices, val_indices


def main():
args = get_args()
with open(args.config_path, encoding="utf-8") as f:
hparams = yaml.load(f, Loader=yaml.SafeLoader)
os.makedirs(hparams["output_root_dir"], exist_ok=True)
fix_seed(hparams["seed"])
pl.seed_everything(hparams["seed"])

dataset = MapYourCityDataset(
csv_path=hparams["dataset"]["train_csv"],
data_dir=hparams["dataset"]["train_dir"],
train=True,
)

train_indices, val_indices = setup_train_val_split(dataset, hparams)

transforms_dict = get_transforms(hparams)
train_dataset = CustomSubset(
Subset(dataset, train_indices),
transforms_dict={
"street": transforms_dict["street"]["train"],
"ortho": transforms_dict["ortho"]["train"],
"s2": transforms_dict["s2"]["train"],
},
)

val_dataset = CustomSubset(
Subset(dataset, val_indices),
transforms_dict={
"street": transforms_dict["street"]["val"],
"ortho": transforms_dict["ortho"]["val"],
"s2": transforms_dict["s2"]["val"],
},
)

train_loader = DataLoader(
train_dataset,
batch_size=hparams["train_parameters"]["batch_size"],
shuffle=True,
drop_last=True,
num_workers=hparams["num_workers"],
worker_init_fn=worker_init_fn,
pin_memory=True,
)

val_loader = DataLoader(
val_dataset,
batch_size=hparams["val_parameters"]["batch_size"],
num_workers=hparams["num_workers"],
)

model = MultiModalNetFullModalityPl(hparams)

# model = MultiModalNetFullModalityFeatureFusion(
# hparams["model"]["encoder_name"],
# hparams["model"]["num_classes"],
# )

for batch in train_loader:
images, s2_data, country_id, lable = batch
output = model(images, s2_data, country_id)
print(output.shape)
break


if __name__ == "__main__":
main()
105 changes: 104 additions & 1 deletion src/integrated/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from src.models import DomainClsLoss, FocalLoss, MultiModalNet
from src.utils import get_object_from_dict

__all__ = ("MultiModalNetPl", "DomainClsLoss")
__all__ = (
"MultiModalNetPl",
"MultiModalNetFullModalityPl",
)


class MultiModalNetPl(pl.LightningModule):
Expand Down Expand Up @@ -147,3 +150,103 @@ def configure_optimizers(self):
)

return [optimizer], [scheduler]


class MultiModalNetFullModalityPl(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams.update(hparams)

self.model = get_object_from_dict(
self.hparams["model"],
)
self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams["model"]["num_classes"])
self.loss = FocalLoss()

def forward(self, images, s2_data, country_id):
return self.model(images, s2_data, country_id)

def common_step(self, batch, batch_idx, is_val: bool = False):
images, s2_data, country_id, label = batch
logits = self.forward(images, s2_data, country_id)
batch_size = logits.shape[0]
num_modal = spec_logits.shape[1]

loss = self.loss(logits, label)

acc = None
if is_val:
_, pred = logits.max(1)
acc = self.accuracy(pred, label)

return loss, acc

def training_step(self, batch, batch_idx):
if batch_idx % 1000 == 0:
self.logger.experiment.add_image(
"train_ortho",
make_grid(
batch[0][:, :3, :, :],
nrow=batch[0].shape[0],
),
global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx,
)

self.logger.experiment.add_image(
"train_street",
make_grid(
batch[0][:, 3:, :, :],
nrow=batch[0].shape[0],
),
global_step=self.current_epoch * self.trainer.num_training_batches + batch_idx,
)

loss, _ = self.common_step(batch, batch_idx, is_val=False)

self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)

return loss

def validation_step(self, batch, batch_idx):
loss, acc = self.common_step(batch, batch_idx, is_val=True)

self.log(
"val_loss",
loss,
on_step=False,
on_epoch=True,
sync_dist=True,
)

self.log(
"val_acc",
acc,
on_step=False,
on_epoch=True,
sync_dist=True,
)

return acc

def configure_optimizers(self):
optimizer = get_object_from_dict(
self.hparams["optimizer"],
params=[x for x in self.parameters() if x.requires_grad],
)

scheduler = {
"scheduler": get_object_from_dict(
self.hparams["scheduler"],
optimizer=optimizer,
),
"monitor": "val_loss",
}

return [optimizer], [scheduler]
Loading

0 comments on commit 913f0f8

Please sign in to comment.