diff options
Diffstat (limited to 'candle-examples/src')
-rw-r--r-- | candle-examples/src/lib.rs | 84 |
1 files changed, 78 insertions, 6 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 3fdd4cc9..0b716e4f 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -56,6 +56,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { #[cfg(test)] mod tests { use anyhow::Result; + use candle::{DType, Device, Tensor}; + use parquet::file::reader::SerializedFileReader; // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 #[rustfmt::skip] @@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un #[test] fn book_training_1() -> Result<()>{ // ANCHOR: book_training_1 -use candle_datasets::hub::from_hub; -use hf_hub::api::sync::Api; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +let dataset_id = "mnist".to_string(); let api = Api::new()?; -let files = from_hub(&api, "mnist".to_string())?; +let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), +); +let repo = api.repo(repo); +let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?; +let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; // ANCHOR_END: book_training_1 +// Ignore unused +let _train = train_parquet; // ANCHOR: book_training_2 -let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten(); -for row in rows { - for (idx, (name, field)) in row.get_column_iter().enumerate() { +for row in test_parquet { + for (idx, (name, field)) in row?.get_column_iter().enumerate() { println!("Column id {idx}, name {name}, value {field}"); } } // ANCHOR_END: book_training_2 +let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?; +let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR: book_training_3 + +let test_samples = 10_000; +let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784); +let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples); +for row in test_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + test_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + test_buffer_labels.push(*label as u8); + } + } +} +let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?; + +let train_samples = 60_000; +let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784); +let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples); +for row in train_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + train_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + train_buffer_labels.push(*label as u8); + } + } +} +let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?; + +let mnist = candle_datasets::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, +}; + +// ANCHOR_END: book_training_3 +assert_eq!(mnist.test_images.dims(), &[10_000, 784]); +assert_eq!(mnist.test_labels.dims(), &[10_000]); +assert_eq!(mnist.train_images.dims(), &[60_000, 784]); +assert_eq!(mnist.train_labels.dims(), &[60_000]); Ok(()) } } |