diff options
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 40 |
1 files changed, 33 insertions, 7 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index c02c65b9..8b64fdd2 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -4,6 +4,7 @@ extern crate intel_mkl_src; mod model; +mod training; mod weights; use clap::{Parser, Subcommand}; @@ -64,19 +65,33 @@ struct EvaluationCmd { which_model: String, } +#[derive(Parser, Debug, Clone)] +pub struct TrainingCmd { + /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py + /// script from llama2.c https://github.com/karpathy/llama2.c + #[arg(long)] + pretokenized_dir: String, + + #[arg(long, default_value_t = 32)] + batch_size: usize, + + #[arg(long, default_value_t = 0.001)] + learning_rate: f64, +} + #[derive(Subcommand, Debug, Clone)] enum Task { Inference(InferenceCmd), - Evaluation(EvaluationCmd), - Training, + Eval(EvaluationCmd), + Train(TrainingCmd), } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] -struct Args { +pub struct Args { /// The task to be performed, inference, training or evaluation. #[command(subcommand)] - task: Task, + task: Option<Task>, /// Run on CPU rather than on GPU. #[arg(long)] @@ -104,9 +119,19 @@ impl Args { fn main() -> anyhow::Result<()> { let args = Args::parse(); match &args.task { - Task::Inference(cmd) => run_inference(cmd, &args)?, - Task::Evaluation(cmd) => run_eval(cmd, &args)?, - Task::Training => todo!(), + None => { + let cmd = InferenceCmd { + temperature: None, + prompt: "".to_string(), + config: None, + model_id: "karpathy/tinyllamas".to_string(), + which_model: "stories15M.bin".to_string(), + }; + run_inference(&cmd, &args)? + } + Some(Task::Inference(cmd)) => run_inference(cmd, &args)?, + Some(Task::Eval(cmd)) => run_eval(cmd, &args)?, + Some(Task::Train(cmd)) => training::run(cmd, &args)?, } Ok(()) } @@ -202,6 +227,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; + println!("{config:?}"); let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; |