summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-01 16:53:42 +0200
committerGitHub <noreply@github.com>2023-09-01 15:53:42 +0100
commit2ed78ab336c99080b1f8830f48ea40e2e1026249 (patch)
tree86932671d0a531f5ba0b5aaa55bdf039acd4d388 /candle-core/src/quantized/mod.rs
parent237323c2bcfde1b7f881d2b71e21be27b3f73838 (diff)
downloadcandle-2ed78ab336c99080b1f8830f48ea40e2e1026249.tar.gz
candle-2ed78ab336c99080b1f8830f48ea40e2e1026249.tar.bz2
candle-2ed78ab336c99080b1f8830f48ea40e2e1026249.zip
Support for quantized tensors in the python api. (#706)
* Add more pyo3 support. * Add some support for quantized tensors in pyo3. * Add an arc layer on qmatmul. * Add the quantized matmul. * Quantization support. * More quantization support. * Test the python quantization.
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r--candle-core/src/quantized/mod.rs14
1 files changed, 11 insertions, 3 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index d87d2d5a..5c2bb2b2 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -230,12 +230,20 @@ impl QTensor {
}
#[derive(Debug)]
-pub struct QMatMul(QTensor);
+pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
- pub fn from_qtensor(qtensor: QTensor) -> Self {
+ pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
}
+
+ pub fn from_qtensor(qtensor: QTensor) -> Self {
+ Self(std::sync::Arc::new(qtensor))
+ }
+
+ pub fn inner(&self) -> &std::sync::Arc<QTensor> {
+ &self.0
+ }
}
impl crate::CustomOp1 for QTensor {
@@ -279,6 +287,6 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.apply_op1_no_bwd(&self.0)
+ xs.apply_op1_no_bwd(self.0.as_ref())
}
}