diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-28 12:07:39 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-02 18:40:24 +0200 |
commit | 82464166e4d947a717509922a566e7ceaf4b3f2f (patch) | |
tree | 4e18c57dffe18c843e9f8de478095cdcd01127c1 /candle-book/src/inference | |
parent | 52414ba5c853a2b39b393677a89d07a73fdc7a15 (diff) | |
download | candle-82464166e4d947a717509922a566e7ceaf4b3f2f.tar.gz candle-82464166e4d947a717509922a566e7ceaf4b3f2f.tar.bz2 candle-82464166e4d947a717509922a566e7ceaf4b3f2f.zip |
3rd phase.
Diffstat (limited to 'candle-book/src/inference')
-rw-r--r-- | candle-book/src/inference/README.md | 6 | ||||
-rw-r--r-- | candle-book/src/inference/hub.md | 79 | ||||
-rw-r--r-- | candle-book/src/inference/serialization.md | 2 |
3 files changed, 87 insertions, 0 deletions
diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/README.md index c82f85e1..1b75a310 100644 --- a/candle-book/src/inference/README.md +++ b/candle-book/src/inference/README.md @@ -1 +1,7 @@ # Running a model + + +In order to run an existing model, you will need to download and use existing weights. +Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format. + +Let's get started by running an old model : `bert-base-uncased`. diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index 6242c070..8cf375d7 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -1 +1,80 @@ # Using the hub + +Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate: + +```bash +cargo add hf-hub +``` + +Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main). + + +```rust +# extern crate candle; +# extern crate hf_hub; +use hf_hub::api::sync::Api; +use candle::Device; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights = repo.get("model.safetensors").unwrap(); + +let weights = candle::safetensors::load(weights, &Device::Cpu); +``` + +We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file. + + +## Using async + +`hf-hub` comes with an async API. + +```bash +cargo add hf-hub --features tokio +``` + +```rust,ignore +# extern crate candle; +# extern crate hf_hub; +use hf_hub::api::tokio::Api; +use candle::Device; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights = repo.get("model.safetensors").await.unwrap(); + +let weights = candle::safetensors::load(weights, &Device::Cpu); +``` + + +## Using in a real model. + +Now that we have our weights, we can use them in our bert architecture: + +```rust +# extern crate candle; +# extern crate candle_nn; +# extern crate hf_hub; +# use hf_hub::api::sync::Api; +# use candle::Device; +# +# let api = Api::new().unwrap(); +# let repo = api.model("bert-base-uncased".to_string()); +# +# let weights = repo.get("model.safetensors").unwrap(); +use candle_nn::Linear; + +let weights = candle::safetensors::load(weights, &Device::Cpu); + +let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap(); +let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap(); + +let linear = Linear::new(weight, Some(bias)); + +let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap(); +let output = linear.forward(&input_ids); +``` + +For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example. diff --git a/candle-book/src/inference/serialization.md b/candle-book/src/inference/serialization.md index 0dfc62d3..133ff025 100644 --- a/candle-book/src/inference/serialization.md +++ b/candle-book/src/inference/serialization.md @@ -1 +1,3 @@ # Serialization + +Once you have a r |