Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> {
static constexpr int qi = QI1_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_0> {
static constexpr int qk = QK2_0;
static constexpr int qr = QR2_0;
static constexpr int qi = QI2_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q2_0:
return dequantize_block_cont_cuda<QK2_0, QR2_0, dequantize_q2_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
Expand Down Expand Up @@ -771,6 +773,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q2_0:
return dequantize_block_cont_cuda<QK2_0, QR2_0, dequantize_q2_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
Expand Down Expand Up @@ -828,6 +832,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
return convert_unary_cuda<float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q2_0:
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
Expand All @@ -851,6 +857,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q2_0:
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
Expand All @@ -874,6 +882,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
return convert_unary_cuda<half, float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q2_0:
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const in
v.y = (2*bit_1 - 1) * d;
}

static __device__ __forceinline__ void dequantize_q2_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q2_0 * x = (const block_q2_0 *) vx;

const float d = x[ib].d;

// Q2_0: 2 bits per element, 4 elements per byte.
// Stored code c in {0,1,2,3} maps to symbol s = c - 1 in {-1, 0, +1, +2}.
const int byte_index_0 = iqs / 4;
const int bit_offset_0 = (iqs % 4) * 2;

const int byte_index_1 = (iqs + 1) / 4;
const int bit_offset_1 = ((iqs + 1) % 4) * 2;

const int c0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 0x3;
const int c1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 0x3;

v.x = (c0 - 1) * d;
v.y = (c1 - 1) * d;
}

static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/getrows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q2_0:
get_rows_cuda_q<QK2_0, QR2_0, dequantize_q2_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5202,6 +5202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -5240,6 +5241,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_Q1_0:
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
break;
case GGML_TYPE_Q2_0:
mul_mat_q_case<GGML_TYPE_Q2_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
Expand Down Expand Up @@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t

switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down
106 changes: 106 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
Expand Down Expand Up @@ -192,6 +193,7 @@ static constexpr __device__ int get_mmq_y_device() {
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q2_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
Expand Down Expand Up @@ -237,6 +239,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q2_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
Expand Down Expand Up @@ -395,6 +398,101 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
}

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)

constexpr int blocks_per_iter = MMQ_ITER_K / QK2_0;
constexpr int threads_per_row = blocks_per_iter * QI2_0;
constexpr int nrows = warp_size / threads_per_row;
constexpr int scale_entries_per_block = QK2_0 / QK8_1;
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;

const int txi = threadIdx.x % threads_per_row;
const int kbx = txi / QI2_0;
const int kqsx = txi % QI2_0;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;

if (need_check) {
i = min(i, i_max);
}

const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + kbx;
// Each 32-element chunk occupies 8 bytes of qs (32 elements * 2 bits = 64 bits)
const int qs_offset = 8*kqsx;
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
const int qs1 = bxi->qs[qs_offset + 4] | (bxi->qs[qs_offset + 5] << 8) |
(bxi->qs[qs_offset + 6] << 16) | (bxi->qs[qs_offset + 7] << 24);

// Unpack 32 2-bit codes into 8 int32s, each holding 4 signed int8s in {-1,0,1,2}.
int unpacked_bytes[8];
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int shift = j * 8;
const int codes = (qs0 >> shift) & 0xFF;
const int c0 = ((codes >> 0) & 0x3) - 1;
const int c1 = ((codes >> 2) & 0x3) - 1;
const int c2 = ((codes >> 4) & 0x3) - 1;
const int c3 = ((codes >> 6) & 0x3) - 1;
unpacked_bytes[j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
}
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int shift = j * 8;
const int codes = (qs1 >> shift) & 0xFF;
const int c0 = ((codes >> 0) & 0x3) - 1;
const int c1 = ((codes >> 2) & 0x3) - 1;
const int c2 = ((codes >> 4) & 0x3) - 1;
const int c3 = ((codes >> 6) & 0x3) - 1;
unpacked_bytes[4 + j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
}

const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}

const int ksx = threadIdx.x % scale_entries_per_row;
const int scale_block = ksx / scale_entries_per_block;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;

if (need_check) {
i = min(i, i_max);
}

const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + scale_block;

#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
#else
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
Expand Down Expand Up @@ -3273,6 +3371,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};

template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_0> {
static constexpr int vdr = VDR_Q2_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};

template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
case GGML_TYPE_Q2_0: return vec_dot_q2_0_q8_1;
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
Expand Down Expand Up @@ -38,6 +39,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
case GGML_TYPE_Q2_0: return VDR_Q2_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
Expand Down Expand Up @@ -988,6 +990,12 @@ static void mul_mat_vec_q_switch_type(
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q2_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"

TYPES_MMQ = [
"GGML_TYPE_Q1_0",
"GGML_TYPE_Q1_0", "GGML_TYPE_Q2_0",
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-q2_0.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_Q2_0);
61 changes: 61 additions & 0 deletions ggml/src/ggml-cuda/vecdotq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block

#define VDR_Q2_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q2_0_Q8_1_MMQ 2 // Q2_0 group 64: 128 bits (4 ints) per block, 2 32-element chunks

#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4

Expand Down Expand Up @@ -717,6 +720,64 @@ static __device__ __forceinline__ float vec_dot_q1_0_q8_1(
return d1 * d8 * sumi;
}

static __device__ __forceinline__ float vec_dot_q2_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

const block_q2_0 * bq2_0 = (const block_q2_0 *) vbq + kbx;

// Q2_0: 128 elements with ONE scale, 2 bits per element (4 elements per byte)
// Q8_1: 32 elements per block with individual scales
// iqs selects which of the 4 chunks of 32 elements to process (0-3)

Comment on lines +728 to +731
const float d2 = bq2_0->d;

// Process only the chunk specified by iqs
const block_q8_1 * bq8_1_chunk = bq8_1 + iqs;

// Load 64 bits (8 bytes) for this chunk from Q2_0: bytes [8*iqs, 8*iqs+8)
const int offset = iqs * 8;
const int v0 = bq2_0->qs[offset + 0] | (bq2_0->qs[offset + 1] << 8) |
(bq2_0->qs[offset + 2] << 16) | (bq2_0->qs[offset + 3] << 24);
const int v1 = bq2_0->qs[offset + 4] | (bq2_0->qs[offset + 5] << 8) |
(bq2_0->qs[offset + 6] << 16) | (bq2_0->qs[offset + 7] << 24);

// Unpack 32 2-bit codes into 8 int32s, each holding 4 signed int8 symbols in {-1,0,1,2}.
// Stored code c in {0,1,2,3} -> symbol s = c - 1.
int vi_bytes[8];
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int shift = j * 8;
const int codes = (v0 >> shift) & 0xFF;
const int c0 = ((codes >> 0) & 0x3) - 1;
const int c1 = ((codes >> 2) & 0x3) - 1;
const int c2 = ((codes >> 4) & 0x3) - 1;
const int c3 = ((codes >> 6) & 0x3) - 1;
vi_bytes[j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
}
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int shift = j * 8;
const int codes = (v1 >> shift) & 0xFF;
const int c0 = ((codes >> 0) & 0x3) - 1;
const int c1 = ((codes >> 2) & 0x3) - 1;
const int c2 = ((codes >> 4) & 0x3) - 1;
const int c3 = ((codes >> 6) & 0x3) - 1;
vi_bytes[4 + j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
}

// Compute dot product for this 32-element chunk
int sumi = 0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int u = get_int_b4(bq8_1_chunk->qs, j);
sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi);
}

// Apply Q2_0's single scale and this chunk's Q8_1 scale
const float d8 = __low2float(bq8_1_chunk->ds);
return d2 * d8 * sumi;
}

static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

Expand Down