summaryrefslogtreecommitdiff
path: root/candle-nn/src/activation.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/activation.rs')
-rw-r--r--candle-nn/src/activation.rs50
1 files changed, 48 insertions, 2 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index a2650634..8b9a8785 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -1,4 +1,4 @@
-use candle::Tensor;
+use candle::{Result, Tensor};
use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
@@ -21,7 +21,7 @@ pub enum Activation {
}
impl super::Module for Activation {
- fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu => xs.gelu_erf(),
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
@@ -40,3 +40,49 @@ impl super::Module for Activation {
}
}
}
+
+#[derive(Clone, Debug)]
+pub struct PReLU {
+ weight: Tensor,
+ is_scalar: bool,
+}
+
+impl PReLU {
+ pub fn new(weight: Tensor, is_scalar: bool) -> Self {
+ Self { weight, is_scalar }
+ }
+
+ pub fn weight(&self) -> &Tensor {
+ &self.weight
+ }
+
+ pub fn is_scalar(&self) -> bool {
+ self.is_scalar
+ }
+}
+
+impl candle::Module for PReLU {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let weight = if self.is_scalar {
+ self.weight.reshape(())?
+ } else {
+ self.weight.clone()
+ };
+ let zeros = xs.zeros_like()?;
+ xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)?
+ }
+}
+
+/// Create or initialize a new PReLU layer.
+///
+/// This uses some default name for weights, namely `"weight"`.
+/// # Arguments
+///
+/// * `num_parameters` - The number of parameters. Use `None` to have as single trainable value
+/// and `Some` for a 1D vector with the appropriate number of features.
+pub fn prelu(num_parameters: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> {
+ let init_ws = crate::init::Init::Const(0.25);
+ // When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1.
+ let ws = vs.get_with_hints((num_parameters.unwrap_or(1),), "weight", init_ws)?;
+ Ok(PReLU::new(ws, num_parameters.is_none()))
+}