Skip to content

Commit

Permalink
cuda : support non-pow-2 number of experts
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Apr 1, 2024
1 parent 8c2f7b8 commit 4531b02
Showing 1 changed file with 43 additions and 15 deletions.
58 changes: 43 additions & 15 deletions ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,79 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
}

template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
static __global__ void k_argsort_f32_i32(const float * x, int * dst, int * dst_pad, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;

if (col >= ncols) return;
if (col >= ncols_pad) {
return;
}

const float * x_row = x + row * ncols;
int * dst_row = dst + row * ncols;
int * dst_row = dst_pad + row * ncols_pad;

// initialize indices
if (col < ncols) {
dst_row[col] = col;
}
dst_row[col] = col;

__syncthreads();

for (int k = 2; k <= ncols; k *= 2) {
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}

// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}

static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}

static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
static void argsort_f32_i32_cuda(ggml_backend_cuda_context & ctx, const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
GGML_ASSERT((ncols & (ncols - 1)) == 0);
const int ncols_pad = next_power_of_2(ncols);

ggml_cuda_pool_alloc<int> dst_padded_alloc;
int * dst_padded = dst;
if (ncols_pad > ncols) {
dst_padded = dst_padded_alloc.alloc(ctx.pool(), nrows * ncols_pad);
}

const dim3 block_dims(ncols, 1, 1);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, dst_padded, ncols, ncols_pad);
} else {
GGML_ASSERT(false);
}
Expand All @@ -73,5 +101,5 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];

argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
argsort_f32_i32_cuda(ctx, src0_d, (int *)dst_d, ncols, nrows, order, stream);
}

0 comments on commit 4531b02

Please sign in to comment.