summaryrefslogtreecommitdiff
path: root/candle-book
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-10-03 10:41:30 +0200
committerGitHub <noreply@github.com>2023-10-03 10:41:30 +0200
commit7b06872f90bd12f660785d997cce47c12c0fffa1 (patch)
treea30bb342bc1dc492c90257882498f524dd0002d2 /candle-book
parent65825e724013304e4b4664a9edfce1b356cd0e40 (diff)
parent638ccf9f46428bfb291603bffc2bf6ef4e6c094e (diff)
downloadcandle-7b06872f90bd12f660785d997cce47c12c0fffa1.tar.gz
candle-7b06872f90bd12f660785d997cce47c12c0fffa1.tar.bz2
candle-7b06872f90bd12f660785d997cce47c12c0fffa1.zip
Merge pull request #926 from evgenyigumnov/book-trainin-simplified
Book train simlified example
Diffstat (limited to 'candle-book')
-rw-r--r--candle-book/Cargo.toml4
-rw-r--r--candle-book/src/SUMMARY.md1
-rw-r--r--candle-book/src/lib.rs3
-rw-r--r--candle-book/src/simplified.rs196
-rw-r--r--candle-book/src/training/simplified.md45
5 files changed, 247 insertions, 2 deletions
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
index 8aec0822..a060a701 100644
--- a/candle-book/Cargo.toml
+++ b/candle-book/Cargo.toml
@@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
+anyhow = { workspace = true }
+tokio = "1.29.1"
[dev-dependencies]
-anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
@@ -38,7 +39,6 @@ tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
-tokio = "1.29.1"
parquet = { workspace = true }
image = { workspace = true }
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index e92f298f..59831af2 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -14,6 +14,7 @@
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Training](training/training.md)
+ - [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()
diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs
index ef9b853e..a1ec1e94 100644
--- a/candle-book/src/lib.rs
+++ b/candle-book/src/lib.rs
@@ -1,4 +1,7 @@
#[cfg(test)]
+pub mod simplified;
+
+#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
diff --git a/candle-book/src/simplified.rs b/candle-book/src/simplified.rs
new file mode 100644
index 00000000..6427563f
--- /dev/null
+++ b/candle-book/src/simplified.rs
@@ -0,0 +1,196 @@
+//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
+//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
+//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
+//!
+//! ##Basic moments:
+//!
+//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
+//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
+//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
+//! For training, samples with real data on the results of the first and second stages of different elections are used.
+//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
+//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
+//! After training, the model is tested on a deferred sample to evaluate the accuracy.
+//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
+//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
+
+#[rustfmt::skip]
+mod tests {
+
+use candle::{DType, Result, Tensor, D, Device};
+use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
+
+// ANCHOR: book_training_simplified1
+const VOTE_DIM: usize = 2;
+const RESULTS: usize = 1;
+const EPOCHS: usize = 10;
+const LAYER1_OUT_SIZE: usize = 4;
+const LAYER2_OUT_SIZE: usize = 2;
+const LEARNING_RATE: f64 = 0.05;
+
+#[derive(Clone)]
+pub struct Dataset {
+ pub train_votes: Tensor,
+ pub train_results: Tensor,
+ pub test_votes: Tensor,
+ pub test_results: Tensor,
+}
+
+struct MultiLevelPerceptron {
+ ln1: Linear,
+ ln2: Linear,
+ ln3: Linear,
+}
+
+impl MultiLevelPerceptron {
+ fn new(vs: VarBuilder) -> Result<Self> {
+ let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
+ let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
+ let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
+ Ok(Self { ln1, ln2, ln3 })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.ln1.forward(xs)?;
+ let xs = xs.relu()?;
+ let xs = self.ln2.forward(&xs)?;
+ let xs = xs.relu()?;
+ self.ln3.forward(&xs)
+ }
+}
+
+// ANCHOR_END: book_training_simplified1
+
+
+
+#[tokio::test]
+// ANCHOR: book_training_simplified3
+async fn simplified() -> anyhow::Result<()> {
+
+ let dev = Device::cuda_if_available(0)?;
+
+ let train_votes_vec: Vec<u32> = vec![
+ 15, 10,
+ 10, 15,
+ 5, 12,
+ 30, 20,
+ 16, 12,
+ 13, 25,
+ 6, 14,
+ 31, 21,
+ ];
+ let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
+
+ let train_results_vec: Vec<u32> = vec![
+ 1,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 0,
+ 1,
+ ];
+ let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
+
+ let test_votes_vec: Vec<u32> = vec![
+ 13, 9,
+ 8, 14,
+ 3, 10,
+ ];
+ let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
+
+ let test_results_vec: Vec<u32> = vec![
+ 1,
+ 0,
+ 0,
+ ];
+ let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
+
+ let m = Dataset {
+ train_votes: train_votes_tensor,
+ train_results: train_results_tensor,
+ test_votes: test_votes_tensor,
+ test_results: test_results_tensor,
+ };
+
+ let trained_model: MultiLevelPerceptron;
+ loop {
+ println!("Trying to train neural network.");
+ match train(m.clone(), &dev) {
+ Ok(model) => {
+ trained_model = model;
+ break;
+ },
+ Err(e) => {
+ println!("Error: {}", e);
+ continue;
+ }
+ }
+
+ }
+
+ let real_world_votes: Vec<u32> = vec![
+ 13, 22,
+ ];
+
+ let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
+
+ let final_result = trained_model.forward(&tensor_test_votes)?;
+
+ let result = final_result
+ .argmax(D::Minus1)?
+ .to_dtype(DType::F32)?
+ .get(0).map(|x| x.to_scalar::<f32>())??;
+ println!("real_life_votes: {:?}", real_world_votes);
+ println!("neural_network_prediction_result: {:?}", result);
+
+ Ok(())
+
+}
+// ANCHOR_3: book_training_simplified3
+
+// ANCHOR: book_training_simplified2
+fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
+ let train_results = m.train_results.to_device(dev)?;
+ let train_votes = m.train_votes.to_device(dev)?;
+ let varmap = VarMap::new();
+ let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
+ let model = MultiLevelPerceptron::new(vs.clone())?;
+ let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
+ let test_votes = m.test_votes.to_device(dev)?;
+ let test_results = m.test_results.to_device(dev)?;
+ let mut final_accuracy: f32 = 0.0;
+ for epoch in 1..EPOCHS + 1 {
+ let logits = model.forward(&train_votes)?;
+ let log_sm = ops::log_softmax(&logits, D::Minus1)?;
+ let loss = loss::nll(&log_sm, &train_results)?;
+ sgd.backward_step(&loss)?;
+
+ let test_logits = model.forward(&test_votes)?;
+ let sum_ok = test_logits
+ .argmax(D::Minus1)?
+ .eq(&test_results)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::<f32>()?;
+ let test_accuracy = sum_ok / test_results.dims1()? as f32;
+ final_accuracy = 100. * test_accuracy;
+ println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
+ loss.to_scalar::<f32>()?,
+ final_accuracy
+ );
+ if final_accuracy == 100.0 {
+ break;
+ }
+ }
+ if final_accuracy < 100.0 {
+ Err(anyhow::Error::msg("The model is not trained well enough."))
+ } else {
+ Ok(model)
+ }
+}
+// ANCHOR_END: book_training_simplified2
+
+
+}
diff --git a/candle-book/src/training/simplified.md b/candle-book/src/training/simplified.md
new file mode 100644
index 00000000..a64f2da4
--- /dev/null
+++ b/candle-book/src/training/simplified.md
@@ -0,0 +1,45 @@
+# Simplified
+
+## How its works
+
+This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
+
+Basic moments:
+
+1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
+2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
+3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
+4. For training, samples with real data on the results of the first and second stages of different elections are used.
+5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
+6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
+7. After training, the model is tested on a deferred sample to evaluate the accuracy.
+8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
+
+Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
+
+
+```rust,ignore
+{{#include ../simplified.rs:book_training_simplified1}}
+```
+
+```rust,ignore
+{{#include ../simplified.rs:book_training_simplified2}}
+```
+
+```rust,ignore
+{{#include ../simplified.rs:book_training_simplified3}}
+```
+
+
+## Example output
+
+```bash
+Trying to train neural network.
+Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
+Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
+Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
+Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
+Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
+real_life_votes: [13, 22]
+neural_network_prediction_result: 0.0
+```