diff options
Diffstat (limited to 'candle-nn/src/linear.rs')
-rw-r--r-- | candle-nn/src/linear.rs | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 943011c9..a0bd925a 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -17,7 +17,7 @@ //! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]); //! # Ok(()) } //! ``` -use candle::Tensor; +use candle::{Result, Tensor}; #[derive(Debug)] pub struct Linear { @@ -42,3 +42,24 @@ impl Linear { } } } + +/// Create or initialize a new linear layer. +/// +/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. +pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_or_init(out_dim, "bias", init_bs)?; + Ok(Linear::new(ws, Some(bs))) +} + +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + Ok(Linear::new(ws, None)) +} |