diff options
Diffstat (limited to 'candle-examples/examples/metavoice/main.rs')
-rw-r--r-- | candle-examples/examples/metavoice/main.rs | 44 |
1 files changed, 36 insertions, 8 deletions
diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index 7635277c..7a7ec3e4 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -11,6 +11,7 @@ use std::io::Write; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::encodec; use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer}; +use candle_transformers::models::quantized_metavoice::transformer as qtransformer; use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; @@ -26,6 +27,11 @@ enum ArgDType { Bf16, } +enum Transformer { + Normal(transformer::Model), + Quantized(qtransformer::Model), +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -40,6 +46,10 @@ struct Args { #[arg(long)] prompt: String, + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + /// The guidance scale. #[arg(long, default_value_t = 3.0)] guidance_scale: f64, @@ -116,10 +126,6 @@ fn main() -> Result<()> { }; let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?; - let first_stage_weights = match &args.first_stage_weights { - Some(w) => std::path::PathBuf::from(w), - None => repo.get("first_stage.safetensors")?, - }; let second_stage_weights = match &args.second_stage_weights { Some(w) => std::path::PathBuf::from(w), None => repo.get("second_stage.safetensors")?, @@ -135,10 +141,27 @@ fn main() -> Result<()> { ArgDType::F16 => DType::F16, ArgDType::Bf16 => DType::BF16, }; - let first_stage_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_config = transformer::Config::cfg1b_v0_1(); - let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + let mut first_stage_model = if args.quantized { + let filename = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage_q4k.gguf")?, + }; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; + let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?; + Transformer::Quantized(first_stage_model) + } else { + let first_stage_weights = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage.safetensors")?, + }; + let first_stage_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + Transformer::Normal(first_stage_model) + }; let second_stage_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? }; @@ -178,7 +201,12 @@ fn main() -> Result<()> { let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &device)?; let input = Tensor::stack(&[&input, &input], 0)?; - let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?; + let logits = match &mut first_stage_model { + Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?, + Transformer::Quantized(m) => { + m.forward(&input, &spk_emb, tokens.len() - context_size)? + } + }; let logits0 = logits.i((0, 0))?; let logits1 = logits.i((1, 0))?; let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?; |