summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/model.rs')
-rw-r--r--candle-examples/examples/whisper/model.rs98
1 files changed, 58 insertions, 40 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 4d80c0c8..7015199d 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -1,8 +1,5 @@
-// We use anyhow rather than candle errors as it provides better support for getting the backtrace
-// back when using RUST_LIB_BACKTRACE=1.
-use anyhow::Result;
-use candle::{Device, Tensor};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
+use candle::{Device, Result, Tensor};
+use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@@ -22,6 +19,7 @@ pub struct Config {
}
impl Config {
+ #[allow(dead_code)]
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
@@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
+//
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = vb.get(size2, "bias")?;
- Ok(Linear::new(weight, Some(bias)))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn conv1d(
@@ -66,32 +80,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
-fn conv1d_no_bias(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- config: Conv1dConfig,
- vb: VarBuilder,
-) -> Result<Conv1d> {
- let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
- Ok(Conv1d::new(weight, None, config))
-}
-
-struct Dropout {
- pr: f64,
-}
-
-impl Dropout {
- fn new(pr: f64) -> Self {
- Self { pr }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- // TODO
- Ok(x.clone())
- }
-}
-
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
@@ -105,10 +93,12 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
+ span: tracing::Span,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@@ -119,10 +109,12 @@ impl MultiHeadAttention {
value,
out,
n_head,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
@@ -134,7 +126,7 @@ impl MultiHeadAttention {
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
- Ok(x.reshape(target_dims)?.transpose(1, 2)?)
+ x.reshape(target_dims)?.transpose(1, 2)
}
fn qkv_attention(
@@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
+ span: tracing::Span,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
@@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
mlp_linear1,
mlp_linear2,
mlp_ln,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
@@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
- Ok((x + mlp)?)
+ x + mlp
}
}
@@ -234,10 +230,16 @@ pub struct AudioEncoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
+ span: tracing::Span,
+ conv1_span: tracing::Span,
+ conv2_span: tracing::Span,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
+ let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
+ let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@@ -264,11 +266,22 @@ impl AudioEncoder {
positional_embedding,
blocks,
ln_post,
+ conv1_span,
+ conv2_span,
+ span,
})
}
+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = self.conv1.forward(x)?.gelu()?;
- let x = self.conv2.forward(&x)?.gelu()?;
+ let _enter = self.span.enter();
+ let x = {
+ let _enter = self.conv1_span.enter();
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _enter = self.conv2_span.enter();
+ self.conv2.forward(&x)?.gelu()?
+ };
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
@@ -288,10 +301,12 @@ pub struct TextDecoder {
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
+ span: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@@ -314,10 +329,12 @@ impl TextDecoder {
blocks,
ln,
mask,
+ span,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
@@ -354,6 +371,7 @@ impl Whisper {
})
}
+ #[allow(dead_code)]
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;