summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-16 18:49:08 +0100
committerGitHub <noreply@github.com>2023-08-16 18:49:08 +0100
commitc5f45887dc32bc6575c9d55135def391b949ce98 (patch)
treeb738a79158c49e7968275d7a9a04423bd47681ee
parentfa4590d7fd2b21ee811dba735851b6ec487f3cee (diff)
downloadcandle-c5f45887dc32bc6575c9d55135def391b949ce98.tar.gz
candle-c5f45887dc32bc6575c9d55135def391b949ce98.tar.bz2
candle-c5f45887dc32bc6575c9d55135def391b949ce98.zip
Add some tracing to the quantized example. (#473)
-rw-r--r--candle-examples/examples/ggml/main.rs66
-rw-r--r--candle-examples/examples/llama/main.rs1
2 files changed, 63 insertions, 4 deletions
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs
index 7d6ec2ca..54bc5f57 100644
--- a/candle-examples/examples/ggml/main.rs
+++ b/candle-examples/examples/ggml/main.rs
@@ -5,7 +5,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
-use candle::quantized::{QMatMul, QTensor};
+use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::Embedding;
use candle_transformers::generation::LogitsProcessor;
@@ -16,15 +16,22 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
struct RmsNorm {
scale: Tensor,
eps: f64,
+ span: tracing::Span,
}
impl RmsNorm {
fn new(scale: QTensor) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
- Ok(Self { scale, eps: 1e-5 })
+ Ok(Self {
+ scale,
+ eps: 1e-5,
+ span,
+ })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
@@ -39,6 +46,25 @@ impl RmsNorm {
}
}
+// QMatMul wrapper adding some tracing.
+struct QMatMul {
+ inner: candle::quantized::QMatMul,
+ span: tracing::Span,
+}
+
+impl QMatMul {
+ fn from_qtensor(qtensor: QTensor) -> Self {
+ let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
+ let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
+ Self { inner, span }
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
@@ -54,6 +80,9 @@ struct LayerWeights {
cos: Tensor,
sin: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
+ span_attn: tracing::Span,
+ span_rot: tracing::Span,
+ span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@@ -65,6 +94,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
@@ -78,6 +108,7 @@ impl LayerWeights {
}
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
@@ -127,6 +158,8 @@ struct ModelWeights {
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
output: candle_nn::Linear,
masks: HashMap<usize, Tensor>,
+ span: tracing::Span,
+ span_output: tracing::Span,
}
struct WeightMap(HashMap<String, QTensor>);
@@ -177,6 +210,9 @@ impl ModelWeights {
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
+ let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
+ let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
@@ -192,14 +228,21 @@ impl ModelWeights {
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
+ span_attn,
+ span_rot,
+ span_mlp,
})
}
+ let span = tracing::span!(tracing::Level::TRACE, "model");
+ let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output,
masks: HashMap::new(),
+ span,
+ span_output,
})
}
@@ -219,6 +262,7 @@ impl ModelWeights {
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
+ let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {
let x = layer_in;
@@ -228,6 +272,7 @@ impl ModelWeights {
let x = (attn + residual)?;
// MLP
+ let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let w1 = layer.feed_forward_w1.forward(&x)?;
@@ -239,6 +284,7 @@ impl ModelWeights {
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
+ let _enter = self.span_output.enter();
self.output.forward(&x)
}
}
@@ -255,7 +301,7 @@ struct Args {
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
- #[arg(long, default_value_t = 100)]
+ #[arg(short = 'n', long, default_value_t = 100)]
sample_len: usize,
/// The tokenizer config in json format.
@@ -269,6 +315,10 @@ struct Args {
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
}
impl Args {
@@ -298,7 +348,17 @@ impl Args {
}
fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let mut file = std::fs::File::open(&args.model()?)?;
let start = std::time::Instant::now();
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index def5eb20..e3d2550e 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -89,7 +89,6 @@ fn main() -> Result<()> {
let args = Args::parse();
let _guard = if args.tracing {
- println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)