PyTorch Re-implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
- install SimCLR package
$ git clone https://github.com/4pygmalion/SimCLR.git
$ cd SimCLR
$ python3 -m pip install .
$ python3 example/test_run.py --n_epochs 15 -d cifar10
|################################| train | EPOCH 0: [9/9] | eta:1 | total_loss: 2.5731 |top3: 0.51429
|################################| val | EPOCH 0: [9/9] | eta:1 | total_loss: 2.68113 |top3: 0.38571
....
|################################| train | EPOCH 14: [9/9] | eta:1 | total_loss: 2.37914 |top3: 0.65714
|################################| val | EPOCH 14: [9/9] | eta:1 | total_loss: 2.45937 |top3: 0.48571
|################################| train | EPOCH 15: [9/9] | eta:1 | total_loss: 2.49873 |top3: 0.57143
|################################| val | EPOCH 15: [9/9] | eta:1 | total_loss: 2.47701 |top3: 0.5
- Build a visual encoder
import torch
import torchvision
from SimCLR.models import ResNetEncoder
encoder = ResNetEncoder(out_features=1024)
- Build dataset to get two augmented image pair
import os
import glob
from torchvision.transforms import RandomCrop, GaussianBlur, RandomEqualize, RandomRotation, RandomAffine, Compose, ToTensor
from SimCLR.datasets import SimCLRDataSet
from matplotlib import pyplot as plt
images = glob.glob(...)
dataset = SimCLRDataSet(
images,
augmentations=[
GaussianBlur((3,3)),
RandomEqualize(),
RandomRotation(degrees=(0, 360)),
RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
],
transforms=Compose([ToTensor()]),
device="cuda:1"
)
- Forwarding into the encoder and calculating SimCLR loss
from SimCLR.trainer import SimCLRTrainer
from SimCLR.loss import SimCLRLoss
dataloder = torch.utils.data.DataLoader(dataset, batch_size=16)
criterion = SimCLRLoss(temperature=0.5)
trainer = SimCLRTrainer(encoder, criterion)
- Train your model with SimCLRTrainer
optimizer = torch.optim.Adam(encoder.parameters())
trainer = SimCLRTrainer(encoder, optimizer, loss, device=device)
encoder.to(device)
loss_meter, topk_meter = trainer.run_epoch(dataloder, phase="train", epoch=1)
- Build similarity matrix(
$s_{i,j}$ ): similarity between i vector and j vector by visual encoder (2N, 2N) - Remove diagonal elements in similarity matrix (
$s_{i,j}, i=j$ ) - Split positive samples (key-value) and negtaive samples
- Concate above two matrix (The 0-th column is alway 0 in this matrix)
- Inspried by https://github.com/sthalles/SimCLR