diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-16 18:49:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-16 18:49:08 +0100 |
commit | c5f45887dc32bc6575c9d55135def391b949ce98 (patch) | |
tree | b738a79158c49e7968275d7a9a04423bd47681ee | |
parent | fa4590d7fd2b21ee811dba735851b6ec487f3cee (diff) | |
download | candle-c5f45887dc32bc6575c9d55135def391b949ce98.tar.gz candle-c5f45887dc32bc6575c9d55135def391b949ce98.tar.bz2 candle-c5f45887dc32bc6575c9d55135def391b949ce98.zip |
Add some tracing to the quantized example. (#473)
-rw-r--r-- | candle-examples/examples/ggml/main.rs | 66 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 1 |
2 files changed, 63 insertions, 4 deletions
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 7d6ec2ca..54bc5f57 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -5,7 +5,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::ggml_file::Content; -use candle::quantized::{QMatMul, QTensor}; +use candle::quantized::QTensor; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::Embedding; use candle_transformers::generation::LogitsProcessor; @@ -16,15 +16,22 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is "; struct RmsNorm { scale: Tensor, eps: f64, + span: tracing::Span, } impl RmsNorm { fn new(scale: QTensor) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let scale = scale.dequantize(&Device::Cpu)?; - Ok(Self { scale, eps: 1e-5 }) + Ok(Self { + scale, + eps: 1e-5, + span, + }) } fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let (b_sz, seq_len, hidden_size) = x.dims3()?; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; @@ -39,6 +46,25 @@ impl RmsNorm { } } +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Self { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Self { inner, span } + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, @@ -54,6 +80,9 @@ struct LayerWeights { cos: Tensor, sin: Tensor, kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { @@ -65,6 +94,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); let (b_sz, _, seq_len, n_embd) = x.dims4()?; let cos = self.cos.narrow(0, index_pos, seq_len)?; let sin = self.sin.narrow(0, index_pos, seq_len)?; @@ -78,6 +108,7 @@ impl LayerWeights { } fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_attn.enter(); let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?; let k = self.attention_wk.forward(x)?; @@ -127,6 +158,8 @@ struct ModelWeights { // TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K. output: candle_nn::Linear, masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, } struct WeightMap(HashMap<String, QTensor>); @@ -177,6 +210,9 @@ impl ModelWeights { let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); layers.push(LayerWeights { attention_wq: QMatMul::from_qtensor(attention_wq), attention_wk: QMatMul::from_qtensor(attention_wk), @@ -192,14 +228,21 @@ impl ModelWeights { cos: cos.clone(), sin: sin.clone(), kv_cache: None, + span_attn, + span_rot, + span_mlp, }) } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); Ok(Self { tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), layers, norm, output, masks: HashMap::new(), + span, + span_output, }) } @@ -219,6 +262,7 @@ impl ModelWeights { fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let (_b_sz, seq_len) = x.dims2()?; let mask = self.mask(seq_len)?; + let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() { let x = layer_in; @@ -228,6 +272,7 @@ impl ModelWeights { let x = (attn + residual)?; // MLP + let _enter = layer.span_mlp.enter(); let residual = &x; let x = layer.ffn_norm.forward(&x)?; let w1 = layer.feed_forward_w1.forward(&x)?; @@ -239,6 +284,7 @@ impl ModelWeights { } let x = self.norm.forward(&layer_in)?; let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); self.output.forward(&x) } } @@ -255,7 +301,7 @@ struct Args { prompt: Option<String>, /// The length of the sample to generate (in tokens). - #[arg(long, default_value_t = 100)] + #[arg(short = 'n', long, default_value_t = 100)] sample_len: usize, /// The tokenizer config in json format. @@ -269,6 +315,10 @@ struct Args { /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } impl Args { @@ -298,7 +348,17 @@ impl Args { } fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let mut file = std::fs::File::open(&args.model()?)?; let start = std::time::Instant::now(); diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index def5eb20..e3d2550e 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -89,7 +89,6 @@ fn main() -> Result<()> { let args = Args::parse(); let _guard = if args.tracing { - println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) |