summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-02 17:17:46 +0100
committerGitHub <noreply@github.com>2023-10-02 17:17:46 +0100
commit089fc3b5847668469cad740f29412d19d9e9fecf (patch)
tree57affdeeee6607066dfbc416db8ca5617ac7f0af /candle-core/src/quantized/mod.rs
parente04c789230c609c285991b78c29f1d6eef0d104f (diff)
downloadcandle-089fc3b5847668469cad740f29412d19d9e9fecf.tar.gz
candle-089fc3b5847668469cad740f29412d19d9e9fecf.tar.bz2
candle-089fc3b5847668469cad740f29412d19d9e9fecf.zip
Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup. * Fix the config file paths. * Use the standard matmul where possible.
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r--candle-core/src/quantized/mod.rs29
1 files changed, 19 insertions, 10 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 61fabc63..94e6bd23 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -232,19 +232,25 @@ impl QTensor {
}
#[derive(Clone, Debug)]
-pub struct QMatMul(std::sync::Arc<QTensor>);
+pub enum QMatMul {
+ QTensor(std::sync::Arc<QTensor>),
+ Tensor(Tensor),
+}
impl QMatMul {
- pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
- Self(qtensor)
- }
-
- pub fn from_qtensor(qtensor: QTensor) -> Self {
- Self(std::sync::Arc::new(qtensor))
+ pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
+ let t = match qtensor.dtype() {
+ GgmlDType::F32 | GgmlDType::F16 => {
+ let tensor = qtensor.dequantize(&Device::Cpu)?;
+ Self::Tensor(tensor)
+ }
+ _ => Self::QTensor(qtensor),
+ };
+ Ok(t)
}
- pub fn inner(&self) -> &std::sync::Arc<QTensor> {
- &self.0
+ pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
+ Self::from_arc(std::sync::Arc::new(qtensor))
}
}
@@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.apply_op1_no_bwd(self.0.as_ref())
+ match self {
+ Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
+ Self::Tensor(t) => xs.matmul(&t.t()?),
+ }
}
}