diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/embeddings.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/embeddings.rs | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs new file mode 100644 index 00000000..f8a4f351 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +pub struct TimestepEmbedding { + linear_1: nn::Linear, + linear_2: nn::Linear, +} + +impl TimestepEmbedding { + // act_fn: "silu" + pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> { + let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?; + let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?; + Ok(Self { linear_1, linear_2 }) + } +} + +impl TimestepEmbedding { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; + self.linear_2.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Timesteps { + num_channels: usize, + flip_sin_to_cos: bool, + downscale_freq_shift: f64, +} + +impl Timesteps { + pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { + Self { + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + } + } +} + +impl Timesteps { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let half_dim = (self.num_channels / 2) as u32; + let exponent = + (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?; + let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; + let emb = exponent.exp()?; + // emb = timesteps[:, None].float() * emb[None, :] + let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?; + let (cos, sin) = (emb.cos()?, emb.sin()?); + let emb = if self.flip_sin_to_cos { + Tensor::cat(&[&cos, &sin], D::Minus1)? + } else { + Tensor::cat(&[&sin, &cos], D::Minus1)? + }; + if self.num_channels % 2 == 1 { + crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None) + } else { + Ok(emb) + } + } +} |