diff options
Diffstat (limited to 'candle-book/src')
-rw-r--r-- | candle-book/src/SUMMARY.md | 10 | ||||
-rw-r--r-- | candle-book/src/inference/hub.md | 6 | ||||
-rw-r--r-- | candle-book/src/lib.rs | 193 | ||||
-rw-r--r-- | candle-book/src/training/README.md | 38 | ||||
-rw-r--r-- | candle-book/src/training/mnist.md | 9 |
5 files changed, 248 insertions, 8 deletions
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index 3432f66f..8228da22 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -12,7 +12,11 @@ - [Running a model](inference/README.md) - [Using the hub](inference/hub.md) -- [Error management]() +- [Error management](error_manage.md) +- [Training](training/README.md) + - [MNIST](training/mnist.md) + - [Fine-tuning]() + - [Serialization]() - [Advanced Cuda usage]() - [Writing a custom kernel]() - [Porting a custom kernel]() @@ -21,7 +25,3 @@ - [Creating a WASM app]() - [Creating a REST api webserver]() - [Creating a desktop Tauri app]() -- [Training]() - - [MNIST]() - - [Fine-tuning]() - - [Serialization]() diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index 4bd69c14..e8d8b267 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -39,7 +39,7 @@ cargo add hf-hub --features tokio ```rust,ignore # This is tested directly in examples crate because it needs external dependencies unfortunately: # See [this](https://github.com/rust-lang/mdBook/issues/706) -{{#include ../../../candle-examples/src/lib.rs:book_hub_1}} +{{#include ../lib.rs:book_hub_1}} ``` @@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2 and will definitely be slower on network mounted disk, because it will issue more read calls. ```rust,ignore -{{#include ../../../candle-examples/src/lib.rs:book_hub_2}} +{{#include ../lib.rs:book_hub_2}} ``` **Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety). @@ -100,5 +100,5 @@ cargo add safetensors ```rust,ignore -{{#include ../../../candle-examples/src/lib.rs:book_hub_3}} +{{#include ../lib.rs:book_hub_3}} ``` diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs new file mode 100644 index 00000000..ef9b853e --- /dev/null +++ b/candle-book/src/lib.rs @@ -0,0 +1,193 @@ +#[cfg(test)] +mod tests { + use anyhow::Result; + use candle::{DType, Device, Tensor}; + use parquet::file::reader::SerializedFileReader; + + // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 + #[rustfmt::skip] + #[tokio::test] + async fn book_hub_1() { +// ANCHOR: book_hub_1 +use candle::Device; +use hf_hub::api::tokio::Api; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights_filename = repo.get("model.safetensors").await.unwrap(); + +let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_1 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_2() { +// ANCHOR: book_hub_2 +use candle::Device; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; +let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_2 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_3() { +// ANCHOR: book_hub_3 +use candle::{DType, Device, Tensor}; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use safetensors::slice::IndexOp; +use safetensors::SafeTensors; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; + +// Use safetensors directly +let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); +let view = tensors + .tensor("bert.encoder.layer.0.attention.self.query.weight") + .unwrap(); + +// We're going to load shard with rank 1, within a world_size of 4 +// We're going to split along dimension 0 doing VIEW[start..stop, :] +let rank = 1; +let world_size = 4; +let dim = 0; +let dtype = view.dtype(); +let mut tp_shape = view.shape().to_vec(); +let size = tp_shape[0]; + +if size % world_size != 0 { + panic!("The dimension is not divisble by `world_size`"); +} +let block_size = size / world_size; +let start = rank * block_size; +let stop = (rank + 1) * block_size; + +// Everything is expressed in tensor dimension +// bytes offsets is handled automatically for safetensors. + +let iterator = view.slice(start..stop).unwrap(); + +tp_shape[dim] = block_size; + +// Convert safetensors Dtype to candle DType +let dtype: DType = dtype.try_into().unwrap(); + +// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. +let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); +let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_3 + assert_eq!(view.shape(), &[768, 768]); + assert_eq!(tp_tensor.dims(), &[192, 768]); + } + + #[rustfmt::skip] + #[test] + fn book_training_1() -> Result<()>{ +// ANCHOR: book_training_1 +use hf_hub::{api::sync::Api, Repo, RepoType}; + +let dataset_id = "mnist".to_string(); + +let api = Api::new()?; +let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), +); +let repo = api.repo(repo); +let test_parquet_filename = repo.get("mnist/test/0000.parquet")?; +let train_parquet_filename = repo.get("mnist/train/0000.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR_END: book_training_1 +// Ignore unused +let _train = train_parquet; +// ANCHOR: book_training_2 +for row in test_parquet { + for (idx, (name, field)) in row?.get_column_iter().enumerate() { + println!("Column id {idx}, name {name}, value {field}"); + } +} +// ANCHOR_END: book_training_2 +let test_parquet_filename = repo.get("mnist/test/0000.parquet")?; +let train_parquet_filename = repo.get("mnist/train/0000.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR: book_training_3 + +let test_samples = 10_000; +let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784); +let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples); +for row in test_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + test_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + test_buffer_labels.push(*label as u8); + } + } +} +let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?; + +let train_samples = 60_000; +let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784); +let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples); +for row in train_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + train_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + train_buffer_labels.push(*label as u8); + } + } +} +let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?; + +let mnist = candle_datasets::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, +}; + +// ANCHOR_END: book_training_3 +assert_eq!(mnist.test_images.dims(), &[10_000, 784]); +assert_eq!(mnist.test_labels.dims(), &[10_000]); +assert_eq!(mnist.train_images.dims(), &[60_000, 784]); +assert_eq!(mnist.train_labels.dims(), &[60_000]); +Ok(()) + } +} diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md index 8977de34..d68a917e 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/README.md @@ -1 +1,39 @@ # Training + + +Training starts with data. We're going to use the huggingface hub and +start with the Hello world dataset of machine learning, MNIST. + +Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist). + +This requires [`hf-hub`](https://github.com/huggingface/hf-hub). +```bash +cargo add hf-hub +``` + +This is going to be very hands-on for now. + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_1}} +``` + +This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset. +Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`]. + +We can inspect the content of the files with: + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_2}} +``` + +You should see something like: + +```bash +Column id 1, name label, value 6 +Column id 0, name image, value {bytes: [137, ....] +Column id 1, name label, value 8 +Column id 0, name image, value {bytes: [137, ....] +``` + +So each row contains 2 columns (image, label) with image being saved as bytes. +Let's put them into a useful struct. diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md index 642960a4..1394921b 100644 --- a/candle-book/src/training/mnist.md +++ b/candle-book/src/training/mnist.md @@ -1 +1,10 @@ # MNIST + +So we now have downloaded the MNIST parquet files, let's put them in a simple struct. + +```rust,ignore +{{#include ../lib.rs:book_training_3}} +``` + +The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory. +It is quite rudimentary, but simple enough for a small dataset like MNIST. |