summaryrefslogtreecommitdiff
path: root/candle-book/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-book/src')
-rw-r--r--candle-book/src/SUMMARY.md10
-rw-r--r--candle-book/src/inference/hub.md6
-rw-r--r--candle-book/src/lib.rs193
-rw-r--r--candle-book/src/training/README.md38
-rw-r--r--candle-book/src/training/mnist.md9
5 files changed, 248 insertions, 8 deletions
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index 3432f66f..8228da22 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -12,7 +12,11 @@
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
-- [Error management]()
+- [Error management](error_manage.md)
+- [Training](training/README.md)
+ - [MNIST](training/mnist.md)
+ - [Fine-tuning]()
+ - [Serialization]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
@@ -21,7 +25,3 @@
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
-- [Training]()
- - [MNIST]()
- - [Fine-tuning]()
- - [Serialization]()
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index 4bd69c14..e8d8b267 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
-{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
+{{#include ../lib.rs:book_hub_1}}
```
@@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
-{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
+{{#include ../lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
@@ -100,5 +100,5 @@ cargo add safetensors
```rust,ignore
-{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
+{{#include ../lib.rs:book_hub_3}}
```
diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs
new file mode 100644
index 00000000..ef9b853e
--- /dev/null
+++ b/candle-book/src/lib.rs
@@ -0,0 +1,193 @@
+#[cfg(test)]
+mod tests {
+ use anyhow::Result;
+ use candle::{DType, Device, Tensor};
+ use parquet::file::reader::SerializedFileReader;
+
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+ .tensor("bert.encoder.layer.0.attention.self.query.weight")
+ .unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+ panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_training_1() -> Result<()>{
+// ANCHOR: book_training_1
+use hf_hub::{api::sync::Api, Repo, RepoType};
+
+let dataset_id = "mnist".to_string();
+
+let api = Api::new()?;
+let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+);
+let repo = api.repo(repo);
+let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
+let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR_END: book_training_1
+// Ignore unused
+let _train = train_parquet;
+// ANCHOR: book_training_2
+for row in test_parquet {
+ for (idx, (name, field)) in row?.get_column_iter().enumerate() {
+ println!("Column id {idx}, name {name}, value {field}");
+ }
+}
+// ANCHOR_END: book_training_2
+let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
+let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR: book_training_3
+
+let test_samples = 10_000;
+let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
+let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
+for row in test_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ test_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ test_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
+
+let train_samples = 60_000;
+let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
+let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
+for row in train_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ train_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ train_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
+
+let mnist = candle_datasets::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+};
+
+// ANCHOR_END: book_training_3
+assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
+assert_eq!(mnist.test_labels.dims(), &[10_000]);
+assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
+assert_eq!(mnist.train_labels.dims(), &[60_000]);
+Ok(())
+ }
+}
diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md
index 8977de34..d68a917e 100644
--- a/candle-book/src/training/README.md
+++ b/candle-book/src/training/README.md
@@ -1 +1,39 @@
# Training
+
+
+Training starts with data. We're going to use the huggingface hub and
+start with the Hello world dataset of machine learning, MNIST.
+
+Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
+
+This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
+```bash
+cargo add hf-hub
+```
+
+This is going to be very hands-on for now.
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
+```
+
+This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
+Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
+
+We can inspect the content of the files with:
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
+```
+
+You should see something like:
+
+```bash
+Column id 1, name label, value 6
+Column id 0, name image, value {bytes: [137, ....]
+Column id 1, name label, value 8
+Column id 0, name image, value {bytes: [137, ....]
+```
+
+So each row contains 2 columns (image, label) with image being saved as bytes.
+Let's put them into a useful struct.
diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md
index 642960a4..1394921b 100644
--- a/candle-book/src/training/mnist.md
+++ b/candle-book/src/training/mnist.md
@@ -1 +1,10 @@
# MNIST
+
+So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
+
+```rust,ignore
+{{#include ../lib.rs:book_training_3}}
+```
+
+The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
+It is quite rudimentary, but simple enough for a small dataset like MNIST.