forked from d-michail/firecastnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·55 lines (38 loc) · 1.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3
from seasfire.data import SeasFireDataModule
from seasfire.firecastnet_lit import FireCastNetLit
from seasfire.gru_lit import GRULit
from seasfire.conv_gru_lit import ConvGRULit
from seasfire.conv_lstm_lit import ConvLSTMLit
from seasfire.utae_lit import UTAELit
from seasfire.cli import SeasfireLightningCLI
import logging
logger = logging.getLogger(__name__)
class GRU(GRULit):
def configure_optimizers(self):
logger.info(f"⚡ Using GRU ⚡")
return super().configure_optimizers()
class ConvGRU(ConvGRULit):
def configure_optimizers(self):
logger.info(f"⚡ Using ConvGRU ⚡")
return super().configure_optimizers()
class ConvLSTM(ConvLSTMLit):
def configure_optimizers(self):
logger.info(f"⚡ Using ConvLSTM ⚡")
return super().configure_optimizers()
class UTAE(UTAELit):
def configure_optimizers(self):
logger.info(f"⚡ Using UTAE ⚡")
return super().configure_optimizers()
class FireCastNet(FireCastNetLit):
def configure_optimizers(self):
logger.info(f"⚡ Using FireCastNet ⚡")
return super().configure_optimizers()
def main():
level = logging.INFO
logging.basicConfig(level=level)
cli = SeasfireLightningCLI(
datamodule_class=SeasFireDataModule,
)
if __name__ == "__main__":
main()