summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-lm/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-20 13:04:36 +0100
committerGitHub <noreply@github.com>2024-03-20 13:04:36 +0100
commit455c42aa729d8019fcb496106478e75dd3246c08 (patch)
treee240ed2766639456e52d1fcea82989633f7fb0a5 /candle-examples/examples/stable-lm/main.rs
parent2a8679509eb55232b37378442c4366343f6dcb11 (diff)
downloadcandle-455c42aa729d8019fcb496106478e75dd3246c08.tar.gz
candle-455c42aa729d8019fcb496106478e75dd3246c08.tar.bz2
candle-455c42aa729d8019fcb496106478e75dd3246c08.zip
Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze. * Fix the quantized llama example. * Unrelated fix for the quantized stable-lm example on cuda. * Fix for mamba on cuda (unrelated to the PR).
Diffstat (limited to 'candle-examples/examples/stable-lm/main.rs')
-rw-r--r--candle-examples/examples/stable-lm/main.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs
index f467903a..f0707010 100644
--- a/candle-examples/examples/stable-lm/main.rs
+++ b/candle-examples/examples/stable-lm/main.rs
@@ -288,12 +288,12 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
- let (model, device) = if args.quantized {
+ let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
- (Model::Quantized(model), Device::Cpu)
+ Model::Quantized(model)
} else {
let dtype = if device.is_cuda() {
DType::BF16
@@ -302,7 +302,7 @@ fn main() -> Result<()> {
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = StableLM::new(&config, vb)?;
- (Model::StableLM(model), device)
+ Model::StableLM(model)
};
println!("loaded the model in {:?}", start.elapsed());