diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-15 15:52:37 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-28 15:15:27 +0200 |
commit | d726484a6dfed53236a8d3d72c23a3d26371a90d (patch) | |
tree | ac2bde8ccccde47c9fc25363df2b20848cedba28 /candle-datasets/src/vision/mnist.rs | |
parent | dd06d93d0ba0fdc71d1f64057e7eca4276b99ee9 (diff) | |
download | candle-d726484a6dfed53236a8d3d72c23a3d26371a90d.tar.gz candle-d726484a6dfed53236a8d3d72c23a3d26371a90d.tar.bz2 candle-d726484a6dfed53236a8d3d72c23a3d26371a90d.zip |
Re-enable local dir for mnist.
Diffstat (limited to 'candle-datasets/src/vision/mnist.rs')
-rw-r--r-- | candle-datasets/src/vision/mnist.rs | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index c908412c..1085edd6 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -2,7 +2,7 @@ //! //! The files can be obtained from the following link: //! <http://yann.lecun.com/exdb/mnist/> -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, Error, Result, Tensor}; use hf_hub::{api::sync::Api, Repo, RepoType}; use parquet::file::reader::{FileReader, SerializedFileReader}; use std::fs::File; @@ -92,7 +92,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, } pub fn load() -> Result<crate::vision::Dataset> { - let api = Api::new().unwrap(); + let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?; let dataset_id = "mnist".to_string(); let repo = Repo::with_revision( dataset_id, @@ -100,12 +100,16 @@ pub fn load() -> Result<crate::vision::Dataset> { "refs/convert/parquet".to_string(), ); let repo = api.repo(repo); - let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap(); - let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap(); - let test_parquet = - SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap(); - let train_parquet = - SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap(); + let test_parquet_filename = repo + .get("mnist/mnist-test.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let train_parquet_filename = repo + .get("mnist/mnist-train.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; + let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; let (test_images, test_labels) = load_parquet(test_parquet)?; let (train_images, train_labels) = load_parquet(train_parquet)?; Ok(crate::vision::Dataset { |