summaryrefslogtreecommitdiff
path: root/candle-examples/examples/metavoice/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/metavoice/main.rs')
-rw-r--r--candle-examples/examples/metavoice/main.rs44
1 files changed, 36 insertions, 8 deletions
diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs
index 7635277c..7a7ec3e4 100644
--- a/candle-examples/examples/metavoice/main.rs
+++ b/candle-examples/examples/metavoice/main.rs
@@ -11,6 +11,7 @@ use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
+use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
@@ -26,6 +27,11 @@ enum ArgDType {
Bf16,
}
+enum Transformer {
+ Normal(transformer::Model),
+ Quantized(qtransformer::Model),
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -40,6 +46,10 @@ struct Args {
#[arg(long)]
prompt: String,
+ /// Use the quantized version of the model.
+ #[arg(long)]
+ quantized: bool,
+
/// The guidance scale.
#[arg(long, default_value_t = 3.0)]
guidance_scale: f64,
@@ -116,10 +126,6 @@ fn main() -> Result<()> {
};
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
- let first_stage_weights = match &args.first_stage_weights {
- Some(w) => std::path::PathBuf::from(w),
- None => repo.get("first_stage.safetensors")?,
- };
let second_stage_weights = match &args.second_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("second_stage.safetensors")?,
@@ -135,10 +141,27 @@ fn main() -> Result<()> {
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
- let first_stage_vb =
- unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
+
let first_stage_config = transformer::Config::cfg1b_v0_1();
- let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
+ let mut first_stage_model = if args.quantized {
+ let filename = match &args.first_stage_weights {
+ Some(w) => std::path::PathBuf::from(w),
+ None => repo.get("first_stage_q4k.gguf")?,
+ };
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
+ let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
+ Transformer::Quantized(first_stage_model)
+ } else {
+ let first_stage_weights = match &args.first_stage_weights {
+ Some(w) => std::path::PathBuf::from(w),
+ None => repo.get("first_stage.safetensors")?,
+ };
+ let first_stage_vb =
+ unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
+ let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
+ Transformer::Normal(first_stage_model)
+ };
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
@@ -178,7 +201,12 @@ fn main() -> Result<()> {
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::stack(&[&input, &input], 0)?;
- let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
+ let logits = match &mut first_stage_model {
+ Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
+ Transformer::Quantized(m) => {
+ m.forward(&input, &spk_emb, tokens.len() - context_size)?
+ }
+ };
let logits0 = logits.i((0, 0))?;
let logits1 = logits.i((1, 0))?;
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;