summaryrefslogtreecommitdiff
path: root/candle-examples/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/src/lib.rs')
-rw-r--r--candle-examples/src/lib.rs99
1 files changed, 0 insertions, 99 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 8bf94eb7..395162eb 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
-
-#[cfg(test)]
-mod tests {
- // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
- #[rustfmt::skip]
- #[tokio::test]
- async fn book_hub_1() {
-// ANCHOR: book_hub_1
-use candle::Device;
-use hf_hub::api::tokio::Api;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-
-let weights_filename = repo.get("model.safetensors").await.unwrap();
-
-let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_1
- assert_eq!(weights.len(), 206);
- }
-
- #[rustfmt::skip]
- #[test]
- fn book_hub_2() {
-// ANCHOR: book_hub_2
-use candle::Device;
-use hf_hub::api::sync::Api;
-use memmap2::Mmap;
-use std::fs;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-let weights_filename = repo.get("model.safetensors").unwrap();
-
-let file = fs::File::open(weights_filename).unwrap();
-let mmap = unsafe { Mmap::map(&file).unwrap() };
-let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_2
- assert_eq!(weights.len(), 206);
- }
-
- #[rustfmt::skip]
- #[test]
- fn book_hub_3() {
-// ANCHOR: book_hub_3
-use candle::{DType, Device, Tensor};
-use hf_hub::api::sync::Api;
-use memmap2::Mmap;
-use safetensors::slice::IndexOp;
-use safetensors::SafeTensors;
-use std::fs;
-
-let api = Api::new().unwrap();
-let repo = api.model("bert-base-uncased".to_string());
-let weights_filename = repo.get("model.safetensors").unwrap();
-
-let file = fs::File::open(weights_filename).unwrap();
-let mmap = unsafe { Mmap::map(&file).unwrap() };
-
-// Use safetensors directly
-let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
-let view = tensors
- .tensor("bert.encoder.layer.0.attention.self.query.weight")
- .unwrap();
-
-// We're going to load shard with rank 1, within a world_size of 4
-// We're going to split along dimension 0 doing VIEW[start..stop, :]
-let rank = 1;
-let world_size = 4;
-let dim = 0;
-let dtype = view.dtype();
-let mut tp_shape = view.shape().to_vec();
-let size = tp_shape[0];
-
-if size % world_size != 0 {
- panic!("The dimension is not divisble by `world_size`");
-}
-let block_size = size / world_size;
-let start = rank * block_size;
-let stop = (rank + 1) * block_size;
-
-// Everything is expressed in tensor dimension
-// bytes offsets is handled automatically for safetensors.
-
-let iterator = view.slice(start..stop).unwrap();
-
-tp_shape[dim] = block_size;
-
-// Convert safetensors Dtype to candle DType
-let dtype: DType = dtype.try_into().unwrap();
-
-// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
-let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
-let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
-// ANCHOR_END: book_hub_3
- assert_eq!(view.shape(), &[768, 768]);
- assert_eq!(tp_tensor.dims(), &[192, 768]);
- }
-}