summaryrefslogtreecommitdiff
path: root/candle-book/src/inference
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-28 12:07:39 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-08-02 18:40:24 +0200
commit82464166e4d947a717509922a566e7ceaf4b3f2f (patch)
tree4e18c57dffe18c843e9f8de478095cdcd01127c1 /candle-book/src/inference
parent52414ba5c853a2b39b393677a89d07a73fdc7a15 (diff)
downloadcandle-82464166e4d947a717509922a566e7ceaf4b3f2f.tar.gz
candle-82464166e4d947a717509922a566e7ceaf4b3f2f.tar.bz2
candle-82464166e4d947a717509922a566e7ceaf4b3f2f.zip
3rd phase.
Diffstat (limited to 'candle-book/src/inference')
-rw-r--r--candle-book/src/inference/README.md6
-rw-r--r--candle-book/src/inference/hub.md79
-rw-r--r--candle-book/src/inference/serialization.md2
3 files changed, 87 insertions, 0 deletions
diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/README.md
index c82f85e1..1b75a310 100644
--- a/candle-book/src/inference/README.md
+++ b/candle-book/src/inference/README.md
@@ -1 +1,7 @@
# Running a model
+
+
+In order to run an existing model, you will need to download and use existing weights.
+Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
+
+Let's get started by running an old model : `bert-base-uncased`.
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index 6242c070..8cf375d7 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -1 +1,80 @@
# Using the hub
+
+Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
+
+```bash
+cargo add hf-hub
+```
+
+Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
+
+
+```rust
+# extern crate candle;
+# extern crate hf_hub;
+use hf_hub::api::sync::Api;
+use candle::Device;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights = repo.get("model.safetensors").unwrap();
+
+let weights = candle::safetensors::load(weights, &Device::Cpu);
+```
+
+We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
+
+
+## Using async
+
+`hf-hub` comes with an async API.
+
+```bash
+cargo add hf-hub --features tokio
+```
+
+```rust,ignore
+# extern crate candle;
+# extern crate hf_hub;
+use hf_hub::api::tokio::Api;
+use candle::Device;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights, &Device::Cpu);
+```
+
+
+## Using in a real model.
+
+Now that we have our weights, we can use them in our bert architecture:
+
+```rust
+# extern crate candle;
+# extern crate candle_nn;
+# extern crate hf_hub;
+# use hf_hub::api::sync::Api;
+# use candle::Device;
+#
+# let api = Api::new().unwrap();
+# let repo = api.model("bert-base-uncased".to_string());
+#
+# let weights = repo.get("model.safetensors").unwrap();
+use candle_nn::Linear;
+
+let weights = candle::safetensors::load(weights, &Device::Cpu);
+
+let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
+let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
+
+let linear = Linear::new(weight, Some(bias));
+
+let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
+let output = linear.forward(&input_ids);
+```
+
+For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
diff --git a/candle-book/src/inference/serialization.md b/candle-book/src/inference/serialization.md
index 0dfc62d3..133ff025 100644
--- a/candle-book/src/inference/serialization.md
+++ b/candle-book/src/inference/serialization.md
@@ -1 +1,3 @@
# Serialization
+
+Once you have a r