diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-01 16:53:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-01 15:53:42 +0100 |
commit | 2ed78ab336c99080b1f8830f48ea40e2e1026249 (patch) | |
tree | 86932671d0a531f5ba0b5aaa55bdf039acd4d388 /candle-core/src/quantized/mod.rs | |
parent | 237323c2bcfde1b7f881d2b71e21be27b3f73838 (diff) | |
download | candle-2ed78ab336c99080b1f8830f48ea40e2e1026249.tar.gz candle-2ed78ab336c99080b1f8830f48ea40e2e1026249.tar.bz2 candle-2ed78ab336c99080b1f8830f48ea40e2e1026249.zip |
Support for quantized tensors in the python api. (#706)
* Add more pyo3 support.
* Add some support for quantized tensors in pyo3.
* Add an arc layer on qmatmul.
* Add the quantized matmul.
* Quantization support.
* More quantization support.
* Test the python quantization.
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r-- | candle-core/src/quantized/mod.rs | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d87d2d5a..5c2bb2b2 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -230,12 +230,20 @@ impl QTensor { } #[derive(Debug)] -pub struct QMatMul(QTensor); +pub struct QMatMul(std::sync::Arc<QTensor>); impl QMatMul { - pub fn from_qtensor(qtensor: QTensor) -> Self { + pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self { Self(qtensor) } + + pub fn from_qtensor(qtensor: QTensor) -> Self { + Self(std::sync::Arc::new(qtensor)) + } + + pub fn inner(&self) -> &std::sync::Arc<QTensor> { + &self.0 + } } impl crate::CustomOp1 for QTensor { @@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.apply_op1_no_bwd(&self.0) + xs.apply_op1_no_bwd(self.0.as_ref()) } } |