summaryrefslogtreecommitdiff
path: root/candle-datasets/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-datasets/src')
-rw-r--r--candle-datasets/src/hub.rs73
-rw-r--r--candle-datasets/src/lib.rs1
-rw-r--r--candle-datasets/src/vision/mnist.rs59
3 files changed, 132 insertions, 1 deletions
diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs
new file mode 100644
index 00000000..b135e148
--- /dev/null
+++ b/candle-datasets/src/hub.rs
@@ -0,0 +1,73 @@
+use hf_hub::{
+ api::sync::{Api, ApiRepo},
+ Repo, RepoType,
+};
+use parquet::file::reader::SerializedFileReader;
+use std::fs::File;
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+ #[error("ApiError : {0}")]
+ ApiError(#[from] hf_hub::api::sync::ApiError),
+
+ #[error("IoError : {0}")]
+ IoError(#[from] std::io::Error),
+
+ #[error("ParquetError : {0}")]
+ ParquetError(#[from] parquet::errors::ParquetError),
+}
+
+fn sibling_to_parquet(
+ rfilename: &str,
+ repo: &ApiRepo,
+) -> Result<SerializedFileReader<File>, Error> {
+ let local = repo.get(rfilename)?;
+ let file = File::open(local)?;
+ let reader = SerializedFileReader::new(file)?;
+ Ok(reader)
+}
+
+pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
+ let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+ );
+ let repo = api.repo(repo);
+ let info = repo.info()?;
+
+ let files: Result<Vec<_>, _> = info
+ .siblings
+ .into_iter()
+ .filter_map(|s| -> Option<Result<_, _>> {
+ let filename = s.rfilename;
+ if filename.ends_with(".parquet") {
+ let reader_result = sibling_to_parquet(&filename, &repo);
+ Some(reader_result)
+ } else {
+ None
+ }
+ })
+ .collect();
+ let files = files?;
+
+ Ok(files)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use parquet::file::reader::FileReader;
+
+ #[test]
+ fn test_dataset() {
+ let api = Api::new().unwrap();
+ let files = from_hub(
+ &api,
+ "hf-internal-testing/dummy_image_text_data".to_string(),
+ )
+ .unwrap();
+ assert_eq!(files.len(), 1);
+ assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
+ }
+}
diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs
index 42ad5d62..bfd77a99 100644
--- a/candle-datasets/src/lib.rs
+++ b/candle-datasets/src/lib.rs
@@ -1,5 +1,6 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
+pub mod hub;
pub mod nlp;
pub mod vision;
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 2267f9a0..1085edd6 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -2,7 +2,9 @@
//!
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
-use candle::{DType, Device, Result, Tensor};
+use candle::{DType, Device, Error, Result, Tensor};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
@@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
labels: 10,
})
}
+
+fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
+ let samples = parquet.metadata().file_metadata().num_rows() as usize;
+ let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
+ let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
+ for row in parquet.into_iter().flatten() {
+ for (_name, field) in row.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ } else if let parquet::record::Field::Long(label) = field {
+ buffer_labels.push(*label as u8);
+ }
+ }
+ }
+ let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
+ .to_dtype(DType::F32)?
+ / 255.)?;
+ let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
+ Ok((images, labels))
+}
+
+pub fn load() -> Result<crate::vision::Dataset> {
+ let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let dataset_id = "mnist".to_string();
+ let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+ );
+ let repo = api.repo(repo);
+ let test_parquet_filename = repo
+ .get("mnist/mnist-test.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let train_parquet_filename = repo
+ .get("mnist/mnist-train.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
+ let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
+ let (test_images, test_labels) = load_parquet(test_parquet)?;
+ let (train_images, train_labels) = load_parquet(train_parquet)?;
+ Ok(crate::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+ })
+}