diff options
author | Evgeny Igumnov <igumnovnsk@gmail.com> | 2023-09-22 11:01:23 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-22 11:01:23 +0600 |
commit | 4ac6039a42b8125f7888709fb718bfd41a73f2ac (patch) | |
tree | f4bc165de51f258a9bf58cac4150c99e512fba01 /candle-transformers/src/models/stable_diffusion/utils.rs | |
parent | 52a60ca3ad3f7e7b6da8e915a5a052d5bef10999 (diff) | |
parent | a96878f2357fbcebf9db8747dcbb55bc8200d8ab (diff) | |
download | candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.gz candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.bz2 candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.zip |
Merge branch 'main' into book-trainin-simplified
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/utils.rs')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/utils.rs | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs new file mode 100644 index 00000000..c62f17af --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -0,0 +1,39 @@ +use candle::{Device, Result, Tensor}; +use candle_nn::Module; + +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) +} + +// Wrap the conv2d op to provide some tracing. +#[derive(Debug)] +pub struct Conv2d { + inner: candle_nn::Conv2d, + span: tracing::Span, +} + +impl Conv2d { + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: candle_nn::Conv2dConfig, + vs: candle_nn::VarBuilder, +) -> Result<Conv2d> { + let span = tracing::span!(tracing::Level::TRACE, "conv2d"); + let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; + Ok(Conv2d { inner, span }) +} |