summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r--candle-examples/examples/llama2-c/main.rs40
1 files changed, 33 insertions, 7 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index c02c65b9..8b64fdd2 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -4,6 +4,7 @@
extern crate intel_mkl_src;
mod model;
+mod training;
mod weights;
use clap::{Parser, Subcommand};
@@ -64,19 +65,33 @@ struct EvaluationCmd {
which_model: String,
}
+#[derive(Parser, Debug, Clone)]
+pub struct TrainingCmd {
+ /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
+ /// script from llama2.c https://github.com/karpathy/llama2.c
+ #[arg(long)]
+ pretokenized_dir: String,
+
+ #[arg(long, default_value_t = 32)]
+ batch_size: usize,
+
+ #[arg(long, default_value_t = 0.001)]
+ learning_rate: f64,
+}
+
#[derive(Subcommand, Debug, Clone)]
enum Task {
Inference(InferenceCmd),
- Evaluation(EvaluationCmd),
- Training,
+ Eval(EvaluationCmd),
+ Train(TrainingCmd),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
-struct Args {
+pub struct Args {
/// The task to be performed, inference, training or evaluation.
#[command(subcommand)]
- task: Task,
+ task: Option<Task>,
/// Run on CPU rather than on GPU.
#[arg(long)]
@@ -104,9 +119,19 @@ impl Args {
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match &args.task {
- Task::Inference(cmd) => run_inference(cmd, &args)?,
- Task::Evaluation(cmd) => run_eval(cmd, &args)?,
- Task::Training => todo!(),
+ None => {
+ let cmd = InferenceCmd {
+ temperature: None,
+ prompt: "".to_string(),
+ config: None,
+ model_id: "karpathy/tinyllamas".to_string(),
+ which_model: "stories15M.bin".to_string(),
+ };
+ run_inference(&cmd, &args)?
+ }
+ Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
+ Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
+ Some(Task::Train(cmd)) => training::run(cmd, &args)?,
}
Ok(())
}
@@ -202,6 +227,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
+ println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;