Skip to content

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

License

Notifications You must be signed in to change notification settings

4pygmalion/SimCLR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

48 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SimCLR

CI Build Status

PyTorch Re-implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Install

  • install SimCLR package
$ git clone https://github.com/4pygmalion/SimCLR.git
$ cd SimCLR
$ python3 -m pip install .

Test

$ python3 example/test_run.py --n_epochs 15 -d cifar10
 |################################| train | EPOCH 0: [9/9] | eta:1 | total_loss: 2.5731 |top3: 0.51429 
 |################################| val | EPOCH 0: [9/9] | eta:1 | total_loss: 2.68113 |top3: 0.38571
....
 |################################| train | EPOCH 14: [9/9] | eta:1 | total_loss: 2.37914 |top3: 0.65714
 |################################| val | EPOCH 14: [9/9] | eta:1 | total_loss: 2.45937 |top3: 0.48571
 |################################| train | EPOCH 15: [9/9] | eta:1 | total_loss: 2.49873 |top3: 0.57143
 |################################| val | EPOCH 15: [9/9] | eta:1 | total_loss: 2.47701 |top3: 0.5

Example

  1. Build a visual encoder
import torch
import torchvision
from SimCLR.models import ResNetEncoder
encoder = ResNetEncoder(out_features=1024)
  1. Build dataset to get two augmented image pair
import os
import glob

from torchvision.transforms import RandomCrop, GaussianBlur, RandomEqualize, RandomRotation, RandomAffine, Compose, ToTensor
from SimCLR.datasets import SimCLRDataSet
from matplotlib import pyplot as plt

images = glob.glob(...)
dataset = SimCLRDataSet(
    images, 
    augmentations=[
        GaussianBlur((3,3)), 
        RandomEqualize(), 
        RandomRotation(degrees=(0, 360)), 
        RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
    ],
    transforms=Compose([ToTensor()]),
    device="cuda:1"
)
  1. Forwarding into the encoder and calculating SimCLR loss
from SimCLR.trainer import SimCLRTrainer
from SimCLR.loss import SimCLRLoss
dataloder = torch.utils.data.DataLoader(dataset, batch_size=16)
criterion = SimCLRLoss(temperature=0.5)
trainer = SimCLRTrainer(encoder, criterion)
  1. Train your model with SimCLRTrainer
optimizer = torch.optim.Adam(encoder.parameters())
trainer = SimCLRTrainer(encoder, optimizer, loss, device=device)
encoder.to(device)
loss_meter, topk_meter = trainer.run_epoch(dataloder, phase="train", epoch=1)

Illustration of SimCLR matrix operation

  1. Build similarity matrix($s_{i,j}$): similarity between i vector and j vector by visual encoder (2N, 2N)
  2. Remove diagonal elements in similarity matrix ($s_{i,j}, i=j$)
  3. Split positive samples (key-value) and negtaive samples
  4. Concate above two matrix (The 0-th column is alway 0 in this matrix)

Reference

About

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published