summaryrefslogtreecommitdiff
path: root/candle-book/src/inference/hub.md
blob: 8cf375d77314cd8750dac1c2551645129086fc3e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Using the hub

Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:

```bash
cargo add hf-hub
```

Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).


```rust
# extern crate candle;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle::Device;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights = repo.get("model.safetensors").unwrap();

let weights = candle::safetensors::load(weights, &Device::Cpu);
```

We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.


## Using async 

`hf-hub` comes with an async API.

```bash
cargo add hf-hub --features tokio
```

```rust,ignore
# extern crate candle;
# extern crate hf_hub;
use hf_hub::api::tokio::Api;
use candle::Device;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights = repo.get("model.safetensors").await.unwrap();

let weights = candle::safetensors::load(weights, &Device::Cpu);
```


## Using in a real model.

Now that we have our weights, we can use them in our bert architecture:

```rust
# extern crate candle;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
# use candle::Device;
# 
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
# 
# let weights = repo.get("model.safetensors").unwrap();
use candle_nn::Linear;

let weights = candle::safetensors::load(weights, &Device::Cpu);

let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();

let linear = Linear::new(weight, Some(bias));

let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids);
```

For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.