summaryrefslogtreecommitdiff
path: root/candle-transformers/src/quantized_nn.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-09 11:06:04 +0100
committerGitHub <noreply@github.com>2024-03-09 11:06:04 +0100
commitdd00482ea3456111482ec1cee045d2ae8efaf8ba (patch)
tree1bc4d566d8c8599f887eb8f8a1ed07be2afb7715 /candle-transformers/src/quantized_nn.rs
parent936f6a48407ee111f52742cf48eccc61f6b62325 (diff)
downloadcandle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.gz
candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.bz2
candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.zip
Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model. * Integrate the quantized version of metavoice.
Diffstat (limited to 'candle-transformers/src/quantized_nn.rs')
-rw-r--r--candle-transformers/src/quantized_nn.rs10
1 files changed, 10 insertions, 0 deletions
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index 99e8d45b..21c88430 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -50,6 +50,16 @@ impl Module for Linear {
}
}
+pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
+ let bias = if bias {
+ Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?)
+ } else {
+ None
+ };
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight, bias })
+}
+
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
let weight = QMatMul::new(in_dim, out_dim, vb)?;