summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-03 11:10:58 +0100
committerGitHub <noreply@github.com>2023-08-03 11:10:58 +0100
commita79286885caaf453821dcc8a1328eba0cf573092 (patch)
tree86d792085ff898f2c0ff09cdc949eb3c2e65bbe2 /candle-examples/examples/llama2-c
parent74845a4dcdc2985bef2c0a7dd7c60c2938ad419d (diff)
downloadcandle-a79286885caaf453821dcc8a1328eba0cf573092.tar.gz
candle-a79286885caaf453821dcc8a1328eba0cf573092.tar.bz2
candle-a79286885caaf453821dcc8a1328eba0cf573092.zip
Support safetensors weights in llama2.c inference. (#317)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs23
-rw-r--r--candle-examples/examples/llama2-c/weights.rs2
2 files changed, 18 insertions, 7 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 8b64fdd2..612dc358 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -27,7 +27,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
- /// Config file in binary format.
+ /// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@@ -225,11 +225,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
- let mut file = std::fs::File::open(config_path)?;
- let config = Config::from_reader(&mut file)?;
- println!("{config:?}");
- let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
- let vb = weights.var_builder(&config, &device)?;
+ let is_safetensors = config_path
+ .extension()
+ .map_or(false, |v| v == "safetensors");
+ let (vb, config) = if is_safetensors {
+ let config = Config::tiny();
+ let tensors = candle::safetensors::load(config_path, &device)?;
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
+ (vb, config)
+ } else {
+ let mut file = std::fs::File::open(config_path)?;
+ let config = Config::from_reader(&mut file)?;
+ println!("{config:?}");
+ let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
+ let vb = weights.var_builder(&config, &device)?;
+ (vb, config)
+ };
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs
index ae1fd6d9..2daed057 100644
--- a/candle-examples/examples/llama2-c/weights.rs
+++ b/candle-examples/examples/llama2-c/weights.rs
@@ -104,7 +104,7 @@ impl TransformerWeights {
})
}
- pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);