summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-datasets/src/vision/mnist.rs4
-rw-r--r--candle-examples/examples/mnist-training/main.rs105
2 files changed, 106 insertions, 3 deletions
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 1085edd6..30b0d01f 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -101,10 +101,10 @@ pub fn load() -> Result<crate::vision::Dataset> {
);
let repo = api.repo(repo);
let test_parquet_filename = repo
- .get("mnist/mnist-test.parquet")
+ .get("mnist/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
- .get("mnist/mnist-train.parquet")
+ .get("mnist/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index ea4b4bce..5bbce31b 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -6,9 +6,10 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
+use rand::prelude::*;
use candle::{DType, Result, Tensor, D};
-use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap};
+use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
@@ -58,6 +59,40 @@ impl Model for Mlp {
}
}
+#[derive(Debug)]
+struct ConvNet {
+ conv1: Conv2d,
+ conv2: Conv2d,
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl Model for ConvNet {
+ fn new(vs: VarBuilder) -> Result<Self> {
+ let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
+ let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
+ let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
+ Ok(Self {
+ conv1,
+ conv2,
+ fc1,
+ fc2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_sz, _img_dim) = xs.dims2()?;
+ let xs = xs.reshape((b_sz, 1, 28, 28))?;
+ let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
+ let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
+ let xs = xs.flatten_from(1)?;
+ let xs = self.fc1.forward(&xs)?;
+ let xs = xs.relu()?;
+ self.fc2.forward(&xs)
+ }
+}
+
struct TrainingArgs {
learning_rate: f64,
load: Option<String>,
@@ -65,6 +100,71 @@ struct TrainingArgs {
epochs: usize,
}
+fn training_loop_cnn(
+ m: candle_datasets::vision::Dataset,
+ args: &TrainingArgs,
+) -> anyhow::Result<()> {
+ const BSIZE: usize = 64;
+
+ let dev = candle::Device::cuda_if_available(0)?;
+
+ let train_labels = m.train_labels;
+ let train_images = m.train_images.to_device(&dev)?;
+ let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ let mut varmap = VarMap::new();
+ let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
+ let model = ConvNet::new(vs.clone())?;
+
+ if let Some(load) = &args.load {
+ println!("loading weights from {load}");
+ varmap.load(load)?
+ }
+
+ let adamw_params = candle_nn::ParamsAdamW {
+ lr: args.learning_rate,
+ ..Default::default()
+ };
+ let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
+ let test_images = m.test_images.to_device(&dev)?;
+ let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+ let n_batches = train_images.dim(0)? / BSIZE;
+ let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
+ for epoch in 1..args.epochs {
+ let mut sum_loss = 0f32;
+ batch_idxs.shuffle(&mut thread_rng());
+ for batch_idx in batch_idxs.iter() {
+ let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
+ let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
+ let logits = model.forward(&train_images)?;
+ let log_sm = ops::log_softmax(&logits, D::Minus1)?;
+ let loss = loss::nll(&log_sm, &train_labels)?;
+ opt.backward_step(&loss)?;
+ sum_loss += loss.to_vec0::<f32>()?;
+ }
+ let avg_loss = sum_loss / n_batches as f32;
+
+ let test_logits = model.forward(&test_images)?;
+ let sum_ok = test_logits
+ .argmax(D::Minus1)?
+ .eq(&test_labels)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::<f32>()?;
+ let test_accuracy = sum_ok / test_labels.dims1()? as f32;
+ println!(
+ "{epoch:4} train loss {:8.5} test acc: {:5.2}%",
+ avg_loss,
+ 100. * test_accuracy
+ );
+ }
+ if let Some(save) = &args.save {
+ println!("saving trained weights in {save}");
+ varmap.save(save)?
+ }
+ Ok(())
+}
+
fn training_loop<M: Model>(
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
@@ -118,6 +218,7 @@ fn training_loop<M: Model>(
enum WhichModel {
Linear,
Mlp,
+ Cnn,
}
#[derive(Parser)]
@@ -160,6 +261,7 @@ pub fn main() -> anyhow::Result<()> {
let default_learning_rate = match args.model {
WhichModel::Linear => 1.,
WhichModel::Mlp => 0.05,
+ WhichModel::Cnn => 0.001,
};
let training_args = TrainingArgs {
epochs: args.epochs,
@@ -170,5 +272,6 @@ pub fn main() -> anyhow::Result<()> {
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
+ WhichModel::Cnn => training_loop_cnn(m, &training_args),
}
}