diff options
-rw-r--r-- | candle-transformers/src/models/with_tracing.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/quantized_nn.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/quantized_var_builder.rs | 1 |
3 files changed, 15 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 383ae71c..2ffec724 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -116,6 +116,12 @@ impl QMatMul { let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span }) } + + pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> { + let inner = candle::quantized::QMatMul::from_arc(ws)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } } impl Module for QMatMul { diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 21c88430..bb0a8641 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -35,6 +35,14 @@ pub struct Linear { } impl Linear { + pub fn from_arc( + weight: std::sync::Arc<candle::quantized::QTensor>, + bias: Option<Tensor>, + ) -> Result<Self> { + let weight = QMatMul::from_weights(weight)?; + Ok(Self { weight, bias }) + } + pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self { Self { weight, bias } } diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index bfd0629f..a963e311 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -3,6 +3,7 @@ use candle::{Device, Result, Shape}; use std::sync::Arc; // VarBuilder specialized for QTensors +#[derive(Clone)] pub struct VarBuilder { data: Arc<std::collections::HashMap<String, Arc<QTensor>>>, path: Vec<String>, |