diff options
-rw-r--r-- | candle-core/tests/quantized_tests.rs | 4 | ||||
-rw-r--r-- | candle-kernels/src/quantized.cu | 24 |
2 files changed, 24 insertions, 4 deletions
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index d767531a..a2629341 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -738,10 +738,6 @@ macro_rules! quantized_matmul { // stable. https://github.com/rust-lang/rust/issues/29599 ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { fn $fn_name(device: &Device) -> Result<()> { - if device.is_cuda() { - // TODO Enable Cuda GGML sometime maybe. - return Ok(()); - } test_matmul(device, (1, 3, 4, 256), $dtype)?; Ok(()) } diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index 762395d8..f8becbbc 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -877,6 +877,30 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f #endif } +extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { + const int i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + float * y = yy + 256*i + 32*ir + 8*il; + + const block_q8_0 * x = (const block_q8_0 *)vx + ib; + const float d = __half2float(x->d); + + const int8_t * q = x->qs + 8*il; + + for (int l = 0; l < 8; ++l) { + y[l] = d * q[l]; + } +} + extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) { const block_q8_K * x = (const block_q8_K *) vx; |