Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU-compatible upper bound and lower bound algorithms to AMReX_Algorithm #2958

Merged
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions Src/Base/AMReX_Algorithm.H
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,57 @@ namespace amrex
return hi;
}

template<typename ItType, typename ValType>
AMREX_GPU_HOST_DEVICE
ItType upper_bound (ItType first, ItType last, const ValType& val)
{
#if AMREX_DEVICE_COMPILE
std::ptrdiff_t count = last-first;
while(count>0){
auto it = first;
const auto step = count/2;
it += step;
if (!(val<*it)){
first = ++it;
count -= step + 1;
}
else{
count = step;
}
}

return first;
#else
return std::upper_bound(first, last, val);
#endif
}

template<typename ItType, typename ValType>
AMREX_GPU_HOST_DEVICE
ItType lower_bound (ItType first, ItType last, const ValType& val)
{
#ifdef AMREX_DEVICE_COMPILE
std::ptrdiff_t count = last-first;
while(count>0)
{
auto it = first;
const auto step = count/2;
it += step;
if (!(val<=*it)){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be *it < value, because the algorithm is only supposed to use <, not <=.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we actually need <= with this formulation. Otherwise with the vector [0,1,2,3,4,5,5,6,6,7] passing 0 as value gives 1 instead of 0 and passing 7 gives END instead of 7.

However, it could be reformulated as if (val > *it ) . I can do that if you think that it would be better

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. Is if (val > *it) mathematically equivalent to if (*it <val) (which I suggested) for your example? Maybe you didn't remove !?

Copy link
Member

@WeiqunZhang WeiqunZhang Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only use <, not >, because a data type with < but without > works with std::lower_bound.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry! My fault... I misread your comment: I thought that you meant !(val<*it) !
You are totally right!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't make difference for arithmetic types. But the user will expect this to work too.

struct foo {
    int i;
    bool operator< (foo const& rhs) const {
        return i < rhs.i;
    }
    // No operator> defined
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do the correction right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry again, I've implemented your suggestion now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problems. Thanks for your hard work!

first = ++it;
count -= step + 1;
}
else{
count = step;
}
}

return first;
#else
return std::lower_bound(first, last, val);
#endif
}

namespace detail {

struct clzll_tag {};
Expand Down