Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive loss layer for training siamese nets #959

Merged
merged 1 commit into from
Sep 19, 2014

Conversation

nickcarlevaris
Copy link

Hi All,

I've started work on a contrastive loss layer that, when combined with weight sharing #546, can be used to train siamese nets.

The layer implements:
loss = 1/2 (y d + (1-y) \max(margin-d, 0))
d = \sum_i (a_i - b_i)^2
where d is the distance between two features a and b, and y is binary, indicating if the two features are similar or dissimilar.

This loss function was proposed in:
Raia Hadsell, Sumit Chopra, Yann LeCun "Dimensionality Reduction by Learning an Invariant Mapping"

I still need to implement the GPU version and flesh out the tests, but before I go too far, is this something that you would be interested in merging? I can also add a small example based on MNIST for documentation if that is of interest.

Thanks,
Nick

@shelhamer
Copy link
Member

A PR including this loss layer in CPU and GPU implementations, with tests, and an example siamese net model on MNIST would certainly be welcome!

It will be nice to bundle a use case of weight sharing to make it less of an experts-only feature.

@nickcarlevaris nickcarlevaris force-pushed the contrastive_loss branch 4 times, most recently from c9ba74e to b55b5c8 Compare August 25, 2014 17:52
@nickcarlevaris
Copy link
Author

This is ready for review if you have a second. I've pushed the CPU and GPU implementations, tests, and an example in examples/siamese. I also rebased it against dev.

Let me know if there are any updates or changes you would like me to make.

Thanks,
Nick

diff_sq_.cpu_data(), // (a_i-b_i)^2
summer_vec_.cpu_data(),
Dtype(0.0),
dist_sq_.mutable_cpu_data()); // \Sum (a_i-b_i)^2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you're not using the dot product as in the Euclidean Loss layer? I think the dot product method should be faster.

  caffe_sub(
      count,
      bottom[0]->cpu_data(),
      bottom[1]->cpu_data(),
      diff_.mutable_cpu_data());
  Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data());

@ashafaei
Copy link
Contributor

Great job @nickcarlevaris.

@nickcarlevaris
Copy link
Author

@ashafaei, thanks for taking a look.

Unlike the Euclidean Loss layer, which needs the total sum-of-squares difference between the bottom blobs, the contrastive loss layer needs the squared distances between each row of the bottom blobs. This is stored in dist_sq_ in the code. Originally, I didn't use the dot product because I thought calling it once per row would it would be slower.

Based on your comments I went back and tried a few things --- it turns out that for cpu_forward, it is faster to use the dot product and just call it once per row. However, for the gpu_forward, it was much faster to not call the dot product multiple times, and instead do the elementwise difference and square, followed by a matrix vector multiplication to sum along the rows.

I've updated the PR accordingly, using the dot product for the cpu_forward and the matrix version for the gpu_forward.

@amiralush
Copy link

@nickcarlevaris thanks for the siamese example. I've walked it through and it seems to be converging nicely on the mnist dataset. However when I tried using it on a subset of 200 categories from Imagenet It doesn't converge. Could you speculate on the cause? I've tried different variations of networks including the imagenet architecture with pre-trained weigths. Nothing seems to work.

Thanks again!
Amir A.

@shelhamer
Copy link
Member

Thanks for the loss and the nice example! Please switch the paths to fit the new standard of running from the Caffe project root adopted in #1003 for merge.

" labels = np.fromstring(f.read(n), dtype=np.uint8)\n",
" \n",
"# scale and reshape\n",
"images = images.reshape(n, 1, 28, 28).astype(np.float32) / 255. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please configure the net by net.set_input_scale('data', 0.00390625) instead of manually dividing to match the prototxt definition.

@nickcarlevaris
Copy link
Author

@shelhamer, thanks for the review. I'm out of town for a week and don't have my dev machine but I'll update the PR right when I get back.

@amiralush, I haven't tried training with the Imagenet data, but in general you may need to increase the size of the output space. Also, make sure that if you are using a margin of 1.0 that the weight initialization produces values roughly spread around a unit sphere in the output space.

for (int i = 0; i < bottom[0]->num(); ++i) {
dist_sq_.mutable_cpu_data()[i] = caffe_cpu_dot(channels,
diff_.cpu_data() + (i*channels), diff_.cpu_data() + (i*channels));
if (bottom[2]->cpu_data()[i]) { // similar pairs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickcarlevaris @shelhamer Isn't it safer to recast the data to int type before doing this comparison? I'm not really sure how it will be interpreted by the compiler, but I think this is one of those situations where you should be explicit. This also applies to other locations where you make comparison based on the label. Look at the SoftMax loss for instance. They always say something like

bottom_diff[i * dim + static_cast<int>(label[i * spatial_dim + j])

lelayf added a commit to crossgradient/caffe that referenced this pull request Sep 6, 2014
add numpy formatting option to display entire array (read_npy)
… example using shared weights and the contrastive loss.
@nickcarlevaris
Copy link
Author

@shelhamer and @ashafaei, I've update the PR based on the changes you suggested. I've also added doxygen style comments to the layer and rebased. Let me know if it needs anything else before merging.

Also, in the future, do you prefer that commits to a PR are made with --amend to squash multiple commits? Or should I leave them in so that the PR comments don't reference outdated diffs?

@okn2020
Copy link

okn2020 commented Sep 13, 2014

I think there is some sort of issue with restarting siamese net training from snapshot (with solver parameters) and by simple fine-tuning. When training restarted accuracy at first iteration is ok but quickly drops. I think it is related to net with shared weights, as I do not see the same behavior on nets without sharing.. any ideas?

@okn2020
Copy link

okn2020 commented Sep 13, 2014

@nickcarlevaris I am using latest Dev and merged your contrastive loss function and examples into it..

@amiralush
Copy link

@shelhamer, @okn2020 I have also experienced this when using a pre-trained network with weight sharing. It seems like the weight sharing update (Net::Update) mechanism is flawed. I've used identical pairs as input and computed the L2-distance between shared layers of the siamese network, the diff was not zero as I expected.

@shelhamer
Copy link
Member

@nickcarlevaris thanks for the update. Squashing / rebasing is a nice last step before merge to tidy up. This is ready to merge once we double-check this resuming issue.

@jeffdonahue have you encountered #959 (comment) ?

@okn2020
Copy link

okn2020 commented Sep 16, 2014

@nickcarlevaris What do you think about @chyh1990 siamese implementation https://github.com/chyh1990/caffe/tree/veri ? My understanding that there k-way softmax is used on the top of the net and then contrastive loss is injected right below. Is it essentially the same? In recent papers I see people first train multi-class classification net with softmax, then replace top layer with contrastive loss and fine-tune it as siamese net.

@shelhamer shelhamer mentioned this pull request Sep 18, 2014
@shelhamer
Copy link
Member

@okn2020 I have reproduced the divergence on resume issue when restoring the iteration 20,000 snapshot.

I haven't however investigated the problem. I'm inclined to merge this as a useful example whether troubled by resuming or not and then let a fix follow.

@shelhamer shelhamer merged commit d149c9a into BVLC:dev Sep 19, 2014
shelhamer added a commit that referenced this pull request Sep 19, 2014
  Add contrastive loss layer, tests, and a siamese network example
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
  Add contrastive loss layer, tests, and a siamese network example
@shelhamer shelhamer mentioned this pull request Oct 2, 2014
8 tasks
@shelhamer
Copy link
Member

@okn2020 @amiralush the divergence on resuming or fine-tuning issue was fixed by #594 since reshapes no longer trigger re-allocation in all cases.

@okn2020
Copy link

okn2020 commented Oct 12, 2014

@shelhamer thank you, will try it out! @nickcarlevaris @shelhamer not sure if it is right place to ask, could you give some tips how to combine this contrastive loss layer with k-way softmax in one net? If I train features layer separately with k-way softmax and then fine-tune with contrastive loss resulting accuracy is very low.

RazvanRanca pushed a commit to RazvanRanca/caffe that referenced this pull request Nov 4, 2014
  Add contrastive loss layer, tests, and a siamese network example
@xiaoyong
Copy link

@okn2020 You may check out the DeepID2 paper:
Yi Sun, Xiaogang Wang, Xiaoou Tang. "Deep Learning Face Representation by Joint Identification-Verification".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants