summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/embeddings.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-06 18:49:43 +0200
committerGitHub <noreply@github.com>2023-08-06 17:49:43 +0100
commitd34039e35267b3f4de83770f8da4ea31491bcec5 (patch)
tree6efc4796859a04223d2211303cc12b3a97321fcc /candle-examples/examples/stable-diffusion/embeddings.rs
parent93cfe5642f473889d1df62ccb8f1740f77523dd3 (diff)
downloadcandle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.gz
candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.bz2
candle-d34039e35267b3f4de83770f8da4ea31491bcec5.zip
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example. * Proper computation of the causal mask. * Add the chunk operation. * Work in progress: port the attention module. * Add some dummy modules for conv2d and group-norm, get the attention module to compile. * Re-enable the 2d convolution. * Add the embeddings module. * Add the resnet module. * Add the unet blocks. * Add the unet. * And add the variational auto-encoder. * Use the pad function from utils.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/embeddings.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs65
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
new file mode 100644
index 00000000..f8a4f351
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+pub struct TimestepEmbedding {
+ linear_1: nn::Linear,
+ linear_2: nn::Linear,
+}
+
+impl TimestepEmbedding {
+ // act_fn: "silu"
+ pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
+ let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
+ let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
+ Ok(Self { linear_1, linear_2 })
+ }
+}
+
+impl TimestepEmbedding {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
+ self.linear_2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Timesteps {
+ num_channels: usize,
+ flip_sin_to_cos: bool,
+ downscale_freq_shift: f64,
+}
+
+impl Timesteps {
+ pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
+ Self {
+ num_channels,
+ flip_sin_to_cos,
+ downscale_freq_shift,
+ }
+ }
+}
+
+impl Timesteps {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let half_dim = (self.num_channels / 2) as u32;
+ let exponent =
+ (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
+ let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
+ let emb = exponent.exp()?;
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
+ let (cos, sin) = (emb.cos()?, emb.sin()?);
+ let emb = if self.flip_sin_to_cos {
+ Tensor::cat(&[&cos, &sin], D::Minus1)?
+ } else {
+ Tensor::cat(&[&sin, &cos], D::Minus1)?
+ };
+ if self.num_channels % 2 == 1 {
+ crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
+ } else {
+ Ok(emb)
+ }
+ }
+}