diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-04 11:57:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-04 11:57:05 +0100 |
commit | 50be8a98ba08295ec3ff46d0a779937bc06d369e (patch) | |
tree | 950e602693e5601e7c24a216adbd1426973b49c0 /candle-examples/examples/stable-lm/main.rs | |
parent | 58cc896e692936e36f3c68cf33ce949a7298bd4d (diff) | |
download | candle-50be8a98ba08295ec3ff46d0a779937bc06d369e.tar.gz candle-50be8a98ba08295ec3ff46d0a779937bc06d369e.tar.bz2 candle-50be8a98ba08295ec3ff46d0a779937bc06d369e.zip |
Quantized support for stable-lm2. (#1654)
* Quantized support for stable-lm2.
* Quantized support for v2-zephyr.
Diffstat (limited to 'candle-examples/examples/stable-lm/main.rs')
-rw-r--r-- | candle-examples/examples/stable-lm/main.rs | 29 |
1 files changed, 24 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 415c6e7e..abe7020c 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -162,7 +162,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 1000)] sample_len: usize, #[arg(long)] @@ -171,7 +171,7 @@ struct Args { #[arg(long, default_value = "main")] revision: String, - #[arg(long, default_value = "v1-orig")] + #[arg(long, default_value = "v2")] which: Which, #[arg(long)] @@ -239,7 +239,14 @@ fn main() -> Result<()> { )); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), - None => repo.get("tokenizer.json")?, + None => match args.which { + Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => { + repo.get("tokenizer.json")? + } + Which::V2 | Which::V2Zephyr => api + .model("lmz/candle-stablelm".to_string()) + .get("tokenizer-gpt4.json")?, + }, }; let filenames = match args.weight_files { Some(files) => files @@ -247,8 +254,20 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::<Vec<_>>(), None => match (args.which, args.quantized) { - (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?], - (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => { + (Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?], + (Which::V2, true) => { + let gguf = api + .model("lmz/candle-stablelm".to_string()) + .get("stablelm-2-1_6b-q4k.gguf")?; + vec![gguf] + } + (Which::V2Zephyr, true) => { + let gguf = api + .model("lmz/candle-stablelm".to_string()) + .get("stablelm-2-zephyr-1_6b-q4k.gguf")?; + vec![gguf] + } + (Which::V1Zephyr | Which::Code, true) => { anyhow::bail!("Quantized {:?} variant not supported.", args.which) } (Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => { |