Skip to content

Commit

Permalink
Removing L2-norm in contrastive loss (L2-norm already present in CosS…
Browse files Browse the repository at this point in the history
…im) (#6550)

### Description

The `forward` method of the `ContrastiveLoss` performs L2-normalization
before computing cosine similarity. The
[`torch.nn.functional.cosine_similarity`](https://pytorch.org/docs/stable/generated/torch.nn.functional.cosine_similarity.html)
method already handles this pre-processing to make sure that `input` and
`target` lie on the surface of the unit hypersphere. This step involves
an unnecessary cost and, thus, can be removed.

### Types of changes

- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.

Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
  • Loading branch information
Lucas-rbnt authored May 24, 2023
1 parent 8dd004a commit ef2bd45
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions monai/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
batch_size = input.shape[0]

norm_i = F.normalize(input, dim=1)
norm_j = F.normalize(target, dim=1)

negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)

repr = torch.cat([norm_i, norm_j], dim=0)
repr = torch.cat([input, target], dim=0)
sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
sim_ij = torch.diag(sim_matrix, batch_size)
sim_ji = torch.diag(sim_matrix, -batch_size)
Expand Down

0 comments on commit ef2bd45

Please sign in to comment.