summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/phi/main.rs17
-rw-r--r--candle-transformers/src/models/mixformer.rs40
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/stable_diffusion/resnet.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs27
-rw-r--r--candle-transformers/src/models/t5.rs71
-rw-r--r--candle-transformers/src/models/with_tracing.rs78
9 files changed, 140 insertions, 100 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 4b290cd8..25c7db98 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -70,7 +70,7 @@ impl TextGeneration {
}
let dt = start_gen.elapsed();
println!(
- "{sample_len} tokens generated ({:.3} token/s)",
+ "\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),
);
Ok(())
@@ -84,6 +84,10 @@ struct Args {
#[arg(long)]
cpu: bool,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
#[arg(long)]
prompt: String,
@@ -114,8 +118,19 @@ struct Args {
}
fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index 028c3567..61eaea54 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -1,3 +1,4 @@
+use crate::models::with_tracing::{linear, Embedding as E, Linear};
/// MixFormer model.
/// https://huggingface.co/microsoft/phi-1_5
/// https://arxiv.org/abs/2309.05463
@@ -58,12 +59,12 @@ impl Config {
#[derive(Debug)]
struct Embedding {
- wte: candle_nn::Embedding,
+ wte: E,
}
impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
+ let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
Ok(Self { wte })
}
}
@@ -143,16 +144,16 @@ impl RotaryEmbedding {
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
- fc1: candle_nn::Linear,
- fc2: candle_nn::Linear,
+ fc1: Linear,
+ fc2: Linear,
act: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
- let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
- let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
+ let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
+ let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
@@ -170,13 +171,13 @@ impl Module for MLP {
#[derive(Debug)]
struct CausalLMHead {
ln: candle_nn::LayerNorm,
- linear: candle_nn::Linear,
+ linear: Linear,
}
impl CausalLMHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
- let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
+ let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
Ok(Self { ln, linear })
}
}
@@ -192,20 +193,21 @@ impl Module for CausalLMHead {
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MHA {
- wqkv: candle_nn::Linear,
- out_proj: candle_nn::Linear,
+ wqkv: Linear,
+ out_proj: Linear,
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize,
softmax_scale: f64,
+ span: tracing::Span,
}
impl MHA {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let head_dim = cfg.n_embd / cfg.n_head;
let op_size = cfg.n_embd;
- let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
- let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
+ let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
+ let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
@@ -215,10 +217,12 @@ impl MHA {
kv_cache: None,
rotary_emb,
softmax_scale,
+ span: tracing::span!(tracing::Level::TRACE, "mha"),
})
}
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self
.wqkv
@@ -267,6 +271,7 @@ struct ParallelBlock {
ln: candle_nn::LayerNorm,
mixer: MHA,
mlp: MLP,
+ span: tracing::Span,
}
impl ParallelBlock {
@@ -274,10 +279,16 @@ impl ParallelBlock {
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
- Ok(Self { ln, mixer, mlp })
+ Ok(Self {
+ ln,
+ mixer,
+ mlp,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
+ })
}
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let residual = xs;
let xs = xs.apply(&self.ln)?;
let attn_outputs = self.mixer.forward(&xs)?;
@@ -291,6 +302,7 @@ pub struct MixFormerSequentialForCausalLM {
embedding: Embedding,
blocks: Vec<ParallelBlock>,
head: CausalLMHead,
+ span: tracing::Span,
}
impl MixFormerSequentialForCausalLM {
@@ -307,10 +319,12 @@ impl MixFormerSequentialForCausalLM {
embedding,
blocks,
head,
+ span: tracing::span!(tracing::Level::TRACE, "mixformer"),
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?;
for block in self.blocks.iter_mut() {
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 991ee201..0fbcaa07 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -11,4 +11,5 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod t5;
pub mod whisper;
+pub mod with_tracing;
pub mod wuerstchen;
diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs
index 0d818115..5df04a8b 100644
--- a/candle-transformers/src/models/stable_diffusion/resnet.rs
+++ b/candle-transformers/src/models/stable_diffusion/resnet.rs
@@ -4,7 +4,7 @@
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
-use super::utils::{conv2d, Conv2d};
+use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
index a3ed136e..f23bd425 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -4,7 +4,7 @@
//! timestep and return a denoised version of the input.
use super::embeddings::{TimestepEmbedding, Timesteps};
use super::unet_2d_blocks::*;
-use super::utils::{conv2d, Conv2d};
+use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
index 29510cef..18448427 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
@@ -4,7 +4,7 @@ use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
-use super::utils::{conv2d, Conv2d};
+use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index c62f17af..0c95cfef 100644
--- a/candle-transformers/src/models/stable_diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -1,5 +1,4 @@
use candle::{Device, Result, Tensor};
-use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
@@ -11,29 +10,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
-
-// Wrap the conv2d op to provide some tracing.
-#[derive(Debug)]
-pub struct Conv2d {
- inner: candle_nn::Conv2d,
- span: tracing::Span,
-}
-
-impl Conv2d {
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-pub fn conv2d(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- cfg: candle_nn::Conv2dConfig,
- vs: candle_nn::VarBuilder,
-) -> Result<Conv2d> {
- let span = tracing::span!(tracing::Level::TRACE, "conv2d");
- let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
- Ok(Conv2d { inner, span })
-}
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 539ae89b..c5d5724a 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -1,57 +1,12 @@
// T5 Text Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
+use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
-#[derive(Debug)]
-struct Embedding {
- inner: candle_nn::Embedding,
- span: tracing::Span,
-}
-
-impl Embedding {
- fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
- let inner = candle_nn::embedding(d1, d2, vb)?;
- let span = tracing::span!(tracing::Level::TRACE, "embedding");
- Ok(Self { inner, span })
- }
-
- fn embeddings(&self) -> &Tensor {
- self.inner.embeddings()
- }
-}
-
-impl Module for Embedding {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(xs)
- }
-}
-
-#[derive(Debug)]
-struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Linear {
- fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
- let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Ok(Self { inner, span })
- }
-}
-
-impl Module for Linear {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(xs)
- }
-}
-
fn default_relative_attention_max_distance() -> usize {
128
}
@@ -205,8 +160,8 @@ struct T5DenseActDense {
impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
- let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi,
wo,
@@ -237,9 +192,9 @@ struct T5DenseGatedActDense {
impl T5DenseGatedActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
- let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
- let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi_0,
wi_1,
@@ -334,10 +289,10 @@ impl T5Attention {
cfg: &Config,
) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
- let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
- let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
- let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
- let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
+ let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if has_relative_attention_bias {
let emb = Embedding::new(
cfg.relative_attention_num_buckets,
@@ -772,7 +727,11 @@ impl T5ForConditionalGeneration {
let lm_head = if tie_word_embeddings {
None
} else {
- Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
+ Some(linear_no_bias(
+ cfg.d_model,
+ cfg.vocab_size,
+ vb.pp("lm_head"),
+ )?)
};
Ok(Self {
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
new file mode 100644
index 00000000..0a2d65b9
--- /dev/null
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -0,0 +1,78 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::embedding(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ pub fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let inner = candle_nn::linear(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+// Wrap the conv2d op to provide some tracing.
+#[derive(Debug)]
+pub struct Conv2d {
+ inner: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl Conv2d {
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: candle_nn::Conv2dConfig,
+ vs: candle_nn::VarBuilder,
+) -> Result<Conv2d> {
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d");
+ let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
+ Ok(Conv2d { inner, span })
+}