diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-20 13:04:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-20 13:04:36 +0100 |
commit | 455c42aa729d8019fcb496106478e75dd3246c08 (patch) | |
tree | e240ed2766639456e52d1fcea82989633f7fb0a5 /candle-examples/examples/stable-lm/main.rs | |
parent | 2a8679509eb55232b37378442c4366343f6dcb11 (diff) | |
download | candle-455c42aa729d8019fcb496106478e75dd3246c08.tar.gz candle-455c42aa729d8019fcb496106478e75dd3246c08.tar.bz2 candle-455c42aa729d8019fcb496106478e75dd3246c08.zip |
Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze.
* Fix the quantized llama example.
* Unrelated fix for the quantized stable-lm example on cuda.
* Fix for mamba on cuda (unrelated to the PR).
Diffstat (limited to 'candle-examples/examples/stable-lm/main.rs')
-rw-r--r-- | candle-examples/examples/stable-lm/main.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index f467903a..f0707010 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -288,12 +288,12 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; - let (model, device) = if args.quantized { + let model = if args.quantized { let filename = &filenames[0]; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { let dtype = if device.is_cuda() { DType::BF16 @@ -302,7 +302,7 @@ fn main() -> Result<()> { }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = StableLM::new(&config, vb)?; - (Model::StableLM(model), device) + Model::StableLM(model) }; println!("loaded the model in {:?}", start.elapsed()); |