-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revision of BYOL module and tests (#874)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: otaj <ota@lightning.ai>
- Loading branch information
1 parent
6f58d71
commit d8ff64f
Showing
4 changed files
with
166 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,78 @@ | ||
from torch import nn | ||
from typing import Tuple, Union | ||
|
||
from torch import Tensor, nn | ||
|
||
from pl_bolts.utils.self_supervised import torchvision_ssl_encoder | ||
from pl_bolts.utils.stability import under_review | ||
|
||
|
||
@under_review() | ||
class MLP(nn.Module): | ||
def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): | ||
"""MLP architecture used as projectors in online and target networks and predictors in the online network. | ||
Args: | ||
input_dim (int, optional): Input dimension. Defaults to 2048. | ||
hidden_dim (int, optional): Hidden layer dimension. Defaults to 4096. | ||
output_dim (int, optional): Output dimension. Defaults to 256. | ||
Note: | ||
Default values for input, hidden, and output dimensions are based on values used in BYOL. | ||
""" | ||
|
||
def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None: | ||
|
||
super().__init__() | ||
self.output_dim = output_dim | ||
self.input_dim = input_dim | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(input_dim, hidden_size, bias=False), | ||
nn.BatchNorm1d(hidden_size), | ||
nn.Linear(input_dim, hidden_dim, bias=False), | ||
nn.BatchNorm1d(hidden_dim), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(hidden_size, output_dim, bias=True), | ||
nn.Linear(hidden_dim, output_dim, bias=True), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.model(x) | ||
return x | ||
def forward(self, x: Tensor) -> Tensor: | ||
return self.model(x) | ||
|
||
|
||
@under_review() | ||
class SiameseArm(nn.Module): | ||
def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): | ||
"""SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single | ||
class. | ||
Args: | ||
encoder (Union[str, nn.Module], optional): Online and target network encoder architecture. | ||
Defaults to "resnet50". | ||
encoder_out_dim (int, optional): Output dimension of encoder. Defaults to 2048. | ||
projector_hidden_dim (int, optional): Online and target network projector network hidden dimension. | ||
Defaults to 4096. | ||
projector_out_dim (int, optional): Online and target network projector network output dimension. | ||
Defaults to 256. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder: Union[str, nn.Module] = "resnet50", | ||
encoder_out_dim: int = 2048, | ||
projector_hidden_dim: int = 4096, | ||
projector_out_dim: int = 256, | ||
) -> None: | ||
|
||
super().__init__() | ||
|
||
if isinstance(encoder, str): | ||
encoder = torchvision_ssl_encoder(encoder) | ||
# Encoder | ||
self.encoder = encoder | ||
# Projector | ||
self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) | ||
# Predictor | ||
self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim) | ||
|
||
def forward(self, x): | ||
self.encoder = torchvision_ssl_encoder(encoder) | ||
else: | ||
self.encoder = encoder | ||
|
||
self.projector = MLP(encoder_out_dim, projector_hidden_dim, projector_out_dim) | ||
|
||
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: | ||
y = self.encoder(x)[0] | ||
z = self.projector(y) | ||
h = self.predictor(z) | ||
return y, z, h | ||
return y, z | ||
|
||
def encode(self, x: Tensor) -> Tensor: | ||
"""Returns the encoded representation of a view. This method does not calculate the projection as in the | ||
forward method. | ||
Args: | ||
x (Tensor): sample to be encoded | ||
""" | ||
return self.encoder(x)[0] |
Oops, something went wrong.