Skip to content

Commit

Permalink
silly typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 30, 2023
1 parent 7873cf9 commit fe9ad7b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ class ZoobotTree(GenericLightningModule):
question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0".
channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1.
use_imagenet_weights (bool, optional): Load weights pretrained on ImageNet (NOT galaxies!). Defaults to False.
test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
timm_kwargs (dict, optional): passed to timm.create_model e.g. drop_path_rate=0.2 for effnet. Defaults to {}.
learning_rate (float, optional): AdamW learning rate. Defaults to 1e-3.
Expand Down Expand Up @@ -297,7 +296,7 @@ def get_encoder_dim(encoder, input_size, channels):
def get_pytorch_encoder(
architecture_name='efficientnet_b0',
channels=1,
use_imagenet_weights=False,
# use_imagenet_weights=False,
**timm_kwargs
) -> nn.Module:
"""
Expand Down Expand Up @@ -333,7 +332,7 @@ def get_pytorch_encoder(
if architecture_name == 'efficientnet':
logging.warning('efficientnet variant not specified - please set architecture_name=efficientnet_b0 (or similar)')
architecture_name = 'efficientnet_b0'
return timm.create_model(architecture_name, in_chans=channels, num_classes=0, pretrained=use_imagenet_weights, **timm_kwargs)
return timm.create_model(architecture_name, in_chans=channels, num_classes=0, **timm_kwargs)


def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float) -> torch.nn.Sequential:
Expand Down

0 comments on commit fe9ad7b

Please sign in to comment.