Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CROMA pretrained model #2370

Merged
merged 9 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ Change Star
.. autoclass:: ChangeStarFarSeg
.. autoclass:: ChangeMixin

CROMA
^^^^^

.. autoclass:: CROMA
.. autofunction:: croma_base
.. autofunction:: croma_large
.. autoclass:: CROMABase_Weights
.. autoclass:: CROMALarge_Weights

DOFA
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/weights/agnostic.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Weight,Source,Citation,License,Spatial,Temporal,Spectral,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake
CROMA,`link <https://github.com/antofuller/CROMA>`__,`link <https://arxiv.org/abs/2311.00566>`__,CC-BY-4.0,implicit,-,implicit,,,,,,,,,,,,
DOFABase16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,implicit,-,explicit,65.7,50.9,95.8,96.9,55.1,93.9,94.5,81.4,58.8,51.5,33.0,65.3
DOFALarge16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,implicit,-,explicit,67.5,54.6,96.9,97.3,60.1,97.1,95.0,81.8,59.4,56.9,32.1,66.3
ResNet50_Weights.FMOW_RGB_GASSL,`link <https://github.com/sustainlab-group/geography-aware-ssl>`__,`link <https://arxiv.org/abs/2011.09980>`__,-,implicit,-,-,,,,,,,,,,,,
Expand Down
119 changes: 119 additions & 0 deletions tests/models/test_croma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from pathlib import Path

import pytest
import torch
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import (
CROMA,
CROMABase_Weights,
CROMALarge_Weights,
croma_base,
croma_large,
)


def save_model(model: torch.nn.Module, path: Path) -> None:
state_dict = {
's1_encoder': model.s1_encoder.state_dict(),
's1_GAP_FFN': model.s1_GAP_FFN.state_dict(),
's2_encoder': model.s2_encoder.state_dict(),
's2_GAP_FFN': model.s2_GAP_FFN.state_dict(),
'joint_encoder': model.joint_encoder.state_dict(),
}
torch.save(state_dict, path)


class TestCROMA:
@pytest.mark.parametrize('modalities', [['sar'], ['optical'], ['sar', 'optical']])
def test_croma(self, modalities: list[str]) -> None:
batch_size = 2
model = CROMA(modalities=modalities)
if 'sar' in modalities:
sar_images = torch.randn(
[batch_size, 2, model.image_size, model.image_size]
)
else:
sar_images = None
if 'optical' in modalities:
optical_images = torch.randn(
[batch_size, 12, model.image_size, model.image_size]
)
else:
optical_images = None
out = model(sar_images, optical_images)
for modality in modalities:
assert f'{modality}_encodings' in out
if set(modalities) == {'sar', 'optical'}:
assert 'joint_encodings' in out


class TestCROMABase:
@pytest.fixture(params=[*CROMABase_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = croma_base()
save_model(model, path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_croma(self) -> None:
croma_base()

def test_croma_weights(self, mocked_weights: WeightsEnum) -> None:
croma_base(weights=mocked_weights)

@pytest.mark.slow
def test_croma_download(self, weights: WeightsEnum) -> None:
croma_base(weights=weights)


class TestCROMALarge:
@pytest.fixture(params=[*CROMALarge_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = croma_large()
save_model(model, path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_croma(self) -> None:
croma_large()

def test_croma_weights(self, mocked_weights: WeightsEnum) -> None:
croma_large(weights=mocked_weights)

@pytest.mark.slow
def test_croma_download(self, weights: WeightsEnum) -> None:
croma_large(weights=weights)
6 changes: 6 additions & 0 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .api import get_model, get_model_weights, get_weight, list_models
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .croma import CROMA, CROMABase_Weights, CROMALarge_Weights, croma_base, croma_large
from .dofa import (
DOFA,
DOFABase16_Weights,
Expand Down Expand Up @@ -35,6 +36,11 @@
'ChangeMixin',
'ChangeStar',
'ChangeStarFarSeg',
'CROMA',
'CROMABase_Weights',
'CROMALarge_Weights',
'croma_base',
'croma_large',
'DOFA',
'dofa_small_patch16_224',
'dofa_base_patch16_224',
Expand Down
Loading