diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-10-03 10:41:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-03 10:41:30 +0200 |
commit | 7b06872f90bd12f660785d997cce47c12c0fffa1 (patch) | |
tree | a30bb342bc1dc492c90257882498f524dd0002d2 /candle-book | |
parent | 65825e724013304e4b4664a9edfce1b356cd0e40 (diff) | |
parent | 638ccf9f46428bfb291603bffc2bf6ef4e6c094e (diff) | |
download | candle-7b06872f90bd12f660785d997cce47c12c0fffa1.tar.gz candle-7b06872f90bd12f660785d997cce47c12c0fffa1.tar.bz2 candle-7b06872f90bd12f660785d997cce47c12c0fffa1.zip |
Merge pull request #926 from evgenyigumnov/book-trainin-simplified
Book train simlified example
Diffstat (limited to 'candle-book')
-rw-r--r-- | candle-book/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-book/src/SUMMARY.md | 1 | ||||
-rw-r--r-- | candle-book/src/lib.rs | 3 | ||||
-rw-r--r-- | candle-book/src/simplified.rs | 196 | ||||
-rw-r--r-- | candle-book/src/training/simplified.md | 45 |
5 files changed, 247 insertions, 2 deletions
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index 8aec0822..a060a701 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true, optional = true } +anyhow = { workspace = true } +tokio = "1.29.1" [dev-dependencies] -anyhow = { workspace = true } byteorder = { workspace = true } hf-hub = { workspace = true, features=["tokio"]} clap = { workspace = true } @@ -38,7 +39,6 @@ tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } wav = { workspace = true } # Necessary to disambiguate with tokio in wasm examples which are 1.28.1 -tokio = "1.29.1" parquet = { workspace = true } image = { workspace = true } diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index e92f298f..59831af2 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -14,6 +14,7 @@ - [Using the hub](inference/hub.md) - [Error management](error_manage.md) - [Training](training/training.md) + - [Simplified](training/simplified.md) - [MNIST](training/mnist.md) - [Fine-tuning]() - [Serialization]() diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs index ef9b853e..a1ec1e94 100644 --- a/candle-book/src/lib.rs +++ b/candle-book/src/lib.rs @@ -1,4 +1,7 @@ #[cfg(test)] +pub mod simplified; + +#[cfg(test)] mod tests { use anyhow::Result; use candle::{DType, Device, Tensor}; diff --git a/candle-book/src/simplified.rs b/candle-book/src/simplified.rs new file mode 100644 index 00000000..6427563f --- /dev/null +++ b/candle-book/src/simplified.rs @@ -0,0 +1,196 @@ +//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face. +//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com +//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round. +//! +//! ##Basic moments: +//! +//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons. +//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage. +//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose. +//! For training, samples with real data on the results of the first and second stages of different elections are used. +//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function. +//! Model parameters (weights of neurons) are initialized randomly, then optimized during training. +//! After training, the model is tested on a deferred sample to evaluate the accuracy. +//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated. +//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data. + +#[rustfmt::skip] +mod tests { + +use candle::{DType, Result, Tensor, D, Device}; +use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer}; + +// ANCHOR: book_training_simplified1 +const VOTE_DIM: usize = 2; +const RESULTS: usize = 1; +const EPOCHS: usize = 10; +const LAYER1_OUT_SIZE: usize = 4; +const LAYER2_OUT_SIZE: usize = 2; +const LEARNING_RATE: f64 = 0.05; + +#[derive(Clone)] +pub struct Dataset { + pub train_votes: Tensor, + pub train_results: Tensor, + pub test_votes: Tensor, + pub test_results: Tensor, +} + +struct MultiLevelPerceptron { + ln1: Linear, + ln2: Linear, + ln3: Linear, +} + +impl MultiLevelPerceptron { + fn new(vs: VarBuilder) -> Result<Self> { + let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?; + let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?; + let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?; + Ok(Self { ln1, ln2, ln3 }) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.ln1.forward(xs)?; + let xs = xs.relu()?; + let xs = self.ln2.forward(&xs)?; + let xs = xs.relu()?; + self.ln3.forward(&xs) + } +} + +// ANCHOR_END: book_training_simplified1 + + + +#[tokio::test] +// ANCHOR: book_training_simplified3 +async fn simplified() -> anyhow::Result<()> { + + let dev = Device::cuda_if_available(0)?; + + let train_votes_vec: Vec<u32> = vec![ + 15, 10, + 10, 15, + 5, 12, + 30, 20, + 16, 12, + 13, 25, + 6, 14, + 31, 21, + ]; + let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?; + + let train_results_vec: Vec<u32> = vec![ + 1, + 0, + 0, + 1, + 1, + 0, + 0, + 1, + ]; + let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?; + + let test_votes_vec: Vec<u32> = vec![ + 13, 9, + 8, 14, + 3, 10, + ]; + let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?; + + let test_results_vec: Vec<u32> = vec![ + 1, + 0, + 0, + ]; + let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?; + + let m = Dataset { + train_votes: train_votes_tensor, + train_results: train_results_tensor, + test_votes: test_votes_tensor, + test_results: test_results_tensor, + }; + + let trained_model: MultiLevelPerceptron; + loop { + println!("Trying to train neural network."); + match train(m.clone(), &dev) { + Ok(model) => { + trained_model = model; + break; + }, + Err(e) => { + println!("Error: {}", e); + continue; + } + } + + } + + let real_world_votes: Vec<u32> = vec![ + 13, 22, + ]; + + let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?; + + let final_result = trained_model.forward(&tensor_test_votes)?; + + let result = final_result + .argmax(D::Minus1)? + .to_dtype(DType::F32)? + .get(0).map(|x| x.to_scalar::<f32>())??; + println!("real_life_votes: {:?}", real_world_votes); + println!("neural_network_prediction_result: {:?}", result); + + Ok(()) + +} +// ANCHOR_3: book_training_simplified3 + +// ANCHOR: book_training_simplified2 +fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> { + let train_results = m.train_results.to_device(dev)?; + let train_votes = m.train_votes.to_device(dev)?; + let varmap = VarMap::new(); + let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev); + let model = MultiLevelPerceptron::new(vs.clone())?; + let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?; + let test_votes = m.test_votes.to_device(dev)?; + let test_results = m.test_results.to_device(dev)?; + let mut final_accuracy: f32 = 0.0; + for epoch in 1..EPOCHS + 1 { + let logits = model.forward(&train_votes)?; + let log_sm = ops::log_softmax(&logits, D::Minus1)?; + let loss = loss::nll(&log_sm, &train_results)?; + sgd.backward_step(&loss)?; + + let test_logits = model.forward(&test_votes)?; + let sum_ok = test_logits + .argmax(D::Minus1)? + .eq(&test_results)? + .to_dtype(DType::F32)? + .sum_all()? + .to_scalar::<f32>()?; + let test_accuracy = sum_ok / test_results.dims1()? as f32; + final_accuracy = 100. * test_accuracy; + println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%", + loss.to_scalar::<f32>()?, + final_accuracy + ); + if final_accuracy == 100.0 { + break; + } + } + if final_accuracy < 100.0 { + Err(anyhow::Error::msg("The model is not trained well enough.")) + } else { + Ok(model) + } +} +// ANCHOR_END: book_training_simplified2 + + +} diff --git a/candle-book/src/training/simplified.md b/candle-book/src/training/simplified.md new file mode 100644 index 00000000..a64f2da4 --- /dev/null +++ b/candle-book/src/training/simplified.md @@ -0,0 +1,45 @@ +# Simplified + +## How its works + +This program implements a neural network to predict the winner of the second round of elections based on the results of the first round. + +Basic moments: + +1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons. +2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage. +3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose. +4. For training, samples with real data on the results of the first and second stages of different elections are used. +5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function. +6. Model parameters (weights of neurons) are initialized randomly, then optimized during training. +7. After training, the model is tested on a deferred sample to evaluate the accuracy. +8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated. + +Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data. + + +```rust,ignore +{{#include ../simplified.rs:book_training_simplified1}} +``` + +```rust,ignore +{{#include ../simplified.rs:book_training_simplified2}} +``` + +```rust,ignore +{{#include ../simplified.rs:book_training_simplified3}} +``` + + +## Example output + +```bash +Trying to train neural network. +Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00% +Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33% +Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33% +Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33% +Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00% +real_life_votes: [13, 22] +neural_network_prediction_result: 0.0 +``` |