summaryrefslogtreecommitdiff
path: root/candle-nn/src/layer_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/layer_norm.rs')
-rw-r--r--candle-nn/src/layer_norm.rs6
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 8f8544bb..668f9a4b 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -62,3 +62,9 @@ impl LayerNorm {
Ok(x)
}
}
+
+pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
+ let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
+ let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
+ Ok(LayerNorm::new(weight, bias, eps))
+}