Skip to content

Commit

Permalink
More concepts
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Dec 16, 2023
1 parent f6f5beb commit e0fd465
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 42 deletions.
3 changes: 3 additions & 0 deletions include/etl/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ concept same_dimensions = etl_expr<T> && decay_traits<T>::dimensions() == decay_
template <typename T, typename E>
concept same_order = etl_expr<T> && decay_traits<T>::storage_order == decay_traits<E>::storage_order;

template <typename T, typename E>
concept same_dimensions_and_order = same_dimensions<T, E> && same_order<T, E>;

template <typename T, size_t D>
concept exact_dimensions = etl_expr<T> && decay_traits<T>::dimensions() == D;

Expand Down
4 changes: 1 addition & 3 deletions include/etl/expr/batch_embedding_lookup_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ struct etl_traits<etl::batch_embedding_lookup_expr<A, B>> {
* \return the DDth dimension of the expression
*/
template <size_t DD>
static constexpr size_t dim() {
static_assert(DD < 3, "Invalid dimensions access");

static constexpr size_t dim() requires(DD < 3) {
if (DD == 0) {
return decay_traits<A>::template dim<0>();
} else if (DD == 1) {
Expand Down
4 changes: 1 addition & 3 deletions include/etl/expr/batch_k_minus_scale_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1571,10 +1571,8 @@ struct etl_traits<etl::batch_k_minus_scale_expr<A, B, C>> {
* \param value The expression
* \return The transpose of the given expression.
*/
template <etl_1d A, etl_expr B, etl_1d C>
template <etl_1d A, etl_2d_or_4d B, etl_1d C>
batch_k_minus_scale_expr<detail::build_type<A>, detail::build_type<B>, detail::build_type<C>> batch_k_minus_scale(const A& a, const B& b, const C& c) {
static_assert(is_4d<B> || is_2d<B>, "etl::batch_k_minus_scale is only defined for 2D/4D RHS");

return {a, b, c};
}

Expand Down
4 changes: 1 addition & 3 deletions include/etl/expr/batch_k_scale_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,10 +1491,8 @@ struct etl_traits<etl::batch_k_scale_expr<A, B>> {
* \param value The expression
* \return The transpose of the given expression.
*/
template <etl_1d A, etl_expr B>
template <etl_1d A, etl_2d_or_4d B>
batch_k_scale_expr<detail::build_type<A>, detail::build_type<B>> batch_k_scale(const A& a, const B& b) {
static_assert(is_2d<B> || is_4d<B>, "etl::batch_k_scale is only defined for 2D/4D RHS");

return {a, b};
}

Expand Down
4 changes: 1 addition & 3 deletions include/etl/expr/batch_k_scale_plus_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1571,10 +1571,8 @@ struct etl_traits<etl::batch_k_scale_plus_expr<A, B, C>> {
* \param value The expression
* \return The transpose of the given expression.
*/
template <etl_1d A, etl_expr B, etl_1d C>
template <etl_1d A, etl_2d_or_4d B, etl_1d C>
batch_k_scale_plus_expr<detail::build_type<A>, detail::build_type<B>, detail::build_type<C>> batch_k_scale_plus(const A& a, const B& b, const C& c) {
static_assert(is_2d<B> || is_4d<B>, "etl::batch_k_scale_plus is only defined for 2D/4D RHS");

return {a, b, c};
}

Expand Down
7 changes: 1 addition & 6 deletions include/etl/expr/batch_softmax_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@ struct batch_softmax_expr : base_temporary_expr_un<batch_softmax_expr<A, Stable>
* \param a The input matrix
* \þaram c The output matrix
*/
template <same_dimensions<A> C>
template <same_dimensions_and_order<A> C>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const C& c) {
static constexpr etl::order order_lhs = decay_traits<C>::storage_order;
static constexpr etl::order order_rhs = decay_traits<A>::storage_order;

static_assert(order_lhs == order_rhs, "Cannot change storage order");

if constexpr (all_fast<A, C>) {
static_assert(decay_traits<A>::size() == decay_traits<C>::size(), "Invalid size");
} else {
Expand Down
3 changes: 1 addition & 2 deletions include/etl/expr/bias_batch_mean_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,7 @@ struct etl_traits<etl::bias_batch_mean_2d_expr<A, Mean>> {
* \return the DDth dimension of the expression
*/
template <size_t DD>
static constexpr size_t dim() {
static_assert(DD == 0, "Invalid dimensions access");
static constexpr size_t dim() requires(DD == 0) {
return decay_traits<A>::template dim<1>();
}

Expand Down
3 changes: 1 addition & 2 deletions include/etl/expr/bias_batch_var_4d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,7 @@ struct etl_traits<etl::bias_batch_var_4d_expr<A, B>> {
* \return the DDth dimension of the expression
*/
template <size_t DD>
static constexpr size_t dim() {
static_assert(DD == 0, "Invalid dimensions access");
static constexpr size_t dim() requires(DD == 0) {
return decay_traits<A>::template dim<1>();
}

Expand Down
3 changes: 1 addition & 2 deletions include/etl/expr/embedding_gradients_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ struct etl_traits<etl::embedding_gradients_expr<A, B, C>> {
* \return the DDth dimension of the expression
*/
template <size_t DD>
static constexpr size_t dim() {
static_assert(DD < 2, "Invalid dimensions access");
static constexpr size_t dim() requires(DD < 2) {
return DD == 0 ? decay_traits<C>::template dim<0>() : decay_traits<B>::template dim<1>();
}

Expand Down
3 changes: 1 addition & 2 deletions include/etl/expr/embedding_lookup_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ struct etl_traits<etl::embedding_lookup_expr<A, B>> {
* \return the DDth dimension of the expression
*/
template <size_t DD>
static constexpr size_t dim() {
static_assert(DD < 2, "Invalid dimensions access");
static constexpr size_t dim() requires(DD < 2) {
return DD == 0 ? decay_traits<A>::template dim<0>() : decay_traits<B>::template dim<1>();
}

Expand Down
4 changes: 1 addition & 3 deletions include/etl/impl/pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ namespace etl::impl {
*
* \return The implementation to use
*/
template <typename X, typename Y>
template <etl_dma X, etl_dma Y>
constexpr etl::pool_impl select_default_pool_impl(bool no_gpu) {
static_assert(all_dma<X, Y>, "DMA should be ensured at this point");

if (cudnn_enabled && all_floating<X, Y> && !no_gpu) {
return etl::pool_impl::CUDNN;
}
Expand Down
16 changes: 5 additions & 11 deletions include/etl/op/dim_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,35 +124,31 @@ struct dim_view {
* \brief Returns a pointer to the first element in memory.
* \return a pointer tot the first element in memory.
*/
memory_type memory_start() noexcept {
static_assert(is_dma<T> && D == 1, "This expression does not have direct memory access");
memory_type memory_start() noexcept requires(etl_dma<T> && D == 1) {
return sub.memory_start() + i * subsize(sub);
}

/*!
* \brief Returns a pointer to the first element in memory.
* \return a pointer tot the first element in memory.
*/
const_memory_type memory_start() const noexcept {
static_assert(is_dma<T> && D == 1, "This expression does not have direct memory access");
const_memory_type memory_start() const noexcept requires(etl_dma<T> && D == 1){
return sub.memory_start() + i * subsize(sub);
}

/*!
* \brief Returns a pointer to the past-the-end element in memory.
* \return a pointer tot the past-the-end element in memory.
*/
memory_type memory_end() noexcept {
static_assert(is_dma<T> && D == 1, "This expression does not have direct memory access");
memory_type memory_end() noexcept requires(etl_dma<T> && D == 1){
return sub.memory_start() + (i + 1) * subsize(sub);
}

/*!
* \brief Returns a pointer to the past-the-end element in memory.
* \return a pointer tot the past-the-end element in memory.
*/
const_memory_type memory_end() const noexcept {
static_assert(is_dma<T> && D == 1, "This expression does not have direct memory access");
const_memory_type memory_end() const noexcept requires(etl_dma<T> && D == 1){
return sub.memory_start() + (i + 1) * subsize(sub);
}

Expand Down Expand Up @@ -323,9 +319,7 @@ struct etl_traits<etl::dim_view<T, D>> {
* \return the D2th dimension of an expression of this type
*/
template <size_t D2>
static constexpr size_t dim() {
static_assert(D2 == 0, "Invalid dimension");

static constexpr size_t dim() requires(D2 == 0) {
return size();
}

Expand Down
3 changes: 1 addition & 2 deletions include/etl/sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ struct sparse_matrix_impl;
* \tparam D The number of dimensions
*/
template <typename T, size_t D>
requires(D == 2)
struct sparse_matrix_impl<T, sparse_storage::COO, D> final : dyn_base<sparse_matrix_impl<T, sparse_storage::COO, D>, T, D> {
static constexpr size_t n_dimensions = D; ///< The number of dimensions
static constexpr sparse_storage storage_format = sparse_storage::COO; ///< The sparse storage scheme
Expand All @@ -236,8 +237,6 @@ struct sparse_matrix_impl<T, sparse_storage::COO, D> final : dyn_base<sparse_mat
friend struct sparse_detail::sparse_reference<this_type>;
friend struct sparse_detail::sparse_reference<const this_type>;

static_assert(n_dimensions == 2, "Only 2D sparse matrix are supported");

private:
using base_type::_dimensions;
using base_type::_size;
Expand Down

0 comments on commit e0fd465

Please sign in to comment.