summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r--candle-examples/examples/mistral/main.rs7
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 5ed5e5cb..bad86098 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -244,13 +244,14 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
- (Model::Quantized(model), Device::Cpu)
+ (Model::Quantized(model), device)
} else {
- let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {