diff options
Diffstat (limited to 'candle-nn/src/init.rs')
-rw-r--r-- | candle-nn/src/init.rs | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs index 762f0ef1..25702d52 100644 --- a/candle-nn/src/init.rs +++ b/candle-nn/src/init.rs @@ -1,7 +1,7 @@ //! Variable initialization. // This is based on: // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# -use candle::Shape; +use candle::{DType, Device, Result, Shape, Tensor, Var}; /// Number of features as input or output of a layer. /// In Kaiming initialization, choosing `FanIn` preserves @@ -91,11 +91,11 @@ pub enum Init { fan: FanInOut, non_linearity: NonLinearity, }, - - /// Orthogonal initialization - Orthogonal { gain: f64 }, } +pub const ZERO: Init = Init::Const(0.); +pub const ONE: Init = Init::Const(1.); + pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { dist: NormalOrUniform::Uniform, fan: FanInOut::FanIn, @@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming { fan: FanInOut::FanIn, non_linearity: NonLinearity::ReLU, }; + +impl Init { + /// Creates a new tensor with the specified shape, device, and initialization. + pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> { + match self { + Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device), + Self::Const(v) if *v == 1. => Var::ones(s, dtype, device), + Self::Const(cst) => { + Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?) + } + Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device), + Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device), + Self::Kaiming { + dist, + fan, + non_linearity, + } => { + let s = s.into(); + let fan = fan.for_shape(&s); + let gain = non_linearity.gain(); + let std = gain / (fan as f64).sqrt(); + match dist { + NormalOrUniform::Uniform => { + let bound = 3f64.sqrt() * std; + Var::rand_f64(-bound, bound, s, dtype, device) + } + NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device), + } + } + } + } +} |