summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/tests/quantized_tests.rs4
-rw-r--r--candle-kernels/src/quantized.cu24
2 files changed, 24 insertions, 4 deletions
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index d767531a..a2629341 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -738,10 +738,6 @@ macro_rules! quantized_matmul {
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
- if device.is_cuda() {
- // TODO Enable Cuda GGML sometime maybe.
- return Ok(());
- }
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu
index 762395d8..f8becbbc 100644
--- a/candle-kernels/src/quantized.cu
+++ b/candle-kernels/src/quantized.cu
@@ -877,6 +877,30 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#endif
}
+extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
+ const int i = blockIdx.x;
+
+ // assume 32 threads
+ const int tid = threadIdx.x;
+ const int il = tid/8;
+ const int ir = tid%8;
+ const int ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ float * y = yy + 256*i + 32*ir + 8*il;
+
+ const block_q8_0 * x = (const block_q8_0 *)vx + ib;
+ const float d = __half2float(x->d);
+
+ const int8_t * q = x->qs + 8*il;
+
+ for (int l = 0; l < 8; ++l) {
+ y[l] = d * q[l];
+ }
+}
+
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q8_K * x = (const block_q8_K *) vx;