summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-book/src/inference/hub.md46
-rw-r--r--candle-core/src/safetensors.rs6
-rw-r--r--candle-examples/Cargo.toml4
-rw-r--r--candle-examples/src/lib.rs99
4 files changed, 143 insertions, 12 deletions
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index de514322..01492df1 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -25,6 +25,8 @@ let weights = candle::safetensors::load(weights, &Device::Cpu);
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
+You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
+
## Using async
@@ -35,17 +37,9 @@ cargo add hf-hub --features tokio
```
```rust,ignore
-# extern crate candle;
-# extern crate hf_hub;
-use hf_hub::api::tokio::Api;
-use candle::Device;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-
-let weights = repo.get("model.safetensors").await.unwrap();
-
-let weights = candle::safetensors::load(weights, &Device::Cpu);
+# This is tested directly in examples crate because it needs external dependencies unfortunately:
+# See [this](https://github.com/rust-lang/mdBook/issues/706)
+{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```
@@ -78,3 +72,33 @@ let output = linear.forward(&input_ids);
```
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
+
+## Memory mapping
+
+For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
+
+**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
+and will definitely be slower on network mounted disk, because it will issue more read calls.
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
+```
+
+**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
+In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
+
+
+## Tensor Parallel Sharding
+
+When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
+
+For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
+
+```bash
+cargo add safetensors
+```
+
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
+```
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 1880a041..132fb914 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
- let st = safetensors::SafeTensors::deserialize(&data)?;
+ load_buffer(&data[..], device)
+}
+
+pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
+ let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 0db960ca..d4544ef7 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -25,6 +25,7 @@ half = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
+hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
hf-hub = { workspace = true }
memmap2 = { workspace = true }
@@ -34,6 +35,9 @@ tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
+# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
+tokio = "1.29.1"
+memmap2.workspace = true
[build-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 285aee04..3410026e 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}
+
+#[cfg(test)]
+mod tests {
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+.tensor("bert.encoder.layer.0.attention.self.query.weight")
+.unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator to we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+}