diff options
Diffstat (limited to 'candle-examples/examples/simple-training/main.rs')
-rw-r--r-- | candle-examples/examples/simple-training/main.rs | 88 |
1 files changed, 64 insertions, 24 deletions
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index 35b938e8..f15aa60c 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -2,8 +2,10 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use clap::{Parser, ValueEnum}; + use candle::{DType, Device, Result, Shape, Tensor, Var, D}; -use candle_nn::{loss, ops, Linear}; +use candle_nn::{loss, ops, Init, Linear}; use std::sync::{Arc, Mutex}; const IMAGE_DIM: usize = 784; @@ -44,7 +46,7 @@ impl VarStore { } } - fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> { + fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> { let shape = shape.into(); let path = if self.path.is_empty() { tensor_name.to_string() @@ -59,8 +61,7 @@ impl VarStore { } return Ok(tensor.as_tensor().clone()); } - // TODO: Proper initialization using the `Init` enum. - let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?; + let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?; let tensor = var.as_tensor().clone(); tensor_data.tensors.insert(path, var); Ok(tensor) @@ -77,21 +78,36 @@ impl VarStore { } } -fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> { - let ws = vs.get((dim2, dim1), "weight")?; - let bs = vs.get(dim2, "bias")?; +fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> { + let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; + let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?; + Ok(Linear::new(ws, Some(bs))) +} + +fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } -#[allow(unused)] +trait Model: Sized { + fn new(vs: VarStore) -> Result<Self>; + fn forward(&self, xs: &Tensor) -> Result<Tensor>; +} + struct LinearModel { linear: Linear, } -#[allow(unused)] -impl LinearModel { +impl Model for LinearModel { fn new(vs: VarStore) -> Result<Self> { - let linear = linear(IMAGE_DIM, LABELS, vs)?; + let linear = linear_z(IMAGE_DIM, LABELS, vs)?; Ok(Self { linear }) } @@ -100,14 +116,12 @@ impl LinearModel { } } -#[allow(unused)] struct Mlp { ln1: Linear, ln2: Linear, } -#[allow(unused)] -impl Mlp { +impl Model for Mlp { fn new(vs: VarStore) -> Result<Self> { let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; let ln2 = linear(100, LABELS, vs.pp("ln2"))?; @@ -121,26 +135,22 @@ impl Mlp { } } -pub fn main() -> anyhow::Result<()> { +fn training_loop<M: Model>( + m: candle_nn::vision::Dataset, + learning_rate: f64, +) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; - // Load the dataset - let m = candle_nn::vision::mnist::load_dir("data")?; - println!("train-images: {:?}", m.train_images.shape()); - println!("train-labels: {:?}", m.train_labels.shape()); - println!("test-images: {:?}", m.test_images.shape()); - println!("test-labels: {:?}", m.test_labels.shape()); let train_labels = m.train_labels; let train_images = m.train_images; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; let vs = VarStore::new(DType::F32, dev); - let model = LinearModel::new(vs.clone())?; - // let model = Mlp::new(vs)?; + let model = M::new(vs.clone())?; let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::<Vec<_>>(); - let sgd = candle_nn::SGD::new(&all_vars, 1.0); + let sgd = candle_nn::SGD::new(&all_vars, learning_rate); let test_images = m.test_images; let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { @@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> { } Ok(()) } + +#[derive(ValueEnum, Clone)] +enum WhichModel { + Linear, + Mlp, +} + +#[derive(Parser)] +struct Args { + #[clap(value_enum, default_value_t = WhichModel::Linear)] + model: WhichModel, + + #[arg(long)] + learning_rate: Option<f64>, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + // Load the dataset + let m = candle_nn::vision::mnist::load_dir("data")?; + println!("train-images: {:?}", m.train_images.shape()); + println!("train-labels: {:?}", m.train_labels.shape()); + println!("test-images: {:?}", m.test_images.shape()); + println!("test-labels: {:?}", m.test_labels.shape()); + + match args.model { + WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)), + WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)), + } +} |