Skip to content

Commit

Permalink
Add CROMA pretrained model (#2370)
Browse files Browse the repository at this point in the history
* add croma

* coverage

* single line test

* Fix type hints

* review

* Update croma.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* review

* typo

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
nilsleh and adamjstewart authored Oct 29, 2024
1 parent 0e82cc7 commit b2b6516
Show file tree
Hide file tree
Showing 5 changed files with 779 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ Change Star
.. autoclass:: ChangeStarFarSeg
.. autoclass:: ChangeMixin

CROMA
^^^^^

.. autoclass:: CROMA
.. autofunction:: croma_base
.. autofunction:: croma_large
.. autoclass:: CROMABase_Weights
.. autoclass:: CROMALarge_Weights

DOFA
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/weights/agnostic.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Weight,Source,Citation,License,Spatial,Temporal,Spectral,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake
CROMA,`link <https://github.com/antofuller/CROMA>`__,`link <https://arxiv.org/abs/2311.00566>`__,CC-BY-4.0,implicit,-,implicit,,,,,,,,,,,,
DOFABase16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,implicit,-,explicit,65.7,50.9,95.8,96.9,55.1,93.9,94.5,81.4,58.8,51.5,33.0,65.3
DOFALarge16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,implicit,-,explicit,67.5,54.6,96.9,97.3,60.1,97.1,95.0,81.8,59.4,56.9,32.1,66.3
ResNet50_Weights.FMOW_RGB_GASSL,`link <https://github.com/sustainlab-group/geography-aware-ssl>`__,`link <https://arxiv.org/abs/2011.09980>`__,-,implicit,-,-,,,,,,,,,,,,
Expand Down
119 changes: 119 additions & 0 deletions tests/models/test_croma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from pathlib import Path

import pytest
import torch
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import (
CROMA,
CROMABase_Weights,
CROMALarge_Weights,
croma_base,
croma_large,
)


def save_model(model: torch.nn.Module, path: Path) -> None:
state_dict = {
's1_encoder': model.s1_encoder.state_dict(),
's1_GAP_FFN': model.s1_GAP_FFN.state_dict(),
's2_encoder': model.s2_encoder.state_dict(),
's2_GAP_FFN': model.s2_GAP_FFN.state_dict(),
'joint_encoder': model.joint_encoder.state_dict(),
}
torch.save(state_dict, path)


class TestCROMA:
@pytest.mark.parametrize('modalities', [['sar'], ['optical'], ['sar', 'optical']])
def test_croma(self, modalities: list[str]) -> None:
batch_size = 2
model = CROMA(modalities=modalities)
if 'sar' in modalities:
sar_images = torch.randn(
[batch_size, 2, model.image_size, model.image_size]
)
else:
sar_images = None
if 'optical' in modalities:
optical_images = torch.randn(
[batch_size, 12, model.image_size, model.image_size]
)
else:
optical_images = None
out = model(sar_images, optical_images)
for modality in modalities:
assert f'{modality}_encodings' in out
if set(modalities) == {'sar', 'optical'}:
assert 'joint_encodings' in out


class TestCROMABase:
@pytest.fixture(params=[*CROMABase_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = croma_base()
save_model(model, path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_croma(self) -> None:
croma_base()

def test_croma_weights(self, mocked_weights: WeightsEnum) -> None:
croma_base(weights=mocked_weights)

@pytest.mark.slow
def test_croma_download(self, weights: WeightsEnum) -> None:
croma_base(weights=weights)


class TestCROMALarge:
@pytest.fixture(params=[*CROMALarge_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = croma_large()
save_model(model, path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_croma(self) -> None:
croma_large()

def test_croma_weights(self, mocked_weights: WeightsEnum) -> None:
croma_large(weights=mocked_weights)

@pytest.mark.slow
def test_croma_download(self, weights: WeightsEnum) -> None:
croma_large(weights=weights)
6 changes: 6 additions & 0 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .api import get_model, get_model_weights, get_weight, list_models
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .croma import CROMA, CROMABase_Weights, CROMALarge_Weights, croma_base, croma_large
from .dofa import (
DOFA,
DOFABase16_Weights,
Expand Down Expand Up @@ -35,6 +36,11 @@
'ChangeMixin',
'ChangeStar',
'ChangeStarFarSeg',
'CROMA',
'CROMABase_Weights',
'CROMALarge_Weights',
'croma_base',
'croma_large',
'DOFA',
'dofa_small_patch16_224',
'dofa_base_patch16_224',
Expand Down
Loading

0 comments on commit b2b6516

Please sign in to comment.