diff options
Diffstat (limited to 'candle-datasets/src/hub.rs')
-rw-r--r-- | candle-datasets/src/hub.rs | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs new file mode 100644 index 00000000..b135e148 --- /dev/null +++ b/candle-datasets/src/hub.rs @@ -0,0 +1,73 @@ +use hf_hub::{ + api::sync::{Api, ApiRepo}, + Repo, RepoType, +}; +use parquet::file::reader::SerializedFileReader; +use std::fs::File; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("ApiError : {0}")] + ApiError(#[from] hf_hub::api::sync::ApiError), + + #[error("IoError : {0}")] + IoError(#[from] std::io::Error), + + #[error("ParquetError : {0}")] + ParquetError(#[from] parquet::errors::ParquetError), +} + +fn sibling_to_parquet( + rfilename: &str, + repo: &ApiRepo, +) -> Result<SerializedFileReader<File>, Error> { + let local = repo.get(rfilename)?; + let file = File::open(local)?; + let reader = SerializedFileReader::new(file)?; + Ok(reader) +} + +pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> { + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let info = repo.info()?; + + let files: Result<Vec<_>, _> = info + .siblings + .into_iter() + .filter_map(|s| -> Option<Result<_, _>> { + let filename = s.rfilename; + if filename.ends_with(".parquet") { + let reader_result = sibling_to_parquet(&filename, &repo); + Some(reader_result) + } else { + None + } + }) + .collect(); + let files = files?; + + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + use parquet::file::reader::FileReader; + + #[test] + fn test_dataset() { + let api = Api::new().unwrap(); + let files = from_hub( + &api, + "hf-internal-testing/dummy_image_text_data".to_string(), + ) + .unwrap(); + assert_eq!(files.len(), 1); + assert_eq!(files[0].metadata().file_metadata().num_rows(), 20); + } +} |