diff options
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index bfc6de53..34c44233 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -9,7 +9,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, Tensor}; +use candle::Tensor; use candle_transformers::generation::LogitsProcessor; use candle_examples::token_output_stream::TokenOutputStream; @@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> { let model_path = args.model()?; let mut file = std::fs::File::open(&model_path)?; let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; let mut model = match model_path.extension().and_then(|v| v.to_str()) { Some("gguf") => { @@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> { for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); total_size_in_bytes += - elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file)? + ModelWeights::from_gguf(model, &mut file, &device)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); total_size_in_bytes += - elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> { let start_prompt_processing = std::time::Instant::now(); let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; logits_processor.sample(&logits)? @@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> { let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample { - let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { |