summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-01 17:23:07 +0100
committerGitHub <noreply@github.com>2023-08-01 17:23:07 +0100
commita27239f3d9b77ad4c300de38d43c6ad64d6b5ea6 (patch)
tree8f31406d35aff7b5c6aecbfbdac773cf31574fce /candle-examples/examples/llama2-c/model.rs
parentbabee9f011805f59868b67053bdb8cce0e221e18 (diff)
downloadcandle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.gz
candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.bz2
candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.zip
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop.
Diffstat (limited to 'candle-examples/examples/llama2-c/model.rs')
-rw-r--r--candle-examples/examples/llama2-c/model.rs15
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 618bf67c..4e7015dd 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -15,6 +15,21 @@ pub struct Config {
pub norm_eps: f64,
}
+impl Config {
+ pub fn tiny() -> Self {
+ Self {
+ dim: 288,
+ hidden_dim: 768,
+ n_layers: 6,
+ n_heads: 6,
+ n_kv_heads: 6,
+ vocab_size: 32000,
+ seq_len: 256,
+ norm_eps: 1e-5,
+ }
+ }
+}
+
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,