summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/stable-lm/README.md5
-rw-r--r--candle-examples/examples/stable-lm/main.rs61
-rw-r--r--candle-transformers/src/models/stable_lm.rs27
3 files changed, 77 insertions, 16 deletions
diff --git a/candle-examples/examples/stable-lm/README.md b/candle-examples/examples/stable-lm/README.md
index ad3e4a5b..485812d3 100644
--- a/candle-examples/examples/stable-lm/README.md
+++ b/candle-examples/examples/stable-lm/README.md
@@ -8,6 +8,11 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in
order to be able to use it.
+Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
+
+StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
+and pass it via the --tokenizer-file argument.
+
## Running some example
```bash
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs
index ccd924a4..415c6e7e 100644
--- a/candle-examples/examples/stable-lm/main.rs
+++ b/candle-examples/examples/stable-lm/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
-use clap::Parser;
+use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
@@ -122,6 +122,16 @@ impl TextGeneration {
}
}
+#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
+enum Which {
+ V1Orig,
+ V1,
+ V1Zephyr,
+ V2,
+ V2Zephyr,
+ Code,
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -155,12 +165,15 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
- #[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
- model_id: String,
+ #[arg(long)]
+ model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
+ #[arg(long, default_value = "v1-orig")]
+ which: Which,
+
#[arg(long)]
tokenizer_file: Option<String>,
@@ -207,8 +220,20 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => match args.which {
+ Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
+ Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
+ Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
+ Which::Code => "stabilityai/stable-code-3b".to_string(),
+ Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
+ Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
+ },
+ };
+
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ model_id,
RepoType::Model,
args.revision,
));
@@ -221,19 +246,35 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
- None => {
- if args.quantized {
- vec![repo.get("model-q4k.gguf")?]
- } else {
+ None => match (args.which, args.quantized) {
+ (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
+ (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
+ anyhow::bail!("Quantized {:?} variant not supported.", args.which)
+ }
+ (Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
vec![repo.get("model.safetensors")?]
}
- }
+ (Which::Code, false) => {
+ candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
+ }
+ },
};
+
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
- let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
+ let config = match args.which {
+ Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
+ Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
+ let config_filename = repo.get("config.json")?;
+ let config = std::fs::read_to_string(config_filename)?;
+ let mut config: Config = serde_json::from_str(&config)?;
+ config.set_use_flash_attn(args.use_flash_attn);
+ config
+ }
+ };
+
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs
index ef06ea99..a49b8282 100644
--- a/candle-transformers/src/models/stable_lm.rs
+++ b/candle-transformers/src/models/stable_lm.rs
@@ -1,10 +1,11 @@
-use crate::models::with_tracing::{linear_no_bias, Linear};
+use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
+use serde::Deserialize;
use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
-#[derive(Debug, Clone, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) intermediate_size: usize,
@@ -18,7 +19,10 @@ pub struct Config {
pub(crate) max_position_embeddings: usize,
pub(crate) norm_eps: f64,
pub(crate) use_cache: bool,
- pub(crate) use_flash_attn: bool,
+ #[serde(default)]
+ pub(crate) use_qkv_bias: bool, // Used in StableLM-2
+ #[serde(default)]
+ pub(crate) use_flash_attn: bool, // Not in config.json
}
impl Config {
@@ -35,6 +39,7 @@ impl Config {
rope_theta: 10_000.,
max_position_embeddings: 4096,
norm_eps: 1e-5,
+ use_qkv_bias: false,
use_cache: true,
use_flash_attn,
}
@@ -51,6 +56,10 @@ impl Config {
pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
+
+ pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
+ self.use_flash_attn = use_flash_attn
+ }
}
#[derive(Debug)]
@@ -179,9 +188,15 @@ impl Attention {
let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
- let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
- let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
- let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let linear_layer = if cfg.use_qkv_bias {
+ linear
+ } else {
+ linear_no_bias
+ };
+
+ let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,