summaryrefslogtreecommitdiff
path: root/candle-examples/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-25 21:49:21 +0100
committerGitHub <noreply@github.com>2023-12-25 21:49:21 +0100
commit37c539f2b7dfc8aa67a10b611dc12e5e0428be00 (patch)
tree48980a31058157808fc777f7fbe50239931afc54 /candle-examples/src
parenteae3a20d43a855acbaa7afd4494acdef9648d2b3 (diff)
downloadcandle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.tar.gz
candle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.tar.bz2
candle-37c539f2b7dfc8aa67a10b611dc12e5e0428be00.zip
Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example. * Add a helper function to load sharded safetensors weights. * Use the sharded loader.
Diffstat (limited to 'candle-examples/src')
-rw-r--r--candle-examples/src/lib.rs27
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index dff31b85..d6dce4a3 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -117,3 +117,30 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
+
+/// Loads the safetensors files for a model from the hub based on a json index file.
+pub fn hub_load_safetensors(
+ repo: &hf_hub::api::sync::ApiRepo,
+ json_file: &str,
+) -> Result<Vec<std::path::PathBuf>> {
+ let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
+ let json_file = std::fs::File::open(json_file)?;
+ let json: serde_json::Value =
+ serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
+ let weight_map = match json.get("weight_map") {
+ None => candle::bail!("no weight map in {json_file:?}"),
+ Some(serde_json::Value::Object(map)) => map,
+ Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
+ };
+ let mut safetensors_files = std::collections::HashSet::new();
+ for value in weight_map.values() {
+ if let Some(file) = value.as_str() {
+ safetensors_files.insert(file.to_string());
+ }
+ }
+ let safetensors_files = safetensors_files
+ .iter()
+ .map(|v| repo.get(v).map_err(candle::Error::wrap))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(safetensors_files)
+}