summaryrefslogtreecommitdiff
path: root/candle-examples/examples/falcon
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 08:50:09 +0100
committerGitHub <noreply@github.com>2023-07-10 08:50:09 +0100
commit9ce0f1c01097b5c4dd199d236c2157f2e0ca8682 (patch)
tree70c8f42941c3a3f041319314b6ea0a34593982be /candle-examples/examples/falcon
parentbc3be6f9b07442abd0ddeab4979e5cc5fedcee78 (diff)
downloadcandle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.tar.gz
candle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.tar.bz2
candle-9ce0f1c01097b5c4dd199d236c2157f2e0ca8682.zip
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
Diffstat (limited to 'candle-examples/examples/falcon')
-rw-r--r--candle-examples/examples/falcon/model.rs113
1 files changed, 34 insertions, 79 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index e7c53e50..e22b7b47 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -1,5 +1,6 @@
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
+use candle_nn::{LayerNorm, Linear};
use std::collections::HashMap;
const MAX_SEQ_LEN: usize = 5000;
@@ -61,80 +62,34 @@ impl<'a> VarBuilder<'a> {
}
}
-#[derive(Debug)]
-struct Linear {
- weight: Tensor,
- bias: Option<Tensor>,
-}
-
-impl Linear {
- fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
- let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
- let bias = if bias {
- Some(vb.get(size2, &format!("{p}.bias"))?)
- } else {
- None
- };
- Ok(Self { weight, bias })
- }
-
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let (bsize, _, _) = x.shape().r3()?;
- let w = self.weight.broadcast_left(bsize)?.t()?;
- let x = x.matmul(&w)?;
- match &self.bias {
- None => Ok(x),
- Some(bias) => x.broadcast_add(bias),
- }
- }
+fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
+ let bias = if bias {
+ Some(vb.get(size2, &format!("{p}.bias"))?)
+ } else {
+ None
+ };
+ Ok(Linear::new(weight, bias))
}
-#[derive(Debug)]
-struct LayerNorm {
- weight: Tensor,
- bias: Tensor,
- eps: f64,
-}
-
-impl LayerNorm {
- fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
- Self { weight, bias, eps }
- }
-
- fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
- let (weight, bias) = match (
- vb.get(size, &format!("{p}.weight")),
- vb.get(size, &format!("{p}.bias")),
- ) {
- (Ok(weight), Ok(bias)) => (weight, bias),
- (Err(err), _) | (_, Err(err)) => {
- if let (Ok(weight), Ok(bias)) = (
- vb.get(size, &format!("{p}.gamma")),
- vb.get(size, &format!("{p}.beta")),
- ) {
- (weight, bias)
- } else {
- return Err(err.into());
- }
+fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
+ let (weight, bias) = match (
+ vb.get(size, &format!("{p}.weight")),
+ vb.get(size, &format!("{p}.bias")),
+ ) {
+ (Ok(weight), Ok(bias)) => (weight, bias),
+ (Err(err), _) | (_, Err(err)) => {
+ if let (Ok(weight), Ok(bias)) = (
+ vb.get(size, &format!("{p}.gamma")),
+ vb.get(size, &format!("{p}.beta")),
+ ) {
+ (weight, bias)
+ } else {
+ return Err(err.into());
}
- };
- Ok(Self { weight, bias, eps })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let dtype = x.dtype();
- let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
- let x = x.to_dtype(DType::F32)?;
- let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
- let x = x.broadcast_sub(&mean_x)?;
- let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
- let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
- let x = x_normed
- .to_dtype(dtype)?
- .broadcast_mul(&self.weight)?
- .broadcast_add(&self.bias)?;
- Ok(x)
- }
+ }
+ };
+ Ok(LayerNorm::new(weight, bias, eps))
}
#[derive(Debug)]
@@ -378,14 +333,14 @@ impl FalconAttention {
} else {
3 * hidden_size
};
- let query_key_value = Linear::load(
+ let query_key_value = linear(
hidden_size,
qkv_out_dim,
cfg.bias,
&format!("{p}.query_key_value"),
vb,
)?;
- let dense = Linear::load(
+ let dense = linear(
hidden_size,
hidden_size,
cfg.bias,
@@ -497,8 +452,8 @@ impl FalconMlp {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let b = cfg.bias;
- let dense_h_to_4h = Linear::load(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
- let dense_4h_to_h = Linear::load(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
+ let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
+ let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
let dropout = Dropout::new(cfg.hidden_dropout);
Ok(Self {
dense_h_to_4h,
@@ -526,7 +481,7 @@ struct FalconDecoderLayer {
impl FalconDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
- let inp_layernorm = LayerNorm::load(
+ let inp_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
&format!("{p}.input_layernorm"),
@@ -536,7 +491,7 @@ impl FalconDecoderLayer {
let post_attention_layernorm = if cfg.parallel_attn {
None
} else {
- let ln = LayerNorm::load(
+ let ln = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
&format!("{p}.post_attention_layernorm"),
@@ -617,13 +572,13 @@ impl Falcon {
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
.collect::<Result<Vec<_>>>()?;
- let ln_f = LayerNorm::load(
+ let ln_f = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
"transformer.ln_f",
vb,
)?;
- let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
+ let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
Ok(Self {
word_embeddings,
blocks,