diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/activation.rs | 61 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 3 | ||||
-rw-r--r-- | candle-nn/src/optim.rs | 8 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 10 |
5 files changed, 78 insertions, 6 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index a2650634..80b750ed 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -1,4 +1,4 @@ -use candle::Tensor; +use candle::{Result, Tensor}; use serde::Deserialize; #[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] @@ -21,7 +21,7 @@ pub enum Activation { } impl super::Module for Activation { - fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { match self { Self::Gelu => xs.gelu_erf(), // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 @@ -40,3 +40,60 @@ impl super::Module for Activation { } } } + +#[derive(Clone, Debug)] +pub struct PReLU { + weight: Tensor, + is_scalar: bool, +} + +impl PReLU { + pub fn new(weight: Tensor, is_scalar: bool) -> Self { + Self { weight, is_scalar } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn is_scalar(&self) -> bool { + self.is_scalar + } +} + +impl candle::Module for PReLU { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let weight = if self.is_scalar { + self.weight.reshape(())? + } else if xs.rank() >= 2 { + let num_channels = xs.dim(1)?; + let num_weights = self.weight.elem_count(); + if num_weights != num_channels { + candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}") + } + let mut s = vec![1; xs.rank()]; + s[1] = self.weight.elem_count(); + self.weight.reshape(s)? + } else { + self.weight.clone() + }; + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)? + } +} + +/// Create or initialize a new PReLU layer. +/// +/// This uses some default name for weights, namely `"weight"`. +/// # Arguments +/// +/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and +/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward` +/// function, the input tensor shape `s` should either be one dimension with this number of +/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number. +pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> { + let init_ws = crate::init::Init::Const(0.25); + // When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1. + let ws = vs.get_with_hints((num_channels.unwrap_or(1),), "weight", init_ws)?; + Ok(PReLU::new(ws, num_channels.is_none())) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 52d8f0c5..8f00e54c 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -15,7 +15,7 @@ pub mod sequential; pub mod var_builder; pub mod var_map; -pub use activation::Activation; +pub use activation::{prelu, Activation, PReLU}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{ conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d, diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 94632296..59a4db8a 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -56,7 +56,7 @@ impl super::Module for Linear { /// Create or initialize a new linear layer. /// -/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. +/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; @@ -69,6 +69,7 @@ pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Li Ok(Linear::new(ws, Some(bs))) } +/// Create or initialize a new linear layer without biases. pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 7704bb48..2c671fc5 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -190,4 +190,12 @@ impl AdamW { }; Self::new(vars, params) } + + pub fn params(&self) -> &ParamsAdamW { + &self.params + } + + pub fn set_params(&mut self, params: ParamsAdamW) { + self.params = params; + } } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cbd238dd..83c86a6f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -40,7 +40,7 @@ struct TensorData<B: Backend> { /// A trait that defines how tensor data is retrieved. /// /// Typically this would use disk storage in some specific format, or random initialization. -/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most +/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most /// of the time. The main restriction is that it doesn't allow for specific args (besides /// initialization hints). pub trait Backend: Send + Sync { @@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors { fn get( &self, - _target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor> { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + let Shard { dim, rank, |