summaryrefslogtreecommitdiff
path: root/candle-examples/src
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-08-01 16:36:53 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-08-02 18:40:24 +0200
commita44471a305f2bc768c4f0dd0e7d23a7cfe3cb408 (patch)
treef2f51f7e58f0fd7bfb03bc67e4b7bac99278d340 /candle-examples/src
parent45642a8530fdfbd64fcac118aed59b7cb7dfaf45 (diff)
downloadcandle-a44471a305f2bc768c4f0dd0e7d23a7cfe3cb408.tar.gz
candle-a44471a305f2bc768c4f0dd0e7d23a7cfe3cb408.tar.bz2
candle-a44471a305f2bc768c4f0dd0e7d23a7cfe3cb408.zip
Adding more details on how to load things.
- Loading with memmap - Loading a sharded tensor - Moved some snippets to `candle-examples/src/lib.rs` This is because managing book specific dependencies is a pain https://github.com/rust-lang/mdBook/issues/706 - This causes a non aligned inclusion https://github.com/rust-lang/mdBook/pull/1856 which we have to ignore fmt to remove. mdbook might need some more love :)
Diffstat (limited to 'candle-examples/src')
-rw-r--r--candle-examples/src/lib.rs99
1 files changed, 99 insertions, 0 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 285aee04..3410026e 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}
+
+#[cfg(test)]
+mod tests {
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+.tensor("bert.encoder.layer.0.attention.self.query.weight")
+.unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator to we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+}