From c753f72c8552ba3e108bd3f1a04971e8abbf3012 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 22 Feb 2024 09:35:28 +0100 Subject: Support for attention bias in gemma + refactor things a bit. (#1744) * Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests. --- candle-nn/src/lib.rs | 2 +- candle-nn/src/linear.rs | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) (limited to 'candle-nn') diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 3d0e6939..1bcb78d9 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; -pub use linear::{linear, linear_no_bias, Linear}; +pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 59a4db8a..96409042 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -57,21 +57,34 @@ impl super::Module for Linear { /// Create or initialize a new linear layer. /// /// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. -pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { +pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let bound = 1. / (in_dim as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_dim, "bias", init_bs)?; + let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } /// Create or initialize a new linear layer without biases. -pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(Linear::new(ws, None)) } + +pub fn linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: crate::VarBuilder, +) -> Result { + if bias { + linear(in_dim, out_dim, vb) + } else { + linear_no_bias(in_dim, out_dim, vb) + } +} -- cgit v1.2.3