summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-01 12:57:55 +0200
committerGitHub <noreply@github.com>2024-10-01 12:57:55 +0200
commitdef4c6cdeef78e437846efcb46a23006f539dee4 (patch)
tree60a9bd3a552f2c295a64297f41c546af4ed14ab6 /candle-core/src/quantized
parent888d886dd8d5cac2558064060c59a4b51b6aa530 (diff)
downloadcandle-def4c6cdeef78e437846efcb46a23006f539dee4.tar.gz
candle-def4c6cdeef78e437846efcb46a23006f539dee4.tar.bz2
candle-def4c6cdeef78e437846efcb46a23006f539dee4.zip
Cuda quantized mmv bugfix. (#2526)
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r--candle-core/src/quantized/cuda.rs26
1 files changed, 25 insertions, 1 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index b0df4997..3c24c0e5 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -321,7 +321,7 @@ fn mul_mat_via_q8_1(
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
- k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
+ k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
@@ -707,4 +707,28 @@ mod test {
assert_eq!(vs[15], 13138824.0);
Ok(())
}
+
+ // The following test used to fail under compute-sanitizer until #2526.
+ #[test]
+ fn cuda_mm_q8_1_pad() -> Result<()> {
+ let dev = CudaDevice::new(0)?;
+ let (x_rows, ncols, y_cols) = (4, 16, 2048);
+ let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
+ let y = dev.htod_sync_copy(&vs).w()?;
+ let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
+ xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
+ let cuda_storage = mul_mat_via_q8_1(
+ &xs.data,
+ &y.slice(..),
+ /* dtype */ GgmlDType::Q4_0,
+ /* x_rows */ x_rows,
+ /* x_cols */ ncols,
+ /* y_rows */ ncols,
+ /* y_cols */ y_cols,
+ &dev,
+ )?;
+ let vs = cuda_storage.as_cuda_slice::<f32>()?;
+ let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
+ Ok(())
+ }
}