diff options
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 76 |
1 files changed, 27 insertions, 49 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 1300e7cb..631ff280 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -4,27 +4,21 @@ use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; -fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) + Some(vb.get(size2, "bias")?) } else { None }; Ok(Linear::new(weight, bias)) } -fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { return Err(err.into()); @@ -50,8 +44,8 @@ impl Dropout { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } @@ -164,14 +158,14 @@ struct FalconRotaryEmbedding { } impl FalconRotaryEmbedding { - fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> { + fn load(device: &Device, cfg: &Config) -> Result<Self> { let head_dim = cfg.head_dim(); let inv_freq: Vec<_> = (0..head_dim) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) .collect(); Ok(Self { - inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?, + inv_freq: Tensor::new(inv_freq.as_slice(), device)?, cache: None, }) } @@ -237,9 +231,9 @@ struct FalconAttention { } impl FalconAttention { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let maybe_rotary = if cfg.rotary() { - let rotary = FalconRotaryEmbedding::load(vb, cfg)?; + let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; Some(rotary) } else { None @@ -251,20 +245,8 @@ impl FalconAttention { } else { 3 * hidden_size }; - let query_key_value = linear( - hidden_size, - qkv_out_dim, - cfg.bias, - &format!("{p}.query_key_value"), - vb, - )?; - let dense = linear( - hidden_size, - hidden_size, - cfg.bias, - &format!("{p}.dense"), - vb, - )?; + let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; + let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; Ok(Self { query_key_value, dense, @@ -367,11 +349,11 @@ struct FalconMlp { } impl FalconMlp { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let h = cfg.hidden_size; let b = cfg.bias; - let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; - let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; + let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; + let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; let dropout = Dropout::new(cfg.hidden_dropout); Ok(Self { dense_h_to_4h, @@ -397,23 +379,21 @@ struct FalconDecoderLayer { } impl FalconDecoderLayer { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { - let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; let inp_layernorm = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - &format!("{p}.input_layernorm"), - vb, + vb.pp("input_layernorm"), )?; - let self_attention = FalconAttention::load(&format!("{p}.self_attention"), vb, cfg)?; + let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; let post_attention_layernorm = if cfg.parallel_attn { None } else { let ln = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - &format!("{p}.post_attention_layernorm"), - vb, + vb.pp("post_attention_layernorm"), )?; Some(ln) }; @@ -480,23 +460,21 @@ impl Falcon { &self.config } - pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> { + pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { let word_embeddings = embedding( cfg.vocab_size, cfg.hidden_size, - "transformer.word_embeddings", - vb, + vb.pp("transformer.word_embeddings"), )?; let blocks = (0..cfg.num_hidden_layers) - .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) + .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) .collect::<Result<Vec<_>>>()?; let ln_f = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, - "transformer.ln_f", - vb, + vb.pp("transformer.ln_f"), )?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; Ok(Self { word_embeddings, blocks, |