summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r--candle-examples/examples/quantized/main.rs16
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index bfc6de53..34c44233 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
-use candle::{Device, Tensor};
+use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
@@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
@@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
- elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
+ elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
- ModelWeights::from_gguf(model, &mut file)?
+ ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None => {
- let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
+ let model = ggml_file::Content::read(&mut file, &device)
+ .map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
- elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
+ elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
- let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
@@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
- let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {