diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-10 08:50:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 08:50:09 +0100 |
commit | 9ce0f1c01097b5c4dd199d236c2157f2e0ca8682 (patch) | |
tree | 70c8f42941c3a3f041319314b6ea0a34593982be /candle-examples/examples/falcon | |
parent | bc3be6f9b07442abd0ddeab4979e5cc5fedcee78 (diff) | |
download | candle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.tar.gz candle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.tar.bz2 candle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.zip |
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate.
* Tweak the cuda dependencies.
* More cuda tweaks.
Diffstat (limited to 'candle-examples/examples/falcon')
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 113 |
1 files changed, 34 insertions, 79 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index e7c53e50..e22b7b47 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -1,5 +1,6 @@ use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D}; +use candle_nn::{LayerNorm, Linear}; use std::collections::HashMap; const MAX_SEQ_LEN: usize = 5000; @@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug)] -struct Linear { - weight: Tensor, - bias: Option<Tensor>, -} - -impl Linear { - fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) - } else { - None - }; - Ok(Self { weight, bias }) - } - - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let (bsize, _, _) = x.shape().r3()?; - let w = self.weight.broadcast_left(bsize)?.t()?; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } +fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = if bias { + Some(vb.get(size2, &format!("{p}.bias"))?) + } else { + None + }; + Ok(Linear::new(weight, bias)) } -#[derive(Debug)] -struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, -} - -impl LayerNorm { - fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - Self { weight, bias, eps } - } - - fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { - (weight, bias) - } else { - return Err(err.into()); - } +fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); } - }; - Ok(Self { weight, bias, eps }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let dtype = x.dtype(); - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; - let x = x.to_dtype(DType::F32)?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) } #[derive(Debug)] @@ -378,14 +333,14 @@ impl FalconAttention { } else { 3 * hidden_size }; - let query_key_value = Linear::load( + let query_key_value = linear( hidden_size, qkv_out_dim, cfg.bias, &format!("{p}.query_key_value"), vb, )?; - let dense = Linear::load( + let dense = linear( hidden_size, hidden_size, cfg.bias, @@ -497,8 +452,8 @@ impl FalconMlp { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { let h = cfg.hidden_size; let b = cfg.bias; - let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; - let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; + let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?; + let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?; let dropout = Dropout::new(cfg.hidden_dropout); Ok(Self { dense_h_to_4h, @@ -526,7 +481,7 @@ struct FalconDecoderLayer { impl FalconDecoderLayer { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?; - let inp_layernorm = LayerNorm::load( + let inp_layernorm = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, &format!("{p}.input_layernorm"), @@ -536,7 +491,7 @@ impl FalconDecoderLayer { let post_attention_layernorm = if cfg.parallel_attn { None } else { - let ln = LayerNorm::load( + let ln = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, &format!("{p}.post_attention_layernorm"), @@ -617,13 +572,13 @@ impl Falcon { let blocks = (0..cfg.num_hidden_layers) .map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg)) .collect::<Result<Vec<_>>>()?; - let ln_f = LayerNorm::load( + let ln_f = layer_norm( cfg.hidden_size, cfg.layer_norm_epsilon, "transformer.ln_f", vb, )?; - let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?; Ok(Self { word_embeddings, blocks, |