summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci_cuda.yaml2
-rw-r--r--Cargo.toml2
-rw-r--r--candle-book/Cargo.toml49
-rw-r--r--candle-book/src/SUMMARY.md10
-rw-r--r--candle-book/src/inference/hub.md6
-rw-r--r--candle-book/src/lib.rs193
-rw-r--r--candle-book/src/training/README.md38
-rw-r--r--candle-book/src/training/mnist.md9
-rw-r--r--candle-datasets/Cargo.toml3
-rw-r--r--candle-datasets/src/hub.rs73
-rw-r--r--candle-datasets/src/lib.rs1
-rw-r--r--candle-datasets/src/vision/mnist.rs59
-rw-r--r--candle-examples/examples/mnist-training/main.rs10
-rw-r--r--candle-examples/src/lib.rs99
14 files changed, 444 insertions, 110 deletions
diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml
index 7c6cfa9b..8953c444 100644
--- a/.github/workflows/ci_cuda.yaml
+++ b/.github/workflows/ci_cuda.yaml
@@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- - run: apt update -y && apt install libssl-dev -y
+ - run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
diff --git a/Cargo.toml b/Cargo.toml
index 0a54ea45..f60bde8c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -3,6 +3,7 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
+ "candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
@@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
zip = { version = "0.6.6", default-features = false }
+parquet = { version = "45.0.0" }
[profile.release-with-debug]
inherits = "release"
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
new file mode 100644
index 00000000..6cd0a487
--- /dev/null
+++ b/candle-book/Cargo.toml
@@ -0,0 +1,49 @@
+[package]
+name = "candle-book"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+readme = "README.md"
+
+[dependencies]
+accelerate-src = { workspace = true, optional = true }
+candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.2.0" }
+candle-nn = { path = "../candle-nn", version = "0.2.0" }
+candle-transformers = { path = "../candle-transformers", version = "0.2.0" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.0", optional = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+num-traits = { workspace = true }
+intel-mkl-src = { workspace = true, optional = true }
+cudarc = { workspace = true, optional = true }
+half = { workspace = true, optional = true }
+image = { workspace = true, optional = true }
+
+[dev-dependencies]
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+hf-hub = { workspace = true, features=["tokio"]}
+clap = { workspace = true }
+memmap2 = { workspace = true }
+rand = { workspace = true }
+tokenizers = { workspace = true, features = ["onig"] }
+tracing = { workspace = true }
+tracing-chrome = { workspace = true }
+tracing-subscriber = { workspace = true }
+wav = { workspace = true }
+# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
+tokio = "1.29.1"
+parquet = { workspace = true }
+image = { workspace = true }
+
+[build-dependencies]
+anyhow = { workspace = true }
+
+[features]
+default = []
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index 3432f66f..8228da22 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -12,7 +12,11 @@
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
-- [Error management]()
+- [Error management](error_manage.md)
+- [Training](training/README.md)
+ - [MNIST](training/mnist.md)
+ - [Fine-tuning]()
+ - [Serialization]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
@@ -21,7 +25,3 @@
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
-- [Training]()
- - [MNIST]()
- - [Fine-tuning]()
- - [Serialization]()
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index 4bd69c14..e8d8b267 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
-{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
+{{#include ../lib.rs:book_hub_1}}
```
@@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
-{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
+{{#include ../lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
@@ -100,5 +100,5 @@ cargo add safetensors
```rust,ignore
-{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
+{{#include ../lib.rs:book_hub_3}}
```
diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs
new file mode 100644
index 00000000..ef9b853e
--- /dev/null
+++ b/candle-book/src/lib.rs
@@ -0,0 +1,193 @@
+#[cfg(test)]
+mod tests {
+ use anyhow::Result;
+ use candle::{DType, Device, Tensor};
+ use parquet::file::reader::SerializedFileReader;
+
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+ .tensor("bert.encoder.layer.0.attention.self.query.weight")
+ .unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+ panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_training_1() -> Result<()>{
+// ANCHOR: book_training_1
+use hf_hub::{api::sync::Api, Repo, RepoType};
+
+let dataset_id = "mnist".to_string();
+
+let api = Api::new()?;
+let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+);
+let repo = api.repo(repo);
+let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
+let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR_END: book_training_1
+// Ignore unused
+let _train = train_parquet;
+// ANCHOR: book_training_2
+for row in test_parquet {
+ for (idx, (name, field)) in row?.get_column_iter().enumerate() {
+ println!("Column id {idx}, name {name}, value {field}");
+ }
+}
+// ANCHOR_END: book_training_2
+let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
+let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR: book_training_3
+
+let test_samples = 10_000;
+let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
+let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
+for row in test_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ test_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ test_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
+
+let train_samples = 60_000;
+let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
+let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
+for row in train_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ train_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ train_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
+
+let mnist = candle_datasets::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+};
+
+// ANCHOR_END: book_training_3
+assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
+assert_eq!(mnist.test_labels.dims(), &[10_000]);
+assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
+assert_eq!(mnist.train_labels.dims(), &[60_000]);
+Ok(())
+ }
+}
diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md
index 8977de34..d68a917e 100644
--- a/candle-book/src/training/README.md
+++ b/candle-book/src/training/README.md
@@ -1 +1,39 @@
# Training
+
+
+Training starts with data. We're going to use the huggingface hub and
+start with the Hello world dataset of machine learning, MNIST.
+
+Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
+
+This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
+```bash
+cargo add hf-hub
+```
+
+This is going to be very hands-on for now.
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
+```
+
+This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
+Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
+
+We can inspect the content of the files with:
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
+```
+
+You should see something like:
+
+```bash
+Column id 1, name label, value 6
+Column id 0, name image, value {bytes: [137, ....]
+Column id 1, name label, value 8
+Column id 0, name image, value {bytes: [137, ....]
+```
+
+So each row contains 2 columns (image, label) with image being saved as bytes.
+Let's put them into a useful struct.
diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md
index 642960a4..1394921b 100644
--- a/candle-book/src/training/mnist.md
+++ b/candle-book/src/training/mnist.md
@@ -1 +1,10 @@
# MNIST
+
+So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
+
+```rust,ignore
+{{#include ../lib.rs:book_training_3}}
+```
+
+The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
+It is quite rudimentary, but simple enough for a small dataset like MNIST.
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index 91fced54..f4472a08 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -18,3 +18,6 @@ intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
+thiserror = { workspace = true }
+parquet = { workspace = true}
+image = { workspace = true }
diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs
new file mode 100644
index 00000000..b135e148
--- /dev/null
+++ b/candle-datasets/src/hub.rs
@@ -0,0 +1,73 @@
+use hf_hub::{
+ api::sync::{Api, ApiRepo},
+ Repo, RepoType,
+};
+use parquet::file::reader::SerializedFileReader;
+use std::fs::File;
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+ #[error("ApiError : {0}")]
+ ApiError(#[from] hf_hub::api::sync::ApiError),
+
+ #[error("IoError : {0}")]
+ IoError(#[from] std::io::Error),
+
+ #[error("ParquetError : {0}")]
+ ParquetError(#[from] parquet::errors::ParquetError),
+}
+
+fn sibling_to_parquet(
+ rfilename: &str,
+ repo: &ApiRepo,
+) -> Result<SerializedFileReader<File>, Error> {
+ let local = repo.get(rfilename)?;
+ let file = File::open(local)?;
+ let reader = SerializedFileReader::new(file)?;
+ Ok(reader)
+}
+
+pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
+ let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+ );
+ let repo = api.repo(repo);
+ let info = repo.info()?;
+
+ let files: Result<Vec<_>, _> = info
+ .siblings
+ .into_iter()
+ .filter_map(|s| -> Option<Result<_, _>> {
+ let filename = s.rfilename;
+ if filename.ends_with(".parquet") {
+ let reader_result = sibling_to_parquet(&filename, &repo);
+ Some(reader_result)
+ } else {
+ None
+ }
+ })
+ .collect();
+ let files = files?;
+
+ Ok(files)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use parquet::file::reader::FileReader;
+
+ #[test]
+ fn test_dataset() {
+ let api = Api::new().unwrap();
+ let files = from_hub(
+ &api,
+ "hf-internal-testing/dummy_image_text_data".to_string(),
+ )
+ .unwrap();
+ assert_eq!(files.len(), 1);
+ assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
+ }
+}
diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs
index 42ad5d62..bfd77a99 100644
--- a/candle-datasets/src/lib.rs
+++ b/candle-datasets/src/lib.rs
@@ -1,5 +1,6 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
+pub mod hub;
pub mod nlp;
pub mod vision;
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 2267f9a0..1085edd6 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -2,7 +2,9 @@
//!
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
-use candle::{DType, Device, Result, Tensor};
+use candle::{DType, Device, Error, Result, Tensor};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
@@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
labels: 10,
})
}
+
+fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
+ let samples = parquet.metadata().file_metadata().num_rows() as usize;
+ let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
+ let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
+ for row in parquet.into_iter().flatten() {
+ for (_name, field) in row.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ } else if let parquet::record::Field::Long(label) = field {
+ buffer_labels.push(*label as u8);
+ }
+ }
+ }
+ let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
+ .to_dtype(DType::F32)?
+ / 255.)?;
+ let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
+ Ok((images, labels))
+}
+
+pub fn load() -> Result<crate::vision::Dataset> {
+ let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let dataset_id = "mnist".to_string();
+ let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+ );
+ let repo = api.repo(repo);
+ let test_parquet_filename = repo
+ .get("mnist/mnist-test.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let train_parquet_filename = repo
+ .get("mnist/mnist-train.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
+ let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
+ let (test_images, test_labels) = load_parquet(test_parquet)?;
+ let (train_images, train_labels) = load_parquet(train_parquet)?;
+ Ok(crate::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+ })
+}
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index bcf8677d..ea4b4bce 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -138,12 +138,20 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
+
+ /// The file where to load the trained weights from, in safetensors format.
+ #[arg(long)]
+ local_mnist: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_datasets::vision::mnist::load_dir("data")?;
+ let m = if let Some(directory) = args.local_mnist {
+ candle_datasets::vision::mnist::load_dir(directory)?
+ } else {
+ candle_datasets::vision::mnist::load()?
+ };
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 8bf94eb7..395162eb 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
-
-#[cfg(test)]
-mod tests {
- // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
- #[rustfmt::skip]
- #[tokio::test]
- async fn book_hub_1() {
-// ANCHOR: book_hub_1
-use candle::Device;
-use hf_hub::api::tokio::Api;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-
-let weights_filename = repo.get("model.safetensors").await.unwrap();
-
-let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_1
- assert_eq!(weights.len(), 206);
- }
-
- #[rustfmt::skip]
- #[test]
- fn book_hub_2() {
-// ANCHOR: book_hub_2
-use candle::Device;
-use hf_hub::api::sync::Api;
-use memmap2::Mmap;
-use std::fs;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-let weights_filename = repo.get("model.safetensors").unwrap();
-
-let file = fs::File::open(weights_filename).unwrap();
-let mmap = unsafe { Mmap::map(&file).unwrap() };
-let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_2
- assert_eq!(weights.len(), 206);
- }
-
- #[rustfmt::skip]
- #[test]
- fn book_hub_3() {
-// ANCHOR: book_hub_3
-use candle::{DType, Device, Tensor};
-use hf_hub::api::sync::Api;
-use memmap2::Mmap;
-use safetensors::slice::IndexOp;
-use safetensors::SafeTensors;
-use std::fs;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-let weights_filename = repo.get("model.safetensors").unwrap();
-
-let file = fs::File::open(weights_filename).unwrap();
-let mmap = unsafe { Mmap::map(&file).unwrap() };
-
-// Use safetensors directly
-let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
-let view = tensors
- .tensor("bert.encoder.layer.0.attention.self.query.weight")
- .unwrap();
-
-// We're going to load shard with rank 1, within a world_size of 4
-// We're going to split along dimension 0 doing VIEW[start..stop, :]
-let rank = 1;
-let world_size = 4;
-let dim = 0;
-let dtype = view.dtype();
-let mut tp_shape = view.shape().to_vec();
-let size = tp_shape[0];
-
-if size % world_size != 0 {
- panic!("The dimension is not divisble by `world_size`");
-}
-let block_size = size / world_size;
-let start = rank * block_size;
-let stop = (rank + 1) * block_size;
-
-// Everything is expressed in tensor dimension
-// bytes offsets is handled automatically for safetensors.
-
-let iterator = view.slice(start..stop).unwrap();
-
-tp_shape[dim] = block_size;
-
-// Convert safetensors Dtype to candle DType
-let dtype: DType = dtype.try_into().unwrap();
-
-// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
-let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
-let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_3
- assert_eq!(view.shape(), &[768, 768]);
- assert_eq!(tp_tensor.dims(), &[192, 768]);
- }
-}