diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-25 21:49:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-25 21:49:21 +0100 |
commit | 37c539f2b7dfc8aa67a10b611dc12e5e0428be00 (patch) | |
tree | 48980a31058157808fc777f7fbe50239931afc54 /candle-examples/src | |
parent | eae3a20d43a855acbaa7afd4494acdef9648d2b3 (diff) | |
download | candle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.tar.gz candle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.tar.bz2 candle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.zip |
Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example.
* Add a helper function to load sharded safetensors weights.
* Use the sharded loader.
Diffstat (limited to 'candle-examples/src')
-rw-r--r-- | candle-examples/src/lib.rs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index dff31b85..d6dce4a3 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -117,3 +117,30 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>( image.save(p).map_err(candle::Error::wrap)?; Ok(()) } + +/// Loads the safetensors files for a model from the hub based on a json index file. +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> Result<Vec<std::path::PathBuf>> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::<Result<Vec<_>>>()?; + Ok(safetensors_files) +} |