Skip to content

Commit

Permalink
Checkpoint Refactor (#65)
Browse files Browse the repository at this point in the history
* fixed grid effect

* blew up commit history

* fixed commint

* added modulus model version and checkpoing

* refactored model registry

* example registry

* updated model registry

* modified activations

* save git hash changed to verbose

* Fixed most issues

* ~90% functionality implemented

* added docstring about json input

* added map location here

* updated init method

* black formated

* removed example external package

* Update test_from_torch.py

* from torch model black

---------

Co-authored-by: oliver <ohennigh@nvidia.com>
  • Loading branch information
loliverhennigh and loliverhennigh authored Aug 4, 2023
1 parent 3883689 commit 656b15e
Show file tree
Hide file tree
Showing 26 changed files with 890 additions and 167 deletions.
7 changes: 4 additions & 3 deletions modulus/models/dlwp/dlwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import dataclass

import modulus
from modulus.models.layers import get_activation
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
from typing import Tuple, Union
Expand Down Expand Up @@ -212,7 +213,7 @@ class DLWP(Module):
nr_initial_channels : int
Number of channels in the initial convolution. This governs the overall channels
in the model.
activation_fn : nn.Module
activation_fn : str
Activation function for the convolutions
depth : int
Depth for the U-Net
Expand Down Expand Up @@ -242,7 +243,7 @@ def __init__(
nr_input_channels: int,
nr_output_channels: int,
nr_initial_channels: int = 64,
activation_fn: nn.Module = nn.LeakyReLU(0.1),
activation_fn: str = "leaky_relu",
depth: int = 2,
clamp_activation: Tuple[Union[float, int, None], Union[float, int, None]] = (
None,
Expand All @@ -254,7 +255,7 @@ def __init__(
self.nr_input_channels = nr_input_channels
self.nr_output_channels = nr_output_channels
self.nr_initial_channels = nr_initial_channels
self.activation_fn = activation_fn
self.activation_fn = get_activation(activation_fn)
self.depth = depth
self.clamp_activation = clamp_activation

Expand Down
56 changes: 34 additions & 22 deletions modulus/models/fno/fno.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dataclasses import dataclass
from ..meta import ModelMetaData
from ..module import Module
from ..mlp import FullyConnected

# ===================================================================
# ===================================================================
Expand Down Expand Up @@ -649,10 +650,16 @@ class FNO(Module):
Parameters
----------
decoder_net : modulus.Module
Pointwise decoder network, input feature size should match `latent_channels`
in_channels : int
Number of input channels
out_channels : int
Number of output channels
decoder_layers : int, optional
Number of decoder layers, by default 1
decoder_layer_size : int, optional
Number of neurons in decoder layers, by default 32
decoder_activation_fn : str, optional
Activation function for decoder, by default "silu"
dimension : int
Model dimensionality (supports 1, 2, 3).
latent_channels : int, optional
Expand All @@ -665,24 +672,19 @@ class FNO(Module):
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
activation_fn : str, optional
Activation function, by default "gelu"
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
Example
-------
>>> # define the decoder net
>>> decoder = modulus.models.mlp.FullyConnected(
... in_features=32,
... out_features=3,
... num_layers=2,
... layer_size=16,
... )
>>> # define the 2d FNO model
>>> model = modulus.models.fno.FNO(
... decoder_net=decoder,
... in_channels=4,
... out_channels=3,
... decoder_layers=2,
... decoder_layer_size=32,
... dimension=2,
... latent_channels=32,
... num_fno_layers=2,
Expand All @@ -701,15 +703,18 @@ class FNO(Module):

def __init__(
self,
decoder_net: Module,
in_channels: int,
dimension: int,
out_channels: int,
decoder_layers: int = 1,
decoder_layer_size: int = 32,
decoder_activation_fn: str = "silu",
dimension: int = 2,
latent_channels: int = 32,
num_fno_layers: int = 4,
num_fno_modes: Union[int, List[int]] = 16,
padding: int = 8,
padding_type: str = "constant",
activation_fn: nn.Module = nn.GELU(),
activation_fn: str = "gelu",
coord_features: bool = True,
) -> None:
super().__init__(meta=MetaData())
Expand All @@ -718,11 +723,17 @@ def __init__(
self.num_fno_modes = num_fno_modes
self.padding = padding
self.padding_type = padding_type
self.activation_fn = activation_fn
self.activation_fn = layers.get_activation(activation_fn)
self.coord_features = coord_features
self.var_dim = decoder_net.meta.var_dim

# decoder net
self.decoder_net = decoder_net
self.decoder_net = FullyConnected(
in_features=latent_channels,
layer_size=decoder_layer_size,
out_features=out_channels,
num_layers=decoder_layers,
activation_fn=decoder_activation_fn,
)

if dimension == 1:
FNOModel = FNO1DEncoder
Expand Down Expand Up @@ -757,16 +768,17 @@ def __init__(
)

def forward(self, x: Tensor) -> Tensor:
# Fourier encoder
y_latent = self.spec_encoder(x)

# Reshape to pointwise inputs if not a conv FC model
y_shape = y_latent.shape
if self.var_dim == -1:
y_latent, y_shape = self.grid_to_points(y_latent)
y_latent, y_shape = self.grid_to_points(y_latent)

# Decoder
y = self.decoder_net(y_latent)

# Convert back into grid
if self.var_dim == -1:
y = self.points_to_grid(y, y_shape)
y = self.points_to_grid(y, y_shape)

return y
2 changes: 2 additions & 0 deletions modulus/models/graphcast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .graph_cast_net import GraphCastNet
10 changes: 7 additions & 3 deletions modulus/models/graphcast/graph_cast_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any
from dataclasses import dataclass

from modulus.models.layers import get_activation
from modulus.models.gnn_layers.utils import set_checkpoint_fn, CuGraphCSC
from modulus.models.gnn_layers.embedder import (
GraphCastEncoderEmbedder,
Expand Down Expand Up @@ -79,8 +80,8 @@ class GraphCastNet(Module):
Number of neurons in each hidden layer, by default 512
aggregation : str, optional
Message passing aggregation method ("sum", "mean"), by default "sum"
activation_fn : nn.Module, optional
Type of activation function, by default nn.SiLU()
activation_fn : str, optional
Type of activation function, by default "silu"
norm_type : str, optional
Normalization type, by default "LayerNorm"
use_cugraphops_encoder : bool, default=False
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(
hidden_layers: int = 1,
hidden_dim: int = 512,
aggregation: str = "sum",
activation_fn: nn.Module = nn.SiLU(),
activation_fn: str = "silu",
norm_type: str = "LayerNorm",
use_cugraphops_encoder: bool = False,
use_cugraphops_processor: bool = False,
Expand All @@ -139,6 +140,9 @@ def __init__(
)
self.has_static_data = static_dataset_path is not None

# Set activation function
activation_fn = get_activation(activation_fn)

# Get the static data
if self.has_static_data:
self.static_data = StaticData(
Expand Down
2 changes: 1 addition & 1 deletion modulus/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .activations import Identity, Stan, SquarePlus
from .activations import Identity, Stan, SquarePlus, get_activation
from .weight_norm import WeightNormLinear
from .spectral_layers import (
SpectralConv1d,
Expand Down
55 changes: 55 additions & 0 deletions modulus/models/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,58 @@ def __init__(self):

def forward(self, x: Tensor) -> Tensor:
return 0.5 * (x + torch.sqrt(x * x + self.b))


# Dictionary of activation functions
ACT2FN = {
"relu": nn.ReLU,
"leaky_relu": (nn.LeakyReLU, {"negative_slope": 0.1}),
"prelu": nn.PReLU,
"relu6": nn.ReLU6,
"elu": nn.ELU,
"selu": nn.SELU,
"silu": nn.SiLU,
"gelu": nn.GELU,
"sigmoid": nn.Sigmoid,
"logsigmoid": nn.LogSigmoid,
"softplus": nn.Softplus,
"softshrink": nn.Softshrink,
"softsign": nn.Softsign,
"tanh": nn.Tanh,
"tanhshrink": nn.Tanhshrink,
"threshold": (nn.Threshold, {"threshold": 1.0, "value": 1.0}),
"hardtanh": nn.Hardtanh,
"identity": Identity,
"stan": Stan,
"squareplus": SquarePlus,
}


def get_activation(activation: str) -> nn.Module:
"""Returns an activation function given a string
Parameters
----------
activation : str
String identifier for the desired activation function
Returns
-------
Activation function
Raises
------
KeyError
If the specified activation function is not found in the dictionary
"""
try:
activation = activation.lower()
module = ACT2FN[activation]
if isinstance(module, tuple):
return module[0](**module[1])
else:
return module()
except KeyError:
raise KeyError(
f"Activation function {activation} not found. Available options are: {list(ACT2FN.keys())}"
)
9 changes: 5 additions & 4 deletions modulus/models/mlp/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import torch
import torch.nn as nn
import modulus
from modulus.models.layers import FCLayer, get_activation

from torch import Tensor
from dataclasses import dataclass
from typing import Optional, Union, List
from modulus.models.layers import FCLayer
from ..meta import ModelMetaData
from ..module import Module

Expand Down Expand Up @@ -53,8 +53,8 @@ class FullyConnected(Module):
Size of output features, by default 512
num_layers : int, optional
Number of hidden layers, by default 6
activation_fn : Union[nn.Module, List[nn.Module]], optional
Activation function to use, by default nn.SILU
activation_fn : Union[str, List[str]], optional
Activation function to use, by default 'silu'
skip_connections : bool, optional
Add skip connections every 2 hidden layers, by default False
adaptive_activations : bool, optional
Expand All @@ -77,7 +77,7 @@ def __init__(
layer_size: int = 512,
out_features: int = 512,
num_layers: int = 6,
activation_fn: Union[nn.Module, List[nn.Module]] = nn.SiLU(),
activation_fn: Union[str, List[str]] = "silu",
skip_connections: bool = False,
adaptive_activations: bool = False,
weight_norm: bool = False,
Expand All @@ -97,6 +97,7 @@ def __init__(
activation_fn = activation_fn + [activation_fn[-1]] * (
num_layers - len(activation_fn)
)
activation_fn = [get_activation(a) for a in activation_fn]

self.layers = nn.ModuleList()

Expand Down
Loading

0 comments on commit 656b15e

Please sign in to comment.