summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-14 23:24:56 +0200
committerGitHub <noreply@github.com>2023-09-14 22:24:56 +0100
commit130fe5a087715fc4d7bf9b581ca7c11378736ac5 (patch)
treeeffd5d92b1dddace769b8e1944eab97a8364c84f /candle-nn/src/ops.rs
parent91ec546febee4c6333cd65d95e8fd09e94499024 (diff)
downloadcandle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.gz
candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.bz2
candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.zip
Add the upblocks. (#853)
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index adf1451c..c4055792 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -44,6 +44,10 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
(xs.neg()?.exp()? + 1.0)?.recip()
}
+pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
+ xs.relu()?.minimum(&(xs * negative_slope)?)
+}
+
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
// This implementation is inefficient as it stores the full mask for the backward pass.
// Instead we could just store the seed and have a specialized kernel that would both