blob: 9554e68a020cfa18a020e866b52ae6c780271084 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
use candle::Tensor;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Activation {
Gelu,
Relu,
Elu(f64),
}
impl Activation {
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
&Self::Elu(alpha) => xs.elu(alpha),
}
}
}
|