summaryrefslogtreecommitdiff
path: root/candle-examples/examples/replit-code/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/replit-code/main.rs')
-rw-r--r--candle-examples/examples/replit-code/main.rs13
1 files changed, 6 insertions, 7 deletions
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 0f72b862..b7f767b9 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -236,16 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
- let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
- (model, Device::Cpu)
+ let model = if args.quantized {
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
+ Model::Q(Q::new(&config, vb.pp("transformer"))?)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = Model::M(M::new(&config, vb.pp("transformer"))?);
- (model, device)
+ Model::M(M::new(&config, vb.pp("transformer"))?)
};
println!("loaded the model in {:?}", start.elapsed());