diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-14 23:24:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 22:24:56 +0100 |
commit | 130fe5a087715fc4d7bf9b581ca7c11378736ac5 (patch) | |
tree | effd5d92b1dddace769b8e1944eab97a8364c84f /candle-nn/src/ops.rs | |
parent | 91ec546febee4c6333cd65d95e8fd09e94499024 (diff) | |
download | candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.gz candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.bz2 candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.zip |
Add the upblocks. (#853)
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r-- | candle-nn/src/ops.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index adf1451c..c4055792 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -44,6 +44,10 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { (xs.neg()?.exp()? + 1.0)?.recip() } +pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> { + xs.relu()?.minimum(&(xs * negative_slope)?) +} + pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> { // This implementation is inefficient as it stores the full mask for the backward pass. // Instead we could just store the seed and have a specialized kernel that would both |