summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 22:37:34 +0100
committerGitHub <noreply@github.com>2023-07-10 22:37:34 +0100
commitb46c28a2ac2a88387590c65a2efef028f010b29e (patch)
treec43a08c46a537c6041f55fbe6cbbb105299aee9e /candle-examples/examples/bert
parent1aa7fbbc33ecd0ad3ce8698220d9eae434db50ba (diff)
downloadcandle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.gz
candle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.bz2
candle-b46c28a2ac2a88387590c65a2efef028f010b29e.zip
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r--candle-examples/examples/bert/main.rs114
1 files changed, 42 insertions, 72 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 3871c752..d0d600ee 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -109,14 +109,14 @@ impl Config {
}
}
-fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
-fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
- let bias = vb.get(size2, &format!("{p}.bias"))?;
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
@@ -135,17 +135,11 @@ impl Dropout {
}
}
-fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
- let (weight, bias) = match (
- vb.get(size, &format!("{p}.weight")),
- vb.get(size, &format!("{p}.bias")),
- ) {
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+ let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
- if let (Ok(weight), Ok(bias)) = (
- vb.get(size, &format!("{p}.gamma")),
- vb.get(size, &format!("{p}.beta")),
- ) {
+ if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
@@ -167,33 +161,29 @@ struct BertEmbeddings {
}
impl BertEmbeddings {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
- &format!("{p}.word_embeddings"),
- vb,
+ vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
- &format!("{p}.position_embeddings"),
- vb,
+ vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
- &format!("{p}.token_type_embeddings"),
- vb,
+ vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
- &format!("{p}.LayerNorm"),
- vb,
+ vb.pp("LayerNorm"),
)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
- let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
+ let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings,
@@ -233,14 +223,14 @@ struct BertSelfAttention {
}
impl BertSelfAttention {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size;
- let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
- let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
- let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
+ let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
+ let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
+ let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
Ok(Self {
query,
key,
@@ -289,18 +279,12 @@ struct BertSelfOutput {
}
impl BertSelfOutput {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(
- config.hidden_size,
- config.hidden_size,
- &format!("{p}.dense"),
- vb,
- )?;
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
- &format!("{p}.LayerNorm"),
- vb,
+ vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@@ -324,9 +308,9 @@ struct BertAttention {
}
impl BertAttention {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
- let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
- let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
+ let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
@@ -347,13 +331,8 @@ struct BertIntermediate {
}
impl BertIntermediate {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(
- config.hidden_size,
- config.intermediate_size,
- &format!("{p}.dense"),
- vb,
- )?;
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
Ok(Self {
dense,
intermediate_act: config.hidden_act,
@@ -375,18 +354,12 @@ struct BertOutput {
}
impl BertOutput {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(
- config.intermediate_size,
- config.hidden_size,
- &format!("{p}.dense"),
- vb,
- )?;
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
- &format!("{p}.LayerNorm"),
- vb,
+ vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@@ -411,10 +384,10 @@ struct BertLayer {
}
impl BertLayer {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
- let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
- let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
- let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention = BertAttention::load(vb.pp("attention"), config)?;
+ let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
+ let output = BertOutput::load(vb.pp("output"), config)?;
Ok(Self {
attention,
intermediate,
@@ -441,12 +414,9 @@ struct BertEncoder {
}
impl BertEncoder {
- fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
- .map(|index| {
- let p = format!("{p}.layer.{index}");
- BertLayer::load(&p, vb, config)
- })
+ .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers })
}
@@ -469,17 +439,17 @@ struct BertModel {
}
impl BertModel {
- fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, encoder) = match (
- BertEmbeddings::load("embeddings", vb, config),
- BertEncoder::load("encoder", vb, config),
+ BertEmbeddings::load(vb.pp("embeddings"), config),
+ BertEncoder::load(vb.pp("encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
- BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
+ BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+ BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
) {
(embeddings, encoder)
} else {
@@ -493,7 +463,7 @@ impl BertModel {
Ok(Self {
embeddings,
encoder,
- device: vb.device.clone(),
+ device: vb.device().clone(),
})
}
@@ -576,7 +546,7 @@ impl Args {
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
- let model = BertModel::load(&vb, &config)?;
+ let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}