summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/layer_norm.rs9
-rw-r--r--candle-nn/src/lib.rs4
2 files changed, 12 insertions, 1 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index b7dd61cb..468fe24d 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -155,6 +155,15 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
})
}
+pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
+ let config = LayerNormConfig {
+ eps,
+ remove_mean: true,
+ affine: false,
+ };
+ layer_norm(size, config, vb)
+}
+
/// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm);
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index eb3cde4a..2113566d 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -46,7 +46,9 @@ pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
-pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
+pub use layer_norm::{
+ layer_norm, layer_norm_no_bias, rms_norm, LayerNorm, LayerNormConfig, RmsNorm,
+};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};