diff options
Diffstat (limited to 'candle-examples/src/lib.rs')
-rw-r--r-- | candle-examples/src/lib.rs | 99 |
1 files changed, 0 insertions, 99 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 8bf94eb7..395162eb 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { image.save(p).map_err(candle::Error::wrap)?; Ok(()) } - -#[cfg(test)] -mod tests { - // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 - #[rustfmt::skip] - #[tokio::test] - async fn book_hub_1() { -// ANCHOR: book_hub_1 -use candle::Device; -use hf_hub::api::tokio::Api; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); - -let weights_filename = repo.get("model.safetensors").await.unwrap(); - -let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_1 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_2() { -// ANCHOR: book_hub_2 -use candle::Device; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; -let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_2 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_3() { -// ANCHOR: book_hub_3 -use candle::{DType, Device, Tensor}; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use safetensors::slice::IndexOp; -use safetensors::SafeTensors; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; - -// Use safetensors directly -let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); -let view = tensors - .tensor("bert.encoder.layer.0.attention.self.query.weight") - .unwrap(); - -// We're going to load shard with rank 1, within a world_size of 4 -// We're going to split along dimension 0 doing VIEW[start..stop, :] -let rank = 1; -let world_size = 4; -let dim = 0; -let dtype = view.dtype(); -let mut tp_shape = view.shape().to_vec(); -let size = tp_shape[0]; - -if size % world_size != 0 { - panic!("The dimension is not divisble by `world_size`"); -} -let block_size = size / world_size; -let start = rank * block_size; -let stop = (rank + 1) * block_size; - -// Everything is expressed in tensor dimension -// bytes offsets is handled automatically for safetensors. - -let iterator = view.slice(start..stop).unwrap(); - -tp_shape[dim] = block_size; - -// Convert safetensors Dtype to candle DType -let dtype: DType = dtype.try_into().unwrap(); - -// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. -let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); -let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_3 - assert_eq!(view.shape(), &[768, 768]); - assert_eq!(tp_tensor.dims(), &[192, 768]); - } -} |