summaryrefslogtreecommitdiff
path: root/candle-examples/examples/simple-training/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/simple-training/main.rs')
-rw-r--r--candle-examples/examples/simple-training/main.rs88
1 files changed, 64 insertions, 24 deletions
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs
index 35b938e8..f15aa60c 100644
--- a/candle-examples/examples/simple-training/main.rs
+++ b/candle-examples/examples/simple-training/main.rs
@@ -2,8 +2,10 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+use clap::{Parser, ValueEnum};
+
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
-use candle_nn::{loss, ops, Linear};
+use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784;
@@ -44,7 +46,7 @@ impl VarStore {
}
}
- fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
+ fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
@@ -59,8 +61,7 @@ impl VarStore {
}
return Ok(tensor.as_tensor().clone());
}
- // TODO: Proper initialization using the `Init` enum.
- let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
+ let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
@@ -77,21 +78,36 @@ impl VarStore {
}
}
-fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
- let ws = vs.get((dim2, dim1), "weight")?;
- let bs = vs.get(dim2, "bias")?;
+fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
+ let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
+ let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
+ Ok(Linear::new(ws, Some(bs)))
+}
+
+fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
+ let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
+ let bound = 1. / (in_dim as f64).sqrt();
+ let init_bs = Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
-#[allow(unused)]
+trait Model: Sized {
+ fn new(vs: VarStore) -> Result<Self>;
+ fn forward(&self, xs: &Tensor) -> Result<Tensor>;
+}
+
struct LinearModel {
linear: Linear,
}
-#[allow(unused)]
-impl LinearModel {
+impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
- let linear = linear(IMAGE_DIM, LABELS, vs)?;
+ let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
@@ -100,14 +116,12 @@ impl LinearModel {
}
}
-#[allow(unused)]
struct Mlp {
ln1: Linear,
ln2: Linear,
}
-#[allow(unused)]
-impl Mlp {
+impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
@@ -121,26 +135,22 @@ impl Mlp {
}
}
-pub fn main() -> anyhow::Result<()> {
+fn training_loop<M: Model>(
+ m: candle_nn::vision::Dataset,
+ learning_rate: f64,
+) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
- // Load the dataset
- let m = candle_nn::vision::mnist::load_dir("data")?;
- println!("train-images: {:?}", m.train_images.shape());
- println!("train-labels: {:?}", m.train_labels.shape());
- println!("test-images: {:?}", m.test_images.shape());
- println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
let vs = VarStore::new(DType::F32, dev);
- let model = LinearModel::new(vs.clone())?;
- // let model = Mlp::new(vs)?;
+ let model = M::new(vs.clone())?;
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
- let sgd = candle_nn::SGD::new(&all_vars, 1.0);
+ let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
@@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
}
Ok(())
}
+
+#[derive(ValueEnum, Clone)]
+enum WhichModel {
+ Linear,
+ Mlp,
+}
+
+#[derive(Parser)]
+struct Args {
+ #[clap(value_enum, default_value_t = WhichModel::Linear)]
+ model: WhichModel,
+
+ #[arg(long)]
+ learning_rate: Option<f64>,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+ // Load the dataset
+ let m = candle_nn::vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.shape());
+ println!("train-labels: {:?}", m.train_labels.shape());
+ println!("test-images: {:?}", m.test_images.shape());
+ println!("test-labels: {:?}", m.test_labels.shape());
+
+ match args.model {
+ WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
+ WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
+ }
+}