summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-22 09:35:28 +0100
committerGitHub <noreply@github.com>2024-02-22 09:35:28 +0100
commitc753f72c8552ba3e108bd3f1a04971e8abbf3012 (patch)
treedbd3f076b9c01811dd58ce6e30122d594b617b6f /candle-nn
parent8013b50829c4256d2a04b7b1acd3de90d9a95650 (diff)
downloadcandle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.gz
candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.bz2
candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.zip
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-nn/src/linear.rs23
2 files changed, 19 insertions, 6 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 3d0e6939..1bcb78d9 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
-pub use linear::{linear, linear_no_bias, Linear};
+pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 59a4db8a..96409042 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -57,21 +57,34 @@ impl super::Module for Linear {
/// Create or initialize a new linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
-pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
+ let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
- let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
+ let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
/// Create or initialize a new linear layer without biases.
-pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
+pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
+ let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}
+
+pub fn linear_b(
+ in_dim: usize,
+ out_dim: usize,
+ bias: bool,
+ vb: crate::VarBuilder,
+) -> Result<Linear> {
+ if bias {
+ linear(in_dim, out_dim, vb)
+ } else {
+ linear_no_bias(in_dim, out_dim, vb)
+ }
+}