diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-01 17:23:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-01 17:23:07 +0100 |
commit | a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6 (patch) | |
tree | 8f31406d35aff7b5c6aecbfbdac773cf31574fce /candle-examples/examples/llama2-c/model.rs | |
parent | babee9f011805f59868b67053bdb8cce0e221e18 (diff) | |
download | candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.gz candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.bz2 candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.zip |
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default.
* Add the training module and load the training dataset.
* Random dataset iterator.
* Proper valid-loss computation.
* Compute the evaluation loss.
* Add more substance to the training loop.
Diffstat (limited to 'candle-examples/examples/llama2-c/model.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 618bf67c..4e7015dd 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -15,6 +15,21 @@ pub struct Config { pub norm_eps: f64, } +impl Config { + pub fn tiny() -> Self { + Self { + dim: 288, + hidden_dim: 768, + n_layers: 6, + n_heads: 6, + n_kv_heads: 6, + vocab_size: 32000, + seq_len: 256, + norm_eps: 1e-5, + } + } +} + #[derive(Clone)] pub struct Cache { masks: Arc<Mutex<HashMap<usize, Tensor>>>, |