summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama2-c/main.rs14
-rw-r--r--candle-examples/examples/llama2-c/training.rs2
-rw-r--r--candle-transformers/src/models/llama2_c.rs41
-rw-r--r--candle-transformers/src/quantized_var_builder.rs10
4 files changed, 63 insertions, 4 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index a3f01ae2..0ceb27af 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
- let config = Config::tiny();
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
+ let (_vocab_size, dim) = vb
+ .get_no_shape("model.embed_tokens.weight")?
+ .shape()
+ .dims2()?;
+ let config = match dim {
+ 64 => Config::tiny_260k(),
+ 288 => Config::tiny_15m(),
+ 512 => Config::tiny_42m(),
+ 768 => Config::tiny_110m(),
+ _ => anyhow::bail!("no config for dim {dim}"),
+ };
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
@@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
- let config = Config::tiny();
+ let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index 150a3272..b2aa0889 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
- let config = Config::tiny();
+ let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs
index 07a6e2f2..753770fb 100644
--- a/candle-transformers/src/models/llama2_c.rs
+++ b/candle-transformers/src/models/llama2_c.rs
@@ -17,7 +17,20 @@ pub struct Config {
}
impl Config {
- pub fn tiny() -> Self {
+ pub fn tiny_260k() -> Self {
+ Self {
+ dim: 64,
+ hidden_dim: 768,
+ n_layers: 5,
+ n_heads: 8,
+ n_kv_heads: 4,
+ vocab_size: 32000,
+ seq_len: 512,
+ norm_eps: 1e-5,
+ }
+ }
+
+ pub fn tiny_15m() -> Self {
Self {
dim: 288,
hidden_dim: 768,
@@ -29,6 +42,32 @@ impl Config {
norm_eps: 1e-5,
}
}
+
+ pub fn tiny_42m() -> Self {
+ Self {
+ dim: 512,
+ hidden_dim: 768,
+ n_layers: 8,
+ n_heads: 8,
+ n_kv_heads: 8,
+ vocab_size: 32000,
+ seq_len: 1024,
+ norm_eps: 1e-5,
+ }
+ }
+
+ pub fn tiny_110m() -> Self {
+ Self {
+ dim: 768,
+ hidden_dim: 768,
+ n_layers: 12,
+ n_heads: 12,
+ n_kv_heads: 12,
+ vocab_size: 32000,
+ seq_len: 1024,
+ norm_eps: 1e-5,
+ }
+ }
}
#[derive(Clone)]
diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs
index 259496d6..810802e8 100644
--- a/candle-transformers/src/quantized_var_builder.rs
+++ b/candle-transformers/src/quantized_var_builder.rs
@@ -77,6 +77,16 @@ impl VarBuilder {
}
}
+ pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
+ let path = self.path(name);
+ match self.data.get(&path) {
+ None => {
+ candle::bail!("cannot find tensor {name}")
+ }
+ Some(qtensor) => Ok(qtensor.clone()),
+ }
+ }
+
pub fn device(&self) -> &Device {
&self.device
}