summaryrefslogtreecommitdiff
path: root/candle-nn/src/activation.rs
blob: 9554e68a020cfa18a020e866b52ae6c780271084 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use candle::Tensor;

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Activation {
    Gelu,
    Relu,
    Elu(f64),
}

impl Activation {
    pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
        match self {
            Self::Gelu => xs.gelu(),
            Self::Relu => xs.relu(),
            &Self::Elu(alpha) => xs.elu(alpha),
        }
    }
}