diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-29 13:10:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 13:10:05 +0200 |
commit | 14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f (patch) | |
tree | 11d5c84dedb610b9e4306030ec36929d1f03e980 /candle-examples/src | |
parent | 62ef494dc17c1f582b28c665e78f2aa78d846bb9 (diff) | |
parent | 2d5b7a735d2c9ccb890dae73862dc734ef0950ae (diff) | |
download | candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.gz candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.bz2 candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.zip |
Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
Diffstat (limited to 'candle-examples/src')
-rw-r--r-- | candle-examples/src/lib.rs | 99 |
1 files changed, 0 insertions, 99 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 8bf94eb7..395162eb 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { image.save(p).map_err(candle::Error::wrap)?; Ok(()) } - -#[cfg(test)] -mod tests { - // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 - #[rustfmt::skip] - #[tokio::test] - async fn book_hub_1() { -// ANCHOR: book_hub_1 -use candle::Device; -use hf_hub::api::tokio::Api; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); - -let weights_filename = repo.get("model.safetensors").await.unwrap(); - -let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_1 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_2() { -// ANCHOR: book_hub_2 -use candle::Device; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; -let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_2 - assert_eq!(weights.len(), 206); - } - - #[rustfmt::skip] - #[test] - fn book_hub_3() { -// ANCHOR: book_hub_3 -use candle::{DType, Device, Tensor}; -use hf_hub::api::sync::Api; -use memmap2::Mmap; -use safetensors::slice::IndexOp; -use safetensors::SafeTensors; -use std::fs; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); -let weights_filename = repo.get("model.safetensors").unwrap(); - -let file = fs::File::open(weights_filename).unwrap(); -let mmap = unsafe { Mmap::map(&file).unwrap() }; - -// Use safetensors directly -let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); -let view = tensors - .tensor("bert.encoder.layer.0.attention.self.query.weight") - .unwrap(); - -// We're going to load shard with rank 1, within a world_size of 4 -// We're going to split along dimension 0 doing VIEW[start..stop, :] -let rank = 1; -let world_size = 4; -let dim = 0; -let dtype = view.dtype(); -let mut tp_shape = view.shape().to_vec(); -let size = tp_shape[0]; - -if size % world_size != 0 { - panic!("The dimension is not divisble by `world_size`"); -} -let block_size = size / world_size; -let start = rank * block_size; -let stop = (rank + 1) * block_size; - -// Everything is expressed in tensor dimension -// bytes offsets is handled automatically for safetensors. - -let iterator = view.slice(start..stop).unwrap(); - -tp_shape[dim] = block_size; - -// Convert safetensors Dtype to candle DType -let dtype: DType = dtype.try_into().unwrap(); - -// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. -let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); -let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_3 - assert_eq!(view.shape(), &[768, 768]); - assert_eq!(tp_tensor.dims(), &[192, 768]); - } -} |