summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-09 10:46:09 +0100
committerGitHub <noreply@github.com>2023-09-09 10:46:09 +0100
commit3cd7e7b51dc7bf49215b136702f5bc3cd4642144 (patch)
treed024ff2edebb0e7af84d471d6daaa2be90b12c9a /candle-examples
parent722c50bb0ce18d749edcf86268238e5d9c9ee57e (diff)
downloadcandle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.tar.gz
candle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.tar.bz2
candle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.zip
Fuse the rel-pos additions via a custom-op. (#786)
* Fuse the rel-pos additions via a custom-op. * Run with rayon. * Add more tracing.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/Cargo.toml1
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs91
2 files changed, 86 insertions, 6 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 9035eae0..6f8792a3 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -24,6 +24,7 @@ intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
+rayon = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index f997170d..76cd15d0 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -34,6 +34,70 @@ impl Module for PatchEmbed {
}
}
+// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
+// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
+// (attn.reshape((b, q_h, q_w, k_h, k_w))?
+// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+// .reshape((b, q_h * q_w, k_h * k_w))
+// Ideally we would perform this operation in place but this is not supported in candle at the
+// moment. We should also investigate using f16 rather than f32.
+struct Add3(usize, usize, usize, usize, usize);
+impl candle::CustomOp3 for Add3 {
+ fn name(&self) -> &'static str {
+ "add3"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &candle::CpuStorage,
+ l1: &candle::Layout,
+ s2: &candle::CpuStorage,
+ l2: &candle::Layout,
+ s3: &candle::CpuStorage,
+ l3: &candle::Layout,
+ ) -> Result<(candle::CpuStorage, candle::Shape)> {
+ use rayon::prelude::*;
+
+ let Add3(b, q_h, q_w, k_h, k_w) = *self;
+ let s1 = s1.as_slice::<f32>()?;
+ let s1 = match l1.contiguous_offsets() {
+ None => candle::bail!("input1 has to be contiguous"),
+ Some((o1, o2)) => &s1[o1..o2],
+ };
+ let s2 = s2.as_slice::<f32>()?;
+ let s2 = match l2.contiguous_offsets() {
+ None => candle::bail!("input2 has to be contiguous"),
+ Some((o1, o2)) => &s2[o1..o2],
+ };
+ let s3 = s3.as_slice::<f32>()?;
+ let s3 = match l3.contiguous_offsets() {
+ None => candle::bail!("input3 has to be contiguous"),
+ Some((o1, o2)) => &s3[o1..o2],
+ };
+ let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
+ dst.par_chunks_exact_mut(k_h * k_w)
+ .enumerate()
+ .for_each(|(b_idx, dst)| {
+ let s1_idx = b_idx * k_h * k_w;
+ let s2_idx = b_idx * k_h;
+ let s3_idx = b_idx * k_w;
+ for h_idx in 0..k_h {
+ let s1_idx = s1_idx + h_idx * k_w;
+ let s2_idx = s2_idx + h_idx;
+ let dst_idx = h_idx * k_w;
+ for w_idx in 0..k_w {
+ let s1_idx = s1_idx + w_idx;
+ let s3_idx = s3_idx + w_idx;
+ let dst_idx = dst_idx + w_idx;
+ dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
+ }
+ }
+ });
+ let dst = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
+ }
+}
+
#[derive(Debug)]
struct Attention {
qkv: crate::Linear,
@@ -42,6 +106,7 @@ struct Attention {
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
span: tracing::Span,
+ span_matmul: tracing::Span,
span_rel_pos: tracing::Span,
span_softmax: tracing::Span,
}
@@ -56,6 +121,7 @@ impl Attention {
vb: VarBuilder,
) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
@@ -76,6 +142,7 @@ impl Attention {
scale,
rel_pos_hw,
span,
+ span_matmul,
span_rel_pos,
span_softmax,
})
@@ -101,10 +168,16 @@ impl Attention {
.transpose(1, 2)? // -> bwhc
.contiguous()?
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
- .transpose(1, 2)?;
- (attn.reshape((b, q_h, q_w, k_h, k_w))?
- + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
- .reshape((b, q_h * q_w, k_h * k_w))
+ .transpose(1, 2)?
+ .contiguous()?;
+ if attn.device().is_cpu() {
+ let op = Add3(b, q_h, q_w, k_h, k_w);
+ attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
+ } else {
+ (attn.reshape((b, q_h, q_w, k_h, k_w))?
+ + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+ .reshape((b, q_h * q_w, k_h * k_w))
+ }
}
None => Ok(attn),
}
@@ -149,7 +222,10 @@ impl Module for Attention {
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
- let attn = (&q * self.scale)?.matmul(&k.t()?)?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ (&q * self.scale)?.matmul(&k.t()?)?
+ };
let attn = {
let _enter = self.span_rel_pos.enter();
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
@@ -158,7 +234,10 @@ impl Module for Attention {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
- let attn = attn.matmul(&v)?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ attn.matmul(&v)?
+ };
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?