diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 22:14:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 21:14:52 +0100 |
commit | 166bfd5847144abec227836e497b509625470535 (patch) | |
tree | 7b13e3dae76c0864a3cb107c98b3a88f24423af3 /candle-examples | |
parent | 1c062bf06ba504a076b329c965c625be0ec67c1d (diff) | |
download | candle-166bfd5847144abec227836e497b509625470535.tar.gz candle-166bfd5847144abec227836e497b509625470535.tar.bz2 candle-166bfd5847144abec227836e497b509625470535.zip |
Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.
* Fix the cuda kernel.
* Use the recip op in sigmoid.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 50ee48e9..90fe3f9a 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,7 +1,8 @@ -use candle::{Result, Tensor}; +use candle::{Device, Result, Tensor}; -pub fn sigmoid(_: &Tensor) -> Result<Tensor> { - todo!() +pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Add sigmoid as binary ops. + (xs.neg()?.exp()? - 1.0)?.recip() } pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { @@ -16,6 +17,13 @@ pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { todo!() } -pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> { - todo!() +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) } |