summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-07 19:31:45 +0200
committerGitHub <noreply@github.com>2023-08-07 18:31:45 +0100
commitfc265d9dcfc13ee0b03f4a09537a9e7156b29231 (patch)
treec94485d92647d340cb6e2a1eebc2c3aad4692d60 /candle-examples
parent2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (diff)
downloadcandle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.gz
candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.bz2
candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.zip
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion. * Add the avg-pool2d operation on cpu.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs10
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs14
2 files changed, 10 insertions, 14 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 227660b1..ac9843f7 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
- token_embedding + position_embedding
+ token_embedding.broadcast_add(&position_embedding)
}
}
@@ -161,9 +161,9 @@ impl ClipAttention {
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
- let attn_weights =
- (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
- + causal_attention_mask)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
@@ -287,7 +287,7 @@ impl ClipTextTransformer {
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
- .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 2203b03a..d8327c0e 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -57,13 +57,9 @@ struct Args {
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,
- #[arg(
- long,
- value_name = "FILE",
- default_value = "data/bpe_simple_vocab_16e6.txt"
- )]
- /// The file specifying the vocabulary to used for tokenization.
- vocab_file: String,
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for tokenization.
+ tokenizer: String,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
@@ -165,7 +161,7 @@ fn run(args: Args) -> Result<()> {
height,
width,
n_steps,
- vocab_file,
+ tokenizer,
final_image,
sliced_attention_size,
num_samples,
@@ -184,7 +180,7 @@ fn run(args: Args) -> Result<()> {
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer
.encode(prompt, true)