summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-nn/src/linear.rs23
2 files changed, 19 insertions, 6 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 3d0e6939..1bcb78d9 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
-pub use linear::{linear, linear_no_bias, Linear};
+pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 59a4db8a..96409042 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -57,21 +57,34 @@ impl super::Module for Linear {
/// Create or initialize a new linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
-pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
+ let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
- let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
+ let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
/// Create or initialize a new linear layer without biases.
-pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
+ let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}
+
+pub fn linear_b(
+ in_dim: usize,
+ out_dim: usize,
+ bias: bool,
+ vb: crate::VarBuilder,
+) -> Result<Linear> {
+ if bias {
+ linear(in_dim, out_dim, vb)
+ } else {
+ linear_no_bias(in_dim, out_dim, vb)
+ }
+}