summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mnist-training/main.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-08-29 13:10:05 +0200
committerGitHub <noreply@github.com>2023-08-29 13:10:05 +0200
commit14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f (patch)
tree11d5c84dedb610b9e4306030ec36929d1f03e980 /candle-examples/examples/mnist-training/main.rs
parent62ef494dc17c1f582b28c665e78f2aa78d846bb9 (diff)
parent2d5b7a735d2c9ccb890dae73862dc734ef0950ae (diff)
downloadcandle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.gz
candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.bz2
candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.zip
Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
Diffstat (limited to 'candle-examples/examples/mnist-training/main.rs')
-rw-r--r--candle-examples/examples/mnist-training/main.rs10
1 files changed, 9 insertions, 1 deletions
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index bcf8677d..ea4b4bce 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -138,12 +138,20 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
+
+ /// The file where to load the trained weights from, in safetensors format.
+ #[arg(long)]
+ local_mnist: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_datasets::vision::mnist::load_dir("data")?;
+ let m = if let Some(directory) = args.local_mnist {
+ candle_datasets::vision::mnist::load_dir(directory)?
+ } else {
+ candle_datasets::vision::mnist::load()?
+ };
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());