diff options
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/utils.rs')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/utils.rs | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs new file mode 100644 index 00000000..c62f17af --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -0,0 +1,39 @@ +use candle::{Device, Result, Tensor}; +use candle_nn::Module; + +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) +} + +// Wrap the conv2d op to provide some tracing. +#[derive(Debug)] +pub struct Conv2d { + inner: candle_nn::Conv2d, + span: tracing::Span, +} + +impl Conv2d { + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: candle_nn::Conv2dConfig, + vs: candle_nn::VarBuilder, +) -> Result<Conv2d> { + let span = tracing::span!(tracing::Level::TRACE, "conv2d"); + let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; + Ok(Conv2d { inner, span }) +} |