diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/utils.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 50ee48e9..90fe3f9a 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,7 +1,8 @@ -use candle::{Result, Tensor}; +use candle::{Device, Result, Tensor}; -pub fn sigmoid(_: &Tensor) -> Result<Tensor> { - todo!() +pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Add sigmoid as binary ops. + (xs.neg()?.exp()? - 1.0)?.recip() } pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { @@ -16,6 +17,13 @@ pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { todo!() } -pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> { - todo!() +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) } |