summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-28 20:00:39 +0200
committerGitHub <noreply@github.com>2023-10-28 19:00:39 +0100
commit012ae0090e70da67987a0308ef18587e9e8a8e44 (patch)
treeb6f50c7e0d460d7abcaa2110568ec7af11ebd7d3 /candle-examples/examples/llama2-c
parent95a857cf57c56a34ecdaae5372f2a13ebd900001 (diff)
downloadcandle-012ae0090e70da67987a0308ef18587e9e8a8e44.tar.gz
candle-012ae0090e70da67987a0308ef18587e9e8a8e44.tar.bz2
candle-012ae0090e70da67987a0308ef18587e9e8a8e44.zip
Infer the config for llama2-c. (#1208)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs14
-rw-r--r--candle-examples/examples/llama2-c/training.rs2
2 files changed, 13 insertions, 3 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index a3f01ae2..0ceb27af 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
- let config = Config::tiny();
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
+ let (_vocab_size, dim) = vb
+ .get_no_shape("model.embed_tokens.weight")?
+ .shape()
+ .dims2()?;
+ let config = match dim {
+ 64 => Config::tiny_260k(),
+ 288 => Config::tiny_15m(),
+ 512 => Config::tiny_42m(),
+ 768 => Config::tiny_110m(),
+ _ => anyhow::bail!("no config for dim {dim}"),
+ };
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
@@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
- let config = Config::tiny();
+ let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index 150a3272..b2aa0889 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
- let config = Config::tiny();
+ let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);