Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Ports generative networks #7196

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
from .segresnet_block import ResBlock
from .selfattention import SABlock
from .spade_norm import SPADE
from .squeeze_and_excitation import (
ChannelSELayer,
ResidualSELayer,
Expand Down
97 changes: 97 additions & 0 deletions monai/networks/blocks/spade_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import ADN, Convolution


class SPADE(nn.Module):
"""
SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291)

Args:
label_nc: number of semantic labels
norm_nc: number of output channels
kernel_size: kernel size
spatial_dims: number of spatial dimensions
hidden_channels: number of channels in the intermediate gamma and beta layers
norm: type of base normalisation used before applying the SPADE normalisation
norm_params: parameters for the base normalisation
"""

def __init__(
self,
label_nc: int,
norm_nc: int,
kernel_size: int = 3,
spatial_dims: int = 2,
hidden_channels: int = 64,
norm: str | tuple = "INSTANCE",
norm_params: dict | None = None,
) -> None:
super().__init__()

if norm_params is None:
norm_params = {}
if len(norm_params) != 0:
norm = (norm, norm_params)
self.param_free_norm = ADN(
act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc
)
self.mlp_shared = Convolution(
spatial_dims=spatial_dims,
in_channels=label_nc,
out_channels=hidden_channels,
kernel_size=kernel_size,
norm=None,
padding=kernel_size // 2,
act="LEAKYRELU",
)
self.mlp_gamma = Convolution(
spatial_dims=spatial_dims,
in_channels=hidden_channels,
out_channels=norm_nc,
kernel_size=kernel_size,
padding=kernel_size // 2,
act=None,
)
self.mlp_beta = Convolution(
spatial_dims=spatial_dims,
in_channels=hidden_channels,
out_channels=norm_nc,
kernel_size=kernel_size,
padding=kernel_size // 2,
act=None,
)

def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:
"""
Args:
x: input tensor
segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels.
The map will be interpolated to the dimension of x internally.
"""

# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)

# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
1 change: 1 addition & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
from .vector_quantizer import EMAQuantizer, VectorQuantizer
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
233 changes: 233 additions & 0 deletions monai/networks/layers/vector_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Sequence, Tuple

import torch
from torch import nn

__all__ = ["VectorQuantizer", "EMAQuantizer"]


class EMAQuantizer(nn.Module):
"""
Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural
Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
58d9a2746493717a7c9252938da7efa6006f3739.

This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.

Args:
spatial_dims : number of spatial spatial_dims.
num_embeddings: number of atomic elements in the codebook.
embedding_dim: number of channels of the input and atomic elements.
commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
decay: EMA decay. Defaults to 0.99.
epsilon: epsilon value. Defaults to 1e-5.
embedding_init: initialization method for the codebook. Defaults to "normal".
ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
"""

def __init__(
self,
spatial_dims: int,
num_embeddings: int,
embedding_dim: int,
commitment_cost: float = 0.25,
decay: float = 0.99,
epsilon: float = 1e-5,
embedding_init: str = "normal",
ddp_sync: bool = True,
):
super().__init__()
self.spatial_dims: int = spatial_dims
self.embedding_dim: int = embedding_dim
self.num_embeddings: int = num_embeddings

assert self.spatial_dims in [2, 3], ValueError(
f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}."
)

self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)
if embedding_init == "normal":
# Initialization is passed since the default one is normal inside the nn.Embedding
pass
elif embedding_init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear")
self.embedding.weight.requires_grad = False

self.commitment_cost: float = commitment_cost

self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings))
self.register_buffer("ema_w", self.embedding.weight.data.clone())
# declare types for mypy
self.ema_cluster_size: torch.Tensor
self.ema_w: torch.Tensor
self.decay: float = decay
self.epsilon: float = epsilon

self.ddp_sync: bool = ddp_sync

# Precalculating required permutation shapes
self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]
self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(
range(1, self.spatial_dims + 1)
)

def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.

Args:
inputs: Encoding space tensors

Returns:
torch.Tensor: Flatten version of the input of shape [B*D*H*W, C].
torch.Tensor: One-hot representation of the quantization indices of shape [B*D*H*W, self.num_embeddings].
torch.Tensor: Quantization indices of shape [B,D,H,W,1]

"""
with torch.autocast(device_type="cuda", enabled=False):
encoding_indices_view = list(inputs.shape)
del encoding_indices_view[1]

inputs = inputs.float()

# Converting to channel last format
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)

# Calculate Euclidean distances
distances = (
(flat_input**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
- 2 * torch.mm(flat_input, self.embedding.weight.t())
)

# Mapping distances to indexes
encoding_indices = torch.max(-distances, dim=1)[1]
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()

# Quantize and reshape
encoding_indices = encoding_indices.view(encoding_indices_view)

return flat_input, encodings, encoding_indices

def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
"""
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
[B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
decoder.

Args:
embedding_indices: Tensor in channel last format which holds indices referencing atomic
elements from self.embedding

Returns:
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
"""
with torch.autocast(device_type="cuda", enabled=False):
embedding: torch.Tensor = (
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
)
return embedding

def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
"""
TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused

Args:
encodings_sum: The summation of one hot representation of what encoding was used for each
position.
dw: The multiplication of the one hot representation of what encoding was used for each
position with the flattened input.

Returns:
None
"""
if self.ddp_sync and torch.distributed.is_initialized():
torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
else:
pass

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
flat_input, encodings, encoding_indices = self.quantize(inputs)
quantized = self.embed(encoding_indices)

# Use EMA to update the embedding vectors
if self.training:
with torch.no_grad():
encodings_sum = encodings.sum(0)
dw = torch.mm(encodings.t(), flat_input)

if self.ddp_sync:
self.distributed_synchronization(encodings_sum, dw)

self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))

# Laplace smoothing of the cluster size
n = self.ema_cluster_size.sum()
weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))
self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))

# Encoding Loss
loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)

# Straight Through Estimator
quantized = inputs + (quantized - inputs).detach()

return quantized, loss, encoding_indices


class VectorQuantizer(torch.nn.Module):
"""
Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
the quantization in their own class.

Args:
quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index
based quantized representation.
"""

def __init__(self, quantizer: EMAQuantizer):
super().__init__()

self.quantizer: EMAQuantizer = quantizer

self.perplexity: torch.Tensor = torch.rand(1)

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
quantized, loss, encoding_indices = self.quantizer(inputs)
# Perplexity calculations
avg_probs = (
torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)
.float()
.div(encoding_indices.numel())
)

self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

return loss, quantized

def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
return self.quantizer.embed(embedding_indices=embedding_indices)

def quantize(self, encodings: torch.Tensor) -> torch.Tensor:
output = self.quantizer(encodings)
encoding_indices: torch.Tensor = output[2]
return encoding_indices
8 changes: 8 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from .ahnet import AHnet, Ahnet, AHNet
from .attentionunet import AttentionUnet
from .autoencoder import AutoEncoder
from .autoencoderkl import AutoencoderKL
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
from .classifier import Classifier, Critic, Discriminator
from .controlnet import ControlNet
from .daf3d import DAF3D
from .densenet import (
DenseNet,
Expand All @@ -34,6 +36,7 @@
densenet201,
densenet264,
)
from .diffusion_model_unet import DiffusionModelUNet
from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
from .dynunet import DynUNet, DynUnet, Dynunet
from .efficientnet import (
Expand All @@ -52,6 +55,7 @@
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .milmodel import MILModel
from .netadapter import NetAdapter
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .quicknat import Quicknat
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
Expand Down Expand Up @@ -102,13 +106,17 @@
seresnext50,
seresnext101,
)
from .spade_autoencoderkl import SPADEAutoencoderKL
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .transformer import DecoderOnlyTransformer
from .unet import UNet, Unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
from .vit import ViT
from .vitautoenc import ViTAutoEnc
from .vnet import VNet
from .voxelmorph import VoxelMorph, VoxelMorphUNet
from .vqvae import VQVAE
Loading
Loading