diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-29 13:10:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 13:10:05 +0200 |
commit | 14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f (patch) | |
tree | 11d5c84dedb610b9e4306030ec36929d1f03e980 /candle-datasets/src/hub.rs | |
parent | 62ef494dc17c1f582b28c665e78f2aa78d846bb9 (diff) | |
parent | 2d5b7a735d2c9ccb890dae73862dc734ef0950ae (diff) | |
download | candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.gz candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.tar.bz2 candle-14b4d456e80a6fb218c6e3c16b4e5aeffb0c2c6f.zip |
Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
Diffstat (limited to 'candle-datasets/src/hub.rs')
-rw-r--r-- | candle-datasets/src/hub.rs | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs new file mode 100644 index 00000000..b135e148 --- /dev/null +++ b/candle-datasets/src/hub.rs @@ -0,0 +1,73 @@ +use hf_hub::{ + api::sync::{Api, ApiRepo}, + Repo, RepoType, +}; +use parquet::file::reader::SerializedFileReader; +use std::fs::File; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("ApiError : {0}")] + ApiError(#[from] hf_hub::api::sync::ApiError), + + #[error("IoError : {0}")] + IoError(#[from] std::io::Error), + + #[error("ParquetError : {0}")] + ParquetError(#[from] parquet::errors::ParquetError), +} + +fn sibling_to_parquet( + rfilename: &str, + repo: &ApiRepo, +) -> Result<SerializedFileReader<File>, Error> { + let local = repo.get(rfilename)?; + let file = File::open(local)?; + let reader = SerializedFileReader::new(file)?; + Ok(reader) +} + +pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> { + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let info = repo.info()?; + + let files: Result<Vec<_>, _> = info + .siblings + .into_iter() + .filter_map(|s| -> Option<Result<_, _>> { + let filename = s.rfilename; + if filename.ends_with(".parquet") { + let reader_result = sibling_to_parquet(&filename, &repo); + Some(reader_result) + } else { + None + } + }) + .collect(); + let files = files?; + + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + use parquet::file::reader::FileReader; + + #[test] + fn test_dataset() { + let api = Api::new().unwrap(); + let files = from_hub( + &api, + "hf-internal-testing/dummy_image_text_data".to_string(), + ) + .unwrap(); + assert_eq!(files.len(), 1); + assert_eq!(files[0].metadata().file_metadata().num_rows(), 20); + } +} |