diff options
-rw-r--r-- | .github/workflows/ci_cuda.yaml | 2 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | candle-book/Cargo.toml | 49 | ||||
-rw-r--r-- | candle-book/src/SUMMARY.md | 10 | ||||
-rw-r--r-- | candle-book/src/inference/hub.md | 6 | ||||
-rw-r--r-- | candle-book/src/lib.rs | 193 | ||||
-rw-r--r-- | candle-book/src/training/README.md | 38 | ||||
-rw-r--r-- | candle-book/src/training/mnist.md | 9 | ||||
-rw-r--r-- | candle-datasets/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-datasets/src/hub.rs | 73 | ||||
-rw-r--r-- | candle-datasets/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-datasets/src/vision/mnist.rs | 59 | ||||
-rw-r--r-- | candle-examples/examples/mnist-training/main.rs | 10 | ||||
-rw-r--r-- | candle-examples/src/lib.rs | 99 |
14 files changed, 444 insertions, 110 deletions
diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml index 7c6cfa9b..8953c444 100644 --- a/.github/workflows/ci_cuda.yaml +++ b/.github/workflows/ci_cuda.yaml @@ -59,7 +59,7 @@ jobs: - name: Install Rust Stable run: curl https://sh.rustup.rs -sSf | sh -s -- -y - uses: Swatinem/rust-cache@v2 - - run: apt update -y && apt install libssl-dev -y + - run: apt-get update -y && apt-get install libssl-dev -y - name: Test (cuda) run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda stop-runner: @@ -3,6 +3,7 @@ members = [ "candle-core", "candle-datasets", "candle-examples", + "candle-book", "candle-nn", "candle-pyo3", "candle-transformers", @@ -57,6 +58,7 @@ tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" wav = "1.0.0" zip = { version = "0.6.6", default-features = false } +parquet = { version = "45.0.0" } [profile.release-with-debug] inherits = "release" diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml new file mode 100644 index 00000000..6cd0a487 --- /dev/null +++ b/candle-book/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "candle-book" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +accelerate-src = { workspace = true, optional = true } +candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.2.0" } +candle-nn = { path = "../candle-nn", version = "0.2.0" } +candle-transformers = { path = "../candle-transformers", version = "0.2.0" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.0", optional = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +num-traits = { workspace = true } +intel-mkl-src = { workspace = true, optional = true } +cudarc = { workspace = true, optional = true } +half = { workspace = true, optional = true } +image = { workspace = true, optional = true } + +[dev-dependencies] +anyhow = { workspace = true } +byteorder = { workspace = true } +hf-hub = { workspace = true, features=["tokio"]} +clap = { workspace = true } +memmap2 = { workspace = true } +rand = { workspace = true } +tokenizers = { workspace = true, features = ["onig"] } +tracing = { workspace = true } +tracing-chrome = { workspace = true } +tracing-subscriber = { workspace = true } +wav = { workspace = true } +# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 +tokio = "1.29.1" +parquet = { workspace = true } +image = { workspace = true } + +[build-dependencies] +anyhow = { workspace = true } + +[features] +default = [] diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index 3432f66f..8228da22 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -12,7 +12,11 @@ - [Running a model](inference/README.md) - [Using the hub](inference/hub.md) -- [Error management]() +- [Error management](error_manage.md) +- [Training](training/README.md) + - [MNIST](training/mnist.md) + - [Fine-tuning]() + - [Serialization]() - [Advanced Cuda usage]() - [Writing a custom kernel]() - [Porting a custom kernel]() @@ -21,7 +25,3 @@ - [Creating a WASM app]() - [Creating a REST api webserver]() - [Creating a desktop Tauri app]() -- [Training]() - - [MNIST]() - - [Fine-tuning]() - - [Serialization]() diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index 4bd69c14..e8d8b267 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -39,7 +39,7 @@ cargo add hf-hub --features tokio ```rust,ignore # This is tested directly in examples crate because it needs external dependencies unfortunately: # See [this](https://github.com/rust-lang/mdBook/issues/706) -{{#include ../../../candle-examples/src/lib.rs:book_hub_1}} +{{#include ../lib.rs:book_hub_1}} ``` @@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2 and will definitely be slower on network mounted disk, because it will issue more read calls. ```rust,ignore -{{#include ../../../candle-examples/src/lib.rs:book_hub_2}} +{{#include ../lib.rs:book_hub_2}} ``` **Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety). @@ -100,5 +100,5 @@ cargo add safetensors ```rust,ignore -{{#include ../../../candle-examples/src/lib.rs:book_hub_3}} +{{#include ../lib.rs:book_hub_3}} ``` diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs new file mode 100644 index 00000000..ef9b853e --- /dev/null +++ b/candle-book/src/lib.rs @@ -0,0 +1,193 @@ +#[cfg(test)] +mod tests { + use anyhow::Result; + use candle::{DType, Device, Tensor}; + use parquet::file::reader::SerializedFileReader; + + // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 + #[rustfmt::skip] + #[tokio::test] + async fn book_hub_1() { +// ANCHOR: book_hub_1 +use candle::Device; +use hf_hub::api::tokio::Api; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights_filename = repo.get("model.safetensors").await.unwrap(); + +let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_1 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_2() { +// ANCHOR: book_hub_2 +use candle::Device; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; +let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_2 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_3() { +// ANCHOR: book_hub_3 +use candle::{DType, Device, Tensor}; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use safetensors::slice::IndexOp; +use safetensors::SafeTensors; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; + +// Use safetensors directly +let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); +let view = tensors + .tensor("bert.encoder.layer.0.attention.self.query.weight") + .unwrap(); + +// We're going to load shard with rank 1, within a world_size of 4 +// We're going to split along dimension 0 doing VIEW[start..stop, :] +let rank = 1; +let world_size = 4; +let dim = 0; +let dtype = view.dtype(); +let mut tp_shape = view.shape().to_vec(); +let size = tp_shape[0]; + +if size % world_size != 0 { + panic!("The dimension is not divisble by `world_size`"); +} +let block_size = size / world_size; +let start = rank * block_size; +let stop = (rank + 1) * block_size; + +// Everything is expressed in tensor dimension +// bytes offsets is handled automatically for safetensors. + +let iterator = view.slice(start..stop).unwrap(); + +tp_shape[dim] = block_size; + +// Convert safetensors Dtype to candle DType +let dtype: DType = dtype.try_into().unwrap(); + +// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. +let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); +let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_3 + assert_eq!(view.shape(), &[768, 768]); + assert_eq!(tp_tensor.dims(), &[192, 768]); + } + + #[rustfmt::skip] + #[test] + fn book_training_1() -> Result<()>{ +// ANCHOR: book_training_1 +use hf_hub::{api::sync::Api, Repo, RepoType}; + +let dataset_id = "mnist".to_string(); + +let api = Api::new()?; +let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), +); +let repo = api.repo(repo); +let test_parquet_filename = repo.get("mnist/test/0000.parquet")?; +let train_parquet_filename = repo.get("mnist/train/0000.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR_END: book_training_1 +// Ignore unused +let _train = train_parquet; +// ANCHOR: book_training_2 +for row in test_parquet { + for (idx, (name, field)) in row?.get_column_iter().enumerate() { + println!("Column id {idx}, name {name}, value {field}"); + } +} +// ANCHOR_END: book_training_2 +let test_parquet_filename = repo.get("mnist/test/0000.parquet")?; +let train_parquet_filename = repo.get("mnist/train/0000.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR: book_training_3 + +let test_samples = 10_000; +let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784); +let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples); +for row in test_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + test_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + test_buffer_labels.push(*label as u8); + } + } +} +let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?; + +let train_samples = 60_000; +let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784); +let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples); +for row in train_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + train_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + train_buffer_labels.push(*label as u8); + } + } +} +let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?; + +let mnist = candle_datasets::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, +}; + +// ANCHOR_END: book_training_3 +assert_eq!(mnist.test_images.dims(), &[10_000, 784]); +assert_eq!(mnist.test_labels.dims(), &[10_000]); +assert_eq!(mnist.train_images.dims(), &[60_000, 784]); +assert_eq!(mnist.train_labels.dims(), &[60_000]); +Ok(()) + } +} diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md index 8977de34..d68a917e 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/README.md @@ -1 +1,39 @@ # Training + + +Training starts with data. We're going to use the huggingface hub and +start with the Hello world dataset of machine learning, MNIST. + +Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist). + +This requires [`hf-hub`](https://github.com/huggingface/hf-hub). +```bash +cargo add hf-hub +``` + +This is going to be very hands-on for now. + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_1}} +``` + +This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset. +Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`]. + +We can inspect the content of the files with: + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_2}} +``` + +You should see something like: + +```bash +Column id 1, name label, value 6 +Column id 0, name image, value {bytes: [137, ....] +Column id 1, name label, value 8 +Column id 0, name image, value {bytes: [137, ....] +``` + +So each row contains 2 columns (image, label) with image being saved as bytes. +Let's put them into a useful struct. diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md index 642960a4..1394921b 100644 --- a/candle-book/src/training/mnist.md +++ b/candle-book/src/training/mnist.md @@ -1 +1,10 @@ # MNIST + +So we now have downloaded the MNIST parquet files, let's put them in a simple struct. + +```rust,ignore +{{#include ../lib.rs:book_training_3}} +``` + +The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory. +It is quite rudimentary, but simple enough for a small dataset like MNIST. diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 91fced54..f4472a08 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -18,3 +18,6 @@ intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } rand = { workspace = true } +thiserror = { workspace = true } +parquet = { workspace = true} +image = { workspace = true } diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs new file mode 100644 index 00000000..b135e148 --- /dev/null +++ b/candle-datasets/src/hub.rs @@ -0,0 +1,73 @@ +use hf_hub::{ + api::sync::{Api, ApiRepo}, + Repo, RepoType, +}; +use parquet::file::reader::SerializedFileReader; +use std::fs::File; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("ApiError : {0}")] + ApiError(#[from] hf_hub::api::sync::ApiError), + + #[error("IoError : {0}")] + IoError(#[from] std::io::Error), + + #[error("ParquetError : {0}")] + ParquetError(#[from] parquet::errors::ParquetError), +} + +fn sibling_to_parquet( + rfilename: &str, + repo: &ApiRepo, +) -> Result<SerializedFileReader<File>, Error> { + let local = repo.get(rfilename)?; + let file = File::open(local)?; + let reader = SerializedFileReader::new(file)?; + Ok(reader) +} + +pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> { + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let info = repo.info()?; + + let files: Result<Vec<_>, _> = info + .siblings + .into_iter() + .filter_map(|s| -> Option<Result<_, _>> { + let filename = s.rfilename; + if filename.ends_with(".parquet") { + let reader_result = sibling_to_parquet(&filename, &repo); + Some(reader_result) + } else { + None + } + }) + .collect(); + let files = files?; + + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + use parquet::file::reader::FileReader; + + #[test] + fn test_dataset() { + let api = Api::new().unwrap(); + let files = from_hub( + &api, + "hf-internal-testing/dummy_image_text_data".to_string(), + ) + .unwrap(); + assert_eq!(files.len(), 1); + assert_eq!(files[0].metadata().file_metadata().num_rows(), 20); + } +} diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs index 42ad5d62..bfd77a99 100644 --- a/candle-datasets/src/lib.rs +++ b/candle-datasets/src/lib.rs @@ -1,5 +1,6 @@ //! Datasets & Dataloaders for Candle pub mod batcher; +pub mod hub; pub mod nlp; pub mod vision; diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index 2267f9a0..1085edd6 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -2,7 +2,9 @@ //! //! The files can be obtained from the following link: //! <http://yann.lecun.com/exdb/mnist/> -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, Error, Result, Tensor}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use parquet::file::reader::{FileReader, SerializedFileReader}; use std::fs::File; use std::io::{self, BufReader, Read}; @@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data labels: 10, }) } + +fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> { + let samples = parquet.metadata().file_metadata().num_rows() as usize; + let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784); + let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples); + for row in parquet.into_iter().flatten() { + for (_name, field) in row.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + buffer_images.extend(image.to_luma8().as_raw()); + } + } + } else if let parquet::record::Field::Long(label) = field { + buffer_labels.push(*label as u8); + } + } + } + let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)? + .to_dtype(DType::F32)? + / 255.)?; + let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?; + Ok((images, labels)) +} + +pub fn load() -> Result<crate::vision::Dataset> { + let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let dataset_id = "mnist".to_string(); + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let test_parquet_filename = repo + .get("mnist/mnist-test.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let train_parquet_filename = repo + .get("mnist/mnist-train.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; + let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; + let (test_images, test_labels) = load_parquet(test_parquet)?; + let (train_images, train_labels) = load_parquet(train_parquet)?; + Ok(crate::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, + }) +} diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index bcf8677d..ea4b4bce 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -138,12 +138,20 @@ struct Args { /// The file where to load the trained weights from, in safetensors format. #[arg(long)] load: Option<String>, + + /// The file where to load the trained weights from, in safetensors format. + #[arg(long)] + local_mnist: Option<String>, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); // Load the dataset - let m = candle_datasets::vision::mnist::load_dir("data")?; + let m = if let Some(directory) = args.local_mnist { + candle_datasets::vision::mnist::load_dir(directory)? + } else { + candle_datasets::vision::mnist::load()? + }; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 8bf94eb7..395162eb 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { image.save(p).map_err(candle::Error::wrap)?; Ok(()) } - -#[cfg(test)] -mod tests { - // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 - #[rustfmt::skip] - #[tokio::test] - async fn book_hub_1() { -// ANCHOR: book_hub_1 -use candle::Device; -use hf_hub::api::tokio::Api; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); - -let weights_filename = repo.get("model.safetensors").await.unwrap(); - -let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_1 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_2() { -// ANCHOR: book_hub_2 -use candle::Device; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; -let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_2 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_3() { -// ANCHOR: book_hub_3 -use candle::{DType, Device, Tensor}; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use safetensors::slice::IndexOp; -use safetensors::SafeTensors; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; - -// Use safetensors directly -let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); -let view = tensors - .tensor("bert.encoder.layer.0.attention.self.query.weight") - .unwrap(); - -// We're going to load shard with rank 1, within a world_size of 4 -// We're going to split along dimension 0 doing VIEW[start..stop, :] -let rank = 1; -let world_size = 4; -let dim = 0; -let dtype = view.dtype(); -let mut tp_shape = view.shape().to_vec(); -let size = tp_shape[0]; - -if size % world_size != 0 { - panic!("The dimension is not divisble by `world_size`"); -} -let block_size = size / world_size; -let start = rank * block_size; -let stop = (rank + 1) * block_size; - -// Everything is expressed in tensor dimension -// bytes offsets is handled automatically for safetensors. - -let iterator = view.slice(start..stop).unwrap(); - -tp_shape[dim] = block_size; - -// Convert safetensors Dtype to candle DType -let dtype: DType = dtype.try_into().unwrap(); - -// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. -let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); -let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_3 - assert_eq!(view.shape(), &[768, 768]); - assert_eq!(tp_tensor.dims(), &[192, 768]); - } -} |