-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ArgMaxLayer with top k predictions #615
Conversation
blob_top_(new Blob<Dtype>()) { | ||
: blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)), | ||
blob_top_(new Blob<Dtype>()), | ||
top_k_(10) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use top_k_(5)
@sguada, thanks for your reviewing efforts! The new functionality is tested more thoroughly and ready to be merged! |
for (int j = 0; j < top_k_; ++j) { | ||
top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2] = | ||
top_k_results.top().first; | ||
top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2 + 1] = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kloudkl, is there a cleaner way to index into top_data
here and down on line 73? E.g. using Blob::offset()
seems like it'd be a lot cleaner/easier to read.
@kloudkl, this may just be a biased preference for my own code, but for the sake of consistent implementation of the same functionality, I'd rather see the truncated insertion sort from the |
The sorting algorithm is only a small part of the whole system, so it may not be worth much optimization. In terms of readability and conciseness, the std::sort is no doubt the winner. If you agree, I'd like to unify all the sorting codes to directly reuse it rather than "reinvent new code".
|
@kloudkl we don't need to sort all of the probabilities, just pick the top 5 (or k). That's what you're doing with the priority queue, and that's what I'm doing with the truncated insertion sort in |
@shuokay, very cool! |
template<typename Dtype> | ||
bool int_Dtype_pair_greater(std::pair<int, Dtype> a, | ||
std::pair<int, Dtype> b) { | ||
return a.second > b.second || (a.second == b.second && a.first > b.first); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kloudkl is there a reason we need a stable sort here? Can't we drop the second term of the OR here? It seems both unlikely that classes will have exactly the same probability and, if they do, unimportant that we keep them in the (arbitrary) order of the assigned indices. The second term here just seems to me to just add extra clutter to the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, also, I just realized this was in common_layers.hpp
. That doesn't seem like the right place for this. Can we just put it in argmax_layer.cpp
as _int_Dtype_pair_greater()
? I think that'd make more sense.
Yes, I like this solution quite a bit better than both the original solutions. I didn't know about Other than that, the only concern I have with this PR is the line comment I made above. |
Oh, one more thing. There's a lint error introduced:
@kloudkl, can you add |
@robwhess, everything you requested is in place. |
@kloudkl, that all looks good. One more very minor request: can you make the |
In Caffe, there is not a single function or method with a leading or training underscore. It's better to follow the convention. |
OK. That's fine, though to be fair, there also aren't any static functions in the |
Cool, tests are all passing for me. @sguada, @shelhamer, @jeffdonahue, @longjon, @sergeyk, I think this PR can be merged. |
Oh, wait, nevermind, don't merge yet. @kloudkl, the lint error I mentioned above now also applies to |
@robwhess, it is added. Thanks for your help! |
Cool. I think we're ready to merge here. |
Could be extended to support a vector of bottom blob labels instead of single blob? |
@bhack I think that's out of the scope of this PR. This should be merged, as is. |
Ok looks good to me, but I'm traveling and only took a quick glance so @longjon please review and merge. Thanks for your work everybody! |
} | ||
std::partial_sort( | ||
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, | ||
bottom_data_vector.end(), int_Dtype_pair_greater<Dtype>); | ||
// check if true label is in top k predictions | ||
for (int k = 0; k < top_k_; k++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use curly braces for loop bodies. (Although Google C++ style guide doesn't require it for single line statements, as far as I know we always use explicit curly braces in Caffe.)
Looks good except as noted. I like using |
All done. Any more concerns? |
To meet the needs for multiple top predictions in #499 and #598, the argmax is extended to output the top k results.
Unlike the implementation of the top k accuracy layer in #531, the argmax layer doesn't assume the input probabilities to be sorted. It picks the top k results with a priority queue.
Candidate reviewers:
@sguada who authored #421 Argmax layer
@robwhess who authored #531 Top-k accuracy
@shuokay who asked #598 How to make top-k prediction