summaryrefslogtreecommitdiff
path: root/candle-nn/src/linear.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/linear.rs')
-rw-r--r--candle-nn/src/linear.rs23
1 files changed, 22 insertions, 1 deletions
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 943011c9..a0bd925a 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -17,7 +17,7 @@
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
//! # Ok(()) }
//! ```
-use candle::Tensor;
+use candle::{Result, Tensor};
#[derive(Debug)]
pub struct Linear {
@@ -42,3 +42,24 @@ impl Linear {
}
}
}
+
+/// Create or initialize a new linear layer.
+///
+/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
+pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
+ let bound = 1. / (in_dim as f64).sqrt();
+ let init_bs = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get_or_init(out_dim, "bias", init_bs)?;
+ Ok(Linear::new(ws, Some(bs)))
+}
+
+pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
+ Ok(Linear::new(ws, None))
+}