summaryrefslogtreecommitdiff
path: root/candle-examples/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/src/lib.rs')
-rw-r--r--candle-examples/src/lib.rs84
1 files changed, 78 insertions, 6 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 3fdd4cc9..0b716e4f 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -56,6 +56,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
#[cfg(test)]
mod tests {
use anyhow::Result;
+ use candle::{DType, Device, Tensor};
+ use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
@@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
-use candle_datasets::hub::from_hub;
-use hf_hub::api::sync::Api;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+
+let dataset_id = "mnist".to_string();
let api = Api::new()?;
-let files = from_hub(&api, "mnist".to_string())?;
+let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+);
+let repo = api.repo(repo);
+let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
+let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
+// Ignore unused
+let _train = train_parquet;
// ANCHOR: book_training_2
-let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
-for row in rows {
- for (idx, (name, field)) in row.get_column_iter().enumerate() {
+for row in test_parquet {
+ for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
+let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
+let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR: book_training_3
+
+let test_samples = 10_000;
+let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
+let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
+for row in test_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ test_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ test_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
+
+let train_samples = 60_000;
+let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
+let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
+for row in train_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ train_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ train_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
+
+let mnist = candle_datasets::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+};
+
+// ANCHOR_END: book_training_3
+assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
+assert_eq!(mnist.test_labels.dims(), &[10_000]);
+assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
+assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}