diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-09 11:06:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-09 11:06:04 +0100 |
commit | dd00482ea3456111482ec1cee045d2ae8efaf8ba (patch) | |
tree | 1bc4d566d8c8599f887eb8f8a1ed07be2afb7715 /candle-examples/examples/metavoice/main.rs | |
parent | 936f6a48407ee111f52742cf48eccc61f6b62325 (diff) | |
download | candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.gz candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.bz2 candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.zip |
Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model.
* Integrate the quantized version of metavoice.
Diffstat (limited to 'candle-examples/examples/metavoice/main.rs')
-rw-r--r-- | candle-examples/examples/metavoice/main.rs | 44 |
1 files changed, 36 insertions, 8 deletions
diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index 7635277c..7a7ec3e4 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -11,6 +11,7 @@ use std::io::Write; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::encodec; use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer}; +use candle_transformers::models::quantized_metavoice::transformer as qtransformer; use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; @@ -26,6 +27,11 @@ enum ArgDType { Bf16, } +enum Transformer { + Normal(transformer::Model), + Quantized(qtransformer::Model), +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -40,6 +46,10 @@ struct Args { #[arg(long)] prompt: String, + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + /// The guidance scale. #[arg(long, default_value_t = 3.0)] guidance_scale: f64, @@ -116,10 +126,6 @@ fn main() -> Result<()> { }; let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?; - let first_stage_weights = match &args.first_stage_weights { - Some(w) => std::path::PathBuf::from(w), - None => repo.get("first_stage.safetensors")?, - }; let second_stage_weights = match &args.second_stage_weights { Some(w) => std::path::PathBuf::from(w), None => repo.get("second_stage.safetensors")?, @@ -135,10 +141,27 @@ fn main() -> Result<()> { ArgDType::F16 => DType::F16, ArgDType::Bf16 => DType::BF16, }; - let first_stage_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_config = transformer::Config::cfg1b_v0_1(); - let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + let mut first_stage_model = if args.quantized { + let filename = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage_q4k.gguf")?, + }; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; + let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?; + Transformer::Quantized(first_stage_model) + } else { + let first_stage_weights = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage.safetensors")?, + }; + let first_stage_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + Transformer::Normal(first_stage_model) + }; let second_stage_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? }; @@ -178,7 +201,12 @@ fn main() -> Result<()> { let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &device)?; let input = Tensor::stack(&[&input, &input], 0)?; - let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?; + let logits = match &mut first_stage_model { + Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?, + Transformer::Quantized(m) => { + m.forward(&input, &spk_emb, tokens.len() - context_size)? + } + }; let logits0 = logits.i((0, 0))?; let logits1 = logits.i((1, 0))?; let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?; |