summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-02 14:14:02 +0100
committerGitHub <noreply@github.com>2023-08-02 14:14:02 +0100
commit4f17290ce05963ae3416f8224ddda77eb67be299 (patch)
tree33de62a391a13e264ffe74e1572e38b4973fd4de /candle-examples/examples/llama2-c
parent0902846f25ad35afd532853336f86fff2656e4c0 (diff)
downloadcandle-4f17290ce05963ae3416f8224ddda77eb67be299.tar.gz
candle-4f17290ce05963ae3416f8224ddda77eb67be299.tar.bz2
candle-4f17290ce05963ae3416f8224ddda77eb67be299.zip
Use AdamW in the llama2 training. (#308)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/training.rs11
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index 92aa90e6..e55c686c 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -150,12 +150,16 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
- let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
+ let params = candle_nn::ParamsAdamW {
+ lr: args.learning_rate,
+ ..Default::default()
+ };
+ let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
- sgd.backward_step(&loss)?;
+ opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
@@ -163,6 +167,9 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}");
}
+ if batch_index > 0 && batch_index % 1000 == 0 {
+ varmap.save("checkpoint.safetensors")?
+ }
}
Ok(())
}