diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-28 20:00:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-28 19:00:39 +0100 |
commit | 012ae0090e70da67987a0308ef18587e9e8a8e44 (patch) | |
tree | b6f50c7e0d460d7abcaa2110568ec7af11ebd7d3 /candle-examples/examples/llama2-c | |
parent | 95a857cf57c56a34ecdaae5372f2a13ebd900001 (diff) | |
download | candle-012ae0090e70da67987a0308ef18587e9e8a8e44.tar.gz candle-012ae0090e70da67987a0308ef18587e9e8a8e44.tar.bz2 candle-012ae0090e70da67987a0308ef18587e9e8a8e44.zip |
Infer the config for llama2-c. (#1208)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 14 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/training.rs | 2 |
2 files changed, 13 insertions, 3 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index a3f01ae2..0ceb27af 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let config = Config::tiny(); let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let (_vocab_size, dim) = vb + .get_no_shape("model.embed_tokens.weight")? + .shape() + .dims2()?; + let config = match dim { + 64 => Config::tiny_260k(), + 288 => Config::tiny_15m(), + 512 => Config::tiny_42m(), + 768 => Config::tiny_110m(), + _ => anyhow::bail!("no config for dim {dim}"), + }; let freq_cis_real = vb .get( (config.seq_len, config.head_size() / 2), @@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); (model, config) } else if is_safetensors { - let config = Config::tiny(); + let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 150a3272..b2aa0889 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); - let config = Config::tiny(); + let config = Config::tiny_15m(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); |