diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-22 09:35:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-22 09:35:28 +0100 |
commit | c753f72c8552ba3e108bd3f1a04971e8abbf3012 (patch) | |
tree | dbd3f076b9c01811dd58ce6e30122d594b617b6f /candle-nn | |
parent | 8013b50829c4256d2a04b7b1acd3de90d9a95650 (diff) | |
download | candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.gz candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.bz2 candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.zip |
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit.
* Fix the cuda tests.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 23 |
2 files changed, 19 insertions, 6 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 3d0e6939..1bcb78d9 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; -pub use linear::{linear, linear_no_bias, Linear}; +pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 59a4db8a..96409042 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -57,21 +57,34 @@ impl super::Module for Linear { /// Create or initialize a new linear layer. /// /// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. -pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { +pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let bound = 1. / (in_dim as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_dim, "bias", init_bs)?; + let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } /// Create or initialize a new linear layer without biases. -pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(Linear::new(ws, None)) } + +pub fn linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: crate::VarBuilder, +) -> Result<Linear> { + if bias { + linear(in_dim, out_dim, vb) + } else { + linear_no_bias(in_dim, out_dim, vb) + } +} |