summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-12 11:30:24 +0100
committerGitHub <noreply@github.com>2024-03-12 11:30:24 +0100
commitff03fd3fb314980d3273ffc49826d764541d76e2 (patch)
treedaf07e953c992146e67f593e94819ef66822dfe8 /candle-transformers
parentdf5f69444e438a7cd8d8ab4971579bf309b72114 (diff)
downloadcandle-ff03fd3fb314980d3273ffc49826d764541d76e2.tar.gz
candle-ff03fd3fb314980d3273ffc49826d764541d76e2.tar.bz2
candle-ff03fd3fb314980d3273ffc49826d764541d76e2.zip
Expose some helper functions to create quantized models. (#1837)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/with_tracing.rs6
-rw-r--r--candle-transformers/src/quantized_nn.rs8
-rw-r--r--candle-transformers/src/quantized_var_builder.rs1
3 files changed, 15 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
index 383ae71c..2ffec724 100644
--- a/candle-transformers/src/models/with_tracing.rs
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -116,6 +116,12 @@ impl QMatMul {
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
+
+ pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> {
+ let inner = candle::quantized::QMatMul::from_arc(ws)?;
+ let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
+ Ok(Self { inner, span })
+ }
}
impl Module for QMatMul {
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index 21c88430..bb0a8641 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -35,6 +35,14 @@ pub struct Linear {
}
impl Linear {
+ pub fn from_arc(
+ weight: std::sync::Arc<candle::quantized::QTensor>,
+ bias: Option<Tensor>,
+ ) -> Result<Self> {
+ let weight = QMatMul::from_weights(weight)?;
+ Ok(Self { weight, bias })
+ }
+
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs
index bfd0629f..a963e311 100644
--- a/candle-transformers/src/quantized_var_builder.rs
+++ b/candle-transformers/src/quantized_var_builder.rs
@@ -3,6 +3,7 @@ use candle::{Device, Result, Shape};
use std::sync::Arc;
// VarBuilder specialized for QTensors
+#[derive(Clone)]
pub struct VarBuilder {
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
path: Vec<String>,