summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_diffusion/utils.rs
diff options
context:
space:
mode:
authorEvgeny Igumnov <igumnovnsk@gmail.com>2023-09-22 11:01:23 +0600
committerGitHub <noreply@github.com>2023-09-22 11:01:23 +0600
commit4ac6039a42b8125f7888709fb718bfd41a73f2ac (patch)
treef4bc165de51f258a9bf58cac4150c99e512fba01 /candle-transformers/src/models/stable_diffusion/utils.rs
parent52a60ca3ad3f7e7b6da8e915a5a052d5bef10999 (diff)
parenta96878f2357fbcebf9db8747dcbb55bc8200d8ab (diff)
downloadcandle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.gz
candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.bz2
candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.zip
Merge branch 'main' into book-trainin-simplified
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/utils.rs')
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs39
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
new file mode 100644
index 00000000..c62f17af
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -0,0 +1,39 @@
+use candle::{Device, Result, Tensor};
+use candle_nn::Module;
+
+pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
+ if steps < 1 {
+ candle::bail!("cannot use linspace with steps {steps} <= 1")
+ }
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
+}
+
+// Wrap the conv2d op to provide some tracing.
+#[derive(Debug)]
+pub struct Conv2d {
+ inner: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl Conv2d {
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: candle_nn::Conv2dConfig,
+ vs: candle_nn::VarBuilder,
+) -> Result<Conv2d> {
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d");
+ let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
+ Ok(Conv2d { inner, span })
+}