Skip to content

Commit

Permalink
feat: added LogReg into all self-supervised learning examples (#1325)
Browse files Browse the repository at this point in the history
* feat: added LogReg into simCLR

* fix: fixed typo

* Update examples/self_supervised/simCLR.py

Co-authored-by: Sergey Kolesnikov <scitator@gmail.com>

* feat: added logistic scorer into byol

Co-authored-by: Sergey Kolesnikov <scitator@gmail.com>
  • Loading branch information
Nimrais and Scitator committed Oct 12, 2021
1 parent 3feea30 commit 93eedf0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/self_supervised/barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
criterion=criterion,
optimizer=optimizer,
callbacks=callbacks,
loaders=get_loaders(arargs.dataset, args.batch_size, args.num_workersgs),
loaders=get_loaders(args.dataset, args.batch_size, args.num_workers),
verbose=True,
num_epochs=epochs,
valid_loader="train",
Expand Down
17 changes: 17 additions & 0 deletions examples/self_supervised/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse

from common import add_arguments, get_contrastive_model, get_loaders
from sklearn.linear_model import LogisticRegression

from torch.optim import Adam

Expand Down Expand Up @@ -50,6 +51,22 @@ def set_requires_grad(model, val):
),
loaders="train",
),
dl.SklearnModelCallback(
feature_key="embedding_origin",
target_key="target",
train_loader="train",
valid_loaders="valid",
model_fn=LogisticRegression,
predict_key="sklearn_predict",
predict_method="predict_proba",
),
dl.OptimizerCallback(metric_key="loss"),
dl.ControlFlowCallback(
dl.AccuracyCallback(
target_key="target", input_key="sklearn_predict", topk_args=(1, 3)
),
loaders="valid",
),
]

runner = SelfSupervisedRunner()
Expand Down
19 changes: 18 additions & 1 deletion examples/self_supervised/simCLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse

from common import add_arguments, get_contrastive_model, get_loaders
from sklearn.linear_model import LogisticRegression

from torch.optim import Adam

Expand All @@ -26,7 +27,23 @@
callbacks = [
dl.CriterionCallback(
input_key="projection_left", target_key="projection_right", metric_key="loss"
)
),
dl.SklearnModelCallback(
feature_key="embedding_origin",
target_key="target",
train_loader="train",
valid_loaders="valid",
model_fn=LogisticRegression,
predict_key="sklearn_predict",
predict_method="predict_proba",
),
dl.OptimizerCallback(metric_key="loss"),
dl.ControlFlowCallback(
dl.AccuracyCallback(
target_key="target", input_key="sklearn_predict", topk_args=(1, 3)
),
loaders="valid",
),
]

runner = SelfSupervisedRunner()
Expand Down

0 comments on commit 93eedf0

Please sign in to comment.