summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/Cargo.toml3
-rw-r--r--candle-examples/examples/wuerstchen/main.rs14
-rw-r--r--candle-transformers/Cargo.toml2
-rw-r--r--candle-transformers/src/models/wuerstchen/attention_processor.rs50
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs3
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs22
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs3
7 files changed, 85 insertions, 12 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index cf8f0021..0e2e8093 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -13,7 +13,6 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
candle-nn = { path = "../candle-nn", version = "0.2.3" }
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
cudarc = { workspace = true, optional = true }
@@ -51,7 +50,7 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
-flash-attn = ["cuda", "dep:candle-flash-attn"]
+flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
index aaa9b78a..95f3b8f4 100644
--- a/candle-examples/examples/wuerstchen/main.rs
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -41,6 +41,9 @@ struct Args {
#[arg(long)]
tracing: bool,
+ #[arg(long)]
+ use_flash_attn: bool,
+
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
@@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
- /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
- /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
+ /* c_in */ PRIOR_CIN,
+ /* c */ 1536,
+ /* c_cond */ 1280,
+ /* c_r */ 64,
+ /* depth */ 32,
+ /* nhead */ 24,
+ args.use_flash_attn,
+ vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
@@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
+ args.use_flash_attn,
vb,
)?
};
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 2faadad9..a3115c2b 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -12,6 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
candle-nn = { path = "../candle-nn", version = "0.2.3" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
@@ -26,4 +27,5 @@ wav = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
+flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs
index 3f1a72eb..0b90cb9d 100644
--- a/candle-transformers/src/models/wuerstchen/attention_processor.rs
+++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs
@@ -11,10 +11,33 @@ pub struct Attention {
to_out: Linear,
heads: usize,
scale: f64,
+ use_flash_attn: bool,
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
}
impl Attention {
- pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result<Self> {
+ pub fn new(
+ query_dim: usize,
+ heads: usize,
+ dim_head: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
let inner_dim = dim_head * heads;
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
@@ -28,6 +51,7 @@ impl Attention {
to_out,
scale,
heads,
+ use_flash_attn,
})
}
@@ -62,8 +86,28 @@ impl Attention {
let key = self.head_to_batch_dim(&key)?;
let value = self.head_to_batch_dim(&value)?;
- let attn_prs = self.get_attention_scores(&query, &key)?;
- let xs = attn_prs.matmul(&value)?;
+ let xs = if self.use_flash_attn {
+ let init_dtype = query.dtype();
+ let q = query
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let k = key
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let v = value
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ flash_attn(&q, &k, &v, self.scale as f32, false)?
+ .transpose(1, 2)?
+ .squeeze(0)?
+ .to_dtype(init_dtype)?
+ } else {
+ let attn_prs = self.get_attention_scores(&query, &key)?;
+ attn_prs.matmul(&value)?
+ };
let xs = self.batch_to_head_dim(&xs)?;
self.to_out
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
index 8416a1f1..c89ec919 100644
--- a/candle-transformers/src/models/wuerstchen/common.rs
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -174,10 +174,11 @@ impl AttnBlock {
c_cond: usize,
nhead: usize,
self_attn: bool,
+ use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c)?;
- let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
+ let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
self_attn,
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 501a2776..64a48c8a 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -88,6 +88,7 @@ pub struct WDiffNeXt {
}
impl WDiffNeXt {
+ #[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c_out: usize,
@@ -95,6 +96,7 @@ impl WDiffNeXt {
c_cond: usize,
clip_embd: usize,
patch_size: usize,
+ use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
@@ -169,8 +171,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
- let attn_block =
- AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
layer_i += 1;
Some(attn_block)
};
@@ -208,8 +216,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
- let attn_block =
- AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
layer_i += 1;
Some(attn_block)
};
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
index 168b70a6..97ccf0e2 100644
--- a/candle-transformers/src/models/wuerstchen/prior.rs
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -21,6 +21,7 @@ pub struct WPrior {
}
impl WPrior {
+ #[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c: usize,
@@ -28,6 +29,7 @@ impl WPrior {
c_r: usize,
depth: usize,
nhead: usize,
+ use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
@@ -44,6 +46,7 @@ impl WPrior {
c,
nhead,
true,
+ use_flash_attn,
vb.pp(format!("blocks.{}", 3 * index + 2)),
)?;
blocks.push(Block {