Skip to content

Commit

Permalink
Fewer iters, more ops per iter
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed May 21, 2023
1 parent 637483b commit a06f7ec
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 76 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
option(LLAMA_BLAS "llama: use BLAS" OFF)
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
set(LLAMA_CUDA_BX "32" CACHE STRING "llama: x block size for dmmv CUDA kernels")
set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)

option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
Expand Down Expand Up @@ -185,8 +185,8 @@ if (LLAMA_CUBLAS)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)

add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_DMMV_BLOCK_X=${LLAMA_CUDA_BX})
add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY})
add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL})

if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
Expand Down
8 changes: 5 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,16 @@ ifdef LLAMA_CUBLAS
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ifdef LLAMA_CUDA_BX
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=$(LLAMA_CUDA_BX)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=32
endif # LLAMA_CUDA_BY
ifdef LLAMA_CUDA_BY
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1
endif # LLAMA_CUDA_BY
ifdef LLAMA_CUDA_UNROLL
NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL)
endif # LLAMA_CUDA_UNROLL
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif # LLAMA_CUBLAS
Expand Down
132 changes: 61 additions & 71 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");

#define WARP_SIZE 32

#define CUDA_MUL_BLOCK_SIZE 256

#define CUDA_DEQUANTIZE_BLOCK_SIZE 256

// dmmv = dequantize_mul_mat_vec
#define GGML_CUDA_DMMV_BLOCK_X 32
#ifndef GGML_CUDA_DMMV_BLOCK_X
#define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY
#endif
#ifndef GGML_CUDA_DMMV_BLOCK_Y
#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY
#endif
Expand Down Expand Up @@ -204,32 +210,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
dequantize_kernel(vx, ib, iqs, v0, v1);
}

template <int ncols, int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) {
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const int iter_stride = 2*GGML_CUDA_DMMV_BLOCK_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;


float tmp = 0; // partial sum for thread in warp

#ifdef GGML_CUDA_UNROLL
#pragma unroll
#endif
for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid;
const int ib = (row*ncols + col)/qk; // block index
const int iqs = (col%qk)/qr; // quant index
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index

// dequantize
float v0, v1;
dequantize_kernel(vx, ib, iqs, v0, v1);

// matrix multiplication
tmp += v0 * y[iybs + iqs + 0];
tmp += v1 * y[iybs + iqs + y_offset];
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter

// dequantize
float v0, v1;
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val

// matrix multiplication
tmp += v0 * y[iybs + iqs + j/qr + 0];
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
}
}

// sum up partial sums and write back result
Expand Down Expand Up @@ -274,72 +288,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

template<dequantize_kernel_t dequantize_kernel, int qk, int qr>
static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst,
const int ncols, const int nrows, cudaStream_t stream) {
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1);

// Use a switch statement for ncols so the compiler can unroll all loops:
switch (ncols) {
case 4096:
dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 5120:
dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 6656:
dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 8192:
dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 11008:
dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 13824:
dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 17920:
dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
case 22016:
dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
break;
default:
fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols);
GGML_ASSERT(false);
break;
}
}

static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<dequantize_q4_0, QK4_0, QR4_0>(vx, y, dst, ncols, nrows, stream);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<dequantize_q4_1, QK4_1, QR4_1>(vx, y, dst, ncols, nrows, stream);
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<dequantize_q5_0, QK5_0, QR5_0>(vx, y, dst, ncols, nrows, stream);
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<dequantize_q5_1, QK5_1, QR5_1>(vx, y, dst, ncols, nrows, stream);
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<dequantize_q8_0, QK8_0, QR8_0>(vx, y, dst, ncols, nrows, stream);
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
Expand All @@ -348,7 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
}

static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
dequantize_mul_mat_vec_cuda<convert_f16, 1, 1>(vx, y, dst, ncols, nrows, stream);
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
Expand Down

0 comments on commit a06f7ec

Please sign in to comment.