diff options
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..9d42dcc8 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { (config.seq_len, config.head_size() / 2), "rot.freq_cis_real", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let freq_cis_imag = vb .get( (config.seq_len, config.head_size() / 2), "rot.freq_cis_imag", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let fake_vb = candle_nn::VarBuilder::from_tensors( [ @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .into_iter() .collect(), candle::DType::F32, - &candle::Device::Cpu, + &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); |