Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove momentum updating from val step and add separate val queue #631

Merged
merged 37 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7416a19
Add remove momentum updating from val step and add separate val queue
May 5, 2021
0f4b8ed
Remove momentum updating from val step and add separate val queue
May 5, 2021
cecf0d1
Merge branch 'fix-moco-validation' of https://github.com/maximzubkov/…
May 5, 2021
fc8c193
Merge branch 'master' into fix-moco-validation
mergify[bot] May 10, 2021
371d2e3
Merge branch 'master' into fix-moco-validation
Borda May 10, 2021
e067661
Merge branch 'master' into fix-moco-validation
mergify[bot] May 10, 2021
e4c3153
chlog
Borda May 10, 2021
d12f28b
Merge branch 'fix-moco-validation' of https://github.com/maximzubkov/…
Borda May 10, 2021
0888b66
Merge branch 'master' into fix-moco-validation
mergify[bot] May 10, 2021
dc1a515
Merge branch 'master' into fix-moco-validation
mergify[bot] May 11, 2021
e6328e3
Merge branch 'master' into fix-moco-validation
mergify[bot] May 11, 2021
80ae1b8
Merge branch 'master' into fix-moco-validation
maximzubkov May 12, 2021
cee2f9b
Merge branch 'master' into fix-moco-validation
mergify[bot] May 14, 2021
5cb2a31
Merge branch 'master' into fix-moco-validation
mergify[bot] May 17, 2021
370f23e
Fix val queue init
May 24, 2021
db708ea
Merge branch 'fix-moco-validation' of https://github.com/maximzubkov/…
May 24, 2021
f951c0a
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 15, 2021
02f182b
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 15, 2021
daa6674
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 15, 2021
071cd19
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 15, 2021
775af54
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 16, 2021
6f885c7
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 16, 2021
39e11dd
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 16, 2021
402db79
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 16, 2021
a53449a
v0.3.4 & changelog
Borda Jun 17, 2021
d3cf77c
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 17, 2021
b117d14
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 17, 2021
c5017c7
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 21, 2021
50701cb
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 21, 2021
a2e39c4
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 24, 2021
bae891a
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 24, 2021
188ef7e
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 24, 2021
0c1fca4
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 25, 2021
0a7e203
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 25, 2021
087639e
Update changelog
akihironitta Jun 27, 2021
be3dfec
Merge branch 'master' into fix-moco-validation
mergify[bot] Jun 29, 2021
6d31284
Merge branch 'master' into fix-moco-validation
mergify[bot] Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631))
- Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629))


Expand Down
34 changes: 21 additions & 13 deletions pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def __init__(

self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

# create the validation queue
self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives))
self.queue = nn.functional.normalize(self.val_queue, dim=0)
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long))

def init_encoders(self, base_encoder):
"""
Override to add your own encoders
Expand All @@ -142,21 +148,21 @@ def _momentum_update_key_encoder(self):
param_k.data = param_k.data * em + param_q.data * (1. - em)

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
def _dequeue_and_enqueue(self, keys, queue_ptr, queue):
# gather keys before updating queue
if self.trainer.use_ddp or self.trainer.use_ddp2:
keys = concat_all_gather(keys)

batch_size = keys.shape[0]

ptr = int(self.queue_ptr)
ptr = int(queue_ptr)
assert self.hparams.num_negatives % batch_size == 0 # for simplicity

# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer

self.queue_ptr[0] = ptr
queue_ptr[0] = ptr

@torch.no_grad()
def _batch_shuffle_ddp(self, x): # pragma: no cover
Expand Down Expand Up @@ -205,11 +211,12 @@ def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover

return x_gather[idx_this]

def forward(self, img_q, img_k):
def forward(self, img_q, img_k, queue):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
queue: a queue from which to pick negative samples
Output:
logits, targets
"""
Expand All @@ -220,7 +227,6 @@ def forward(self, img_q, img_k):

# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder

# shuffle for making use of BN
if self.trainer.use_ddp or self.trainer.use_ddp2:
Expand All @@ -238,7 +244,7 @@ def forward(self, img_q, img_k):
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()])

# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
Expand All @@ -250,10 +256,7 @@ def forward(self, img_q, img_k):
labels = torch.zeros(logits.shape[0], dtype=torch.long)
labels = labels.type_as(logits)

# dequeue and enqueue
self._dequeue_and_enqueue(k)

return logits, labels
return logits, labels, k

def training_step(self, batch, batch_idx):
# in STL10 we pass in both lab+unl for online ft
Expand All @@ -264,7 +267,10 @@ def training_step(self, batch, batch_idx):

(img_1, img_2), _ = batch

output, target = self(img_q=img_1, img_k=img_2)
self._momentum_update_key_encoder() # update the key encoder
output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue)
self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue

loss = F.cross_entropy(output.float(), target.long())

acc1, acc5 = precision_at_k(output, target, top_k=(1, 5))
Expand All @@ -282,7 +288,9 @@ def validation_step(self, batch, batch_idx):

(img_1, img_2), labels = batch

output, target = self(img_q=img_1, img_k=img_2)
output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue)
self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue

loss = F.cross_entropy(output, target.long())

acc1, acc5 = precision_at_k(output, target, top_k=(1, 5))
Expand Down