Skip to content

Commit

Permalink
Refactor AutoencoderKlMaisi (Project-MONAI#7993)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7988 .

### Description

Refactor AutoencoderKlMaisi to use monai core components.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Pengfei Guo <pengfeig@nvidia.com>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored and rcremese committed Sep 2, 2024
1 parent 61ef0da commit b622b7c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 32 deletions.
68 changes: 43 additions & 25 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,17 @@

import gc
import logging
from typing import TYPE_CHECKING, Sequence, cast
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import Convolution
from monai.utils import optional_import
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
from monai.utils.type_conversion import convert_to_tensor

AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")

if TYPE_CHECKING:
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
else:
AutoencoderKLType = cast(type, AutoencoderKL)

# Set up logging configuration
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
in_channels: Number of input channels.
num_channels: Sequence of block output channels.
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
num_res_blocks: Number of residual blocks (see ResBlock) per level.
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
norm_num_groups: Number of groups for the group norm layers.
norm_eps: Epsilon for the normalization.
attention_levels: Indicate which level from num_channels contain an attention block.
with_nonlocal_attn: If True, use non-local attention block.
include_fc: whether to include the final linear layer in the attention block. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
Expand All @@ -547,6 +541,8 @@ def __init__(
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -603,11 +599,13 @@ def __init__(
input_channel = output_channel
if attention_levels[i]:
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
Expand All @@ -626,7 +624,7 @@ def __init__(

if with_nonlocal_attn:
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
Expand All @@ -636,16 +634,18 @@ def __init__(
)

blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
Expand Down Expand Up @@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
num_channels: Sequence of block output channels.
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks (see ResBlock) per level.
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
norm_num_groups: Number of groups for the group norm layers.
norm_eps: Epsilon for the normalization.
attention_levels: Indicate which level from num_channels contain an attention block.
with_nonlocal_attn: If True, use non-local attention block.
include_fc: whether to include the final linear layer in the attention block. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
Expand All @@ -729,6 +731,8 @@ def __init__(
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
use_convtranspose: bool = False,
) -> None:
Expand Down Expand Up @@ -758,7 +762,7 @@ def __init__(

if with_nonlocal_attn:
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
Expand All @@ -767,16 +771,18 @@ def __init__(
)
)
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
Expand Down Expand Up @@ -812,11 +818,13 @@ def __init__(

if reversed_attention_levels[i]:
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
Expand Down Expand Up @@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class AutoencoderKlMaisi(AutoencoderKLType):
class AutoencoderKlMaisi(AutoencoderKL):
"""
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
Expand All @@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
norm_eps: Epsilon for the normalization.
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
use_checkpointing: If True, use activation checkpointing.
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
Expand All @@ -909,6 +919,8 @@ def __init__(
norm_eps: float = 1e-6,
with_encoder_nonlocal_attn: bool = False,
with_decoder_nonlocal_attn: bool = False,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
use_checkpointing: bool = False,
use_convtranspose: bool = False,
Expand All @@ -930,12 +942,14 @@ def __init__(
norm_eps,
with_encoder_nonlocal_attn,
with_decoder_nonlocal_attn,
use_flash_attention,
use_checkpointing,
use_convtranspose,
include_fc,
use_combined_linear,
use_flash_attention,
)

self.encoder = MaisiEncoder(
self.encoder: nn.Module = MaisiEncoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channels=num_channels,
Expand All @@ -945,6 +959,8 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_encoder_nonlocal_attn,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
num_splits=num_splits,
dim_split=dim_split,
Expand All @@ -953,7 +969,7 @@ def __init__(
save_mem=save_mem,
)

self.decoder = MaisiDecoder(
self.decoder: nn.Module = MaisiDecoder(
spatial_dims=spatial_dims,
num_channels=num_channels,
in_channels=latent_channels,
Expand All @@ -963,6 +979,8 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_decoder_nonlocal_attn,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
use_convtranspose=use_convtranspose,
num_splits=num_splits,
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __init__(
"`num_channels`."
)

self.encoder = Encoder(
self.encoder: nn.Module = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
channels=channels,
Expand All @@ -546,7 +546,7 @@ def __init__(
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.decoder = Decoder(
self.decoder: nn.Module = Decoder(
spatial_dims=spatial_dims,
channels=channels,
in_channels=latent_channels,
Expand Down
6 changes: 1 addition & 5 deletions tests/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
from monai.networks import eval_mode
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
_, has_einops = optional_import("einops")
_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -79,7 +76,6 @@
CASES = CASES_NO_ATTENTION


@unittest.skipUnless(has_generative, "monai-generative required")
class TestAutoencoderKlMaisi(unittest.TestCase):

@parameterized.expand(CASES)
Expand Down

0 comments on commit b622b7c

Please sign in to comment.