summaryrefslogtreecommitdiff
path: root/candle-book
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-08-01 15:04:41 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-08-02 18:40:24 +0200
commit45642a8530fdfbd64fcac118aed59b7cb7dfaf45 (patch)
tree16e6aa12a193fff89f4b000f39e4cfe0a4ac0d25 /candle-book
parent82464166e4d947a717509922a566e7ceaf4b3f2f (diff)
downloadcandle-45642a8530fdfbd64fcac118aed59b7cb7dfaf45.tar.gz
candle-45642a8530fdfbd64fcac118aed59b7cb7dfaf45.tar.bz2
candle-45642a8530fdfbd64fcac118aed59b7cb7dfaf45.zip
Fixing examples.
Diffstat (limited to 'candle-book')
-rw-r--r--candle-book/src/inference/hub.md6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index 8cf375d7..de514322 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -58,20 +58,20 @@ Now that we have our weights, we can use them in our bert architecture:
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
-# use candle::Device;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
+use candle::{Device, Tensor, DType};
use candle_nn::Linear;
-let weights = candle::safetensors::load(weights, &Device::Cpu);
+let weights = candle::safetensors::load(weights, &Device::Cpu).unwrap();
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
-let linear = Linear::new(weight, Some(bias));
+let linear = Linear::new(weight.clone(), Some(bias.clone()));
let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids);