diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 19:31:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 18:31:45 +0100 |
commit | fc265d9dcfc13ee0b03f4a09537a9e7156b29231 (patch) | |
tree | c94485d92647d340cb6e2a1eebc2c3aad4692d60 /candle-examples | |
parent | 2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (diff) | |
download | candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.gz candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.bz2 candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.zip |
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion.
* Add the avg-pool2d operation on cpu.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 10 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 14 |
2 files changed, 10 insertions, 14 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index 227660b1..ac9843f7 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -103,7 +103,7 @@ impl ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; - token_embedding + position_embedding + token_embedding.broadcast_add(&position_embedding) } } @@ -161,9 +161,9 @@ impl ClipAttention { let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; let src_len = key_states.dim(1)?; - let attn_weights = - (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))? - + causal_attention_mask)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; let attn_weights = attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; @@ -287,7 +287,7 @@ impl ClipTextTransformer { // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 2203b03a..d8327c0e 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -57,13 +57,9 @@ struct Args { #[arg(long, value_name = "FILE")] vae_weights: Option<String>, - #[arg( - long, - value_name = "FILE", - default_value = "data/bpe_simple_vocab_16e6.txt" - )] - /// The file specifying the vocabulary to used for tokenization. - vocab_file: String, + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for tokenization. + tokenizer: String, /// The size of the sliced attention or 0 for automatic slicing (disabled by default) #[arg(long)] @@ -165,7 +161,7 @@ fn run(args: Args) -> Result<()> { height, width, n_steps, - vocab_file, + tokenizer, final_image, sliced_attention_size, num_samples, @@ -184,7 +180,7 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; - let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; println!("Running with prompt \"{prompt}\"."); let tokens = tokenizer .encode(prompt, true) |