summaryrefslogtreecommitdiff
path: root/candle-examples/examples/phi/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r--candle-examples/examples/phi/main.rs16
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 69eed84f..39f4fd58 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -307,18 +307,21 @@ fn main() -> Result<()> {
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let device = candle_examples::device(args.cpu)?;
+ let model = if args.quantized {
let config = config();
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ &filenames[0],
+ &device,
+ )?;
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
- (Model::Quantized(model), Device::Cpu)
+ Model::Quantized(model)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
- let model = match args.model {
+ match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
@@ -334,8 +337,7 @@ fn main() -> Result<()> {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
- };
- (model, device)
+ }
};
println!("loaded the model in {:?}", start.elapsed());