diff options
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 69eed84f..39f4fd58 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -307,18 +307,21 @@ fn main() -> Result<()> { WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let device = candle_examples::device(args.cpu)?; + let model = if args.quantized { let config = config(); + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &filenames[0], + &device, + )?; let model = match args.model { WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = match args.model { + match args.model { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; @@ -334,8 +337,7 @@ fn main() -> Result<()> { let config = config(); Model::MixFormer(MixFormer::new(&config, vb)?) } - }; - (model, device) + } }; println!("loaded the model in {:?}", start.elapsed()); |