summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-lm/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-04 11:57:05 +0100
committerGitHub <noreply@github.com>2024-02-04 11:57:05 +0100
commit50be8a98ba08295ec3ff46d0a779937bc06d369e (patch)
tree950e602693e5601e7c24a216adbd1426973b49c0 /candle-examples/examples/stable-lm/main.rs
parent58cc896e692936e36f3c68cf33ce949a7298bd4d (diff)
downloadcandle-50be8a98ba08295ec3ff46d0a779937bc06d369e.tar.gz
candle-50be8a98ba08295ec3ff46d0a779937bc06d369e.tar.bz2
candle-50be8a98ba08295ec3ff46d0a779937bc06d369e.zip
Quantized support for stable-lm2. (#1654)
* Quantized support for stable-lm2. * Quantized support for v2-zephyr.
Diffstat (limited to 'candle-examples/examples/stable-lm/main.rs')
-rw-r--r--candle-examples/examples/stable-lm/main.rs29
1 files changed, 24 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs
index 415c6e7e..abe7020c 100644
--- a/candle-examples/examples/stable-lm/main.rs
+++ b/candle-examples/examples/stable-lm/main.rs
@@ -162,7 +162,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
@@ -171,7 +171,7 @@ struct Args {
#[arg(long, default_value = "main")]
revision: String,
- #[arg(long, default_value = "v1-orig")]
+ #[arg(long, default_value = "v2")]
which: Which,
#[arg(long)]
@@ -239,7 +239,14 @@ fn main() -> Result<()> {
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
- None => repo.get("tokenizer.json")?,
+ None => match args.which {
+ Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
+ repo.get("tokenizer.json")?
+ }
+ Which::V2 | Which::V2Zephyr => api
+ .model("lmz/candle-stablelm".to_string())
+ .get("tokenizer-gpt4.json")?,
+ },
};
let filenames = match args.weight_files {
Some(files) => files
@@ -247,8 +254,20 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match (args.which, args.quantized) {
- (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
- (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
+ (Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
+ (Which::V2, true) => {
+ let gguf = api
+ .model("lmz/candle-stablelm".to_string())
+ .get("stablelm-2-1_6b-q4k.gguf")?;
+ vec![gguf]
+ }
+ (Which::V2Zephyr, true) => {
+ let gguf = api
+ .model("lmz/candle-stablelm".to_string())
+ .get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
+ vec![gguf]
+ }
+ (Which::V1Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {