diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 50e03de500747..d7b2c3394f590 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4609,8 +4609,8 @@ static __global__ void rope( template static __global__ void rope_neox( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims + const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4619,23 +4619,25 @@ static __global__ void rope_neox( } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col/2; + const int ib = col / n_dims; + const int ic = col % n_dims; + + const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; - // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero - const float cur_rot = -float(col)/ncols; + float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, cur_rot); + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; - const float x1 = x[i + ncols/2]; + const float x1 = x[i + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } static __global__ void rope_glm_f32( @@ -5738,20 +5740,26 @@ static void rope_cuda( template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.0f / n_dims; + if (pos == nullptr) { rope_neox<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims ); } else { rope_neox<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims ); } } @@ -6706,15 +6714,14 @@ inline void ggml_cuda_op_rope( GGML_ASSERT(false); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); } else if (is_neox) { - GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); if (src0->type == GGML_TYPE_F32) { rope_neox_cuda( - (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda( - (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else {