diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-02 17:17:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 17:17:46 +0100 |
commit | 089fc3b5847668469cad740f29412d19d9e9fecf (patch) | |
tree | 57affdeeee6607066dfbc416db8ca5617ac7f0af /candle-core/src/quantized/mod.rs | |
parent | e04c789230c609c285991b78c29f1d6eef0d104f (diff) | |
download | candle-089fc3b5847668469cad740f29412d19d9e9fecf.tar.gz candle-089fc3b5847668469cad740f29412d19d9e9fecf.tar.bz2 candle-089fc3b5847668469cad740f29412d19d9e9fecf.zip |
Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.
* Fix the config file paths.
* Use the standard matmul where possible.
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r-- | candle-core/src/quantized/mod.rs | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 61fabc63..94e6bd23 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -232,19 +232,25 @@ impl QTensor { } #[derive(Clone, Debug)] -pub struct QMatMul(std::sync::Arc<QTensor>); +pub enum QMatMul { + QTensor(std::sync::Arc<QTensor>), + Tensor(Tensor), +} impl QMatMul { - pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self { - Self(qtensor) - } - - pub fn from_qtensor(qtensor: QTensor) -> Self { - Self(std::sync::Arc::new(qtensor)) + pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> { + let t = match qtensor.dtype() { + GgmlDType::F32 | GgmlDType::F16 => { + let tensor = qtensor.dequantize(&Device::Cpu)?; + Self::Tensor(tensor) + } + _ => Self::QTensor(qtensor), + }; + Ok(t) } - pub fn inner(&self) -> &std::sync::Arc<QTensor> { - &self.0 + pub fn from_qtensor(qtensor: QTensor) -> Result<Self> { + Self::from_arc(std::sync::Arc::new(qtensor)) } } @@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.apply_op1_no_bwd(self.0.as_ref()) + match self { + Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), + Self::Tensor(t) => xs.matmul(&t.t()?), + } } } |