From 45642a8530fdfbd64fcac118aed59b7cb7dfaf45 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Aug 2023 15:04:41 +0200 Subject: Fixing examples. --- candle-book/src/inference/hub.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'candle-book') diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index 8cf375d7..de514322 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -58,20 +58,20 @@ Now that we have our weights, we can use them in our bert architecture: # extern crate candle_nn; # extern crate hf_hub; # use hf_hub::api::sync::Api; -# use candle::Device; # # let api = Api::new().unwrap(); # let repo = api.model("bert-base-uncased".to_string()); # # let weights = repo.get("model.safetensors").unwrap(); +use candle::{Device, Tensor, DType}; use candle_nn::Linear; -let weights = candle::safetensors::load(weights, &Device::Cpu); +let weights = candle::safetensors::load(weights, &Device::Cpu).unwrap(); let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap(); let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap(); -let linear = Linear::new(weight, Some(bias)); +let linear = Linear::new(weight.clone(), Some(bias.clone())); let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap(); let output = linear.forward(&input_ids); -- cgit v1.2.3