summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/init.rs40
-rw-r--r--candle-nn/src/lib.rs1
2 files changed, 37 insertions, 4 deletions
diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs
index 762f0ef1..25702d52 100644
--- a/candle-nn/src/init.rs
+++ b/candle-nn/src/init.rs
@@ -1,7 +1,7 @@
//! Variable initialization.
// This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
-use candle::Shape;
+use candle::{DType, Device, Result, Shape, Tensor, Var};
/// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves
@@ -91,11 +91,11 @@ pub enum Init {
fan: FanInOut,
non_linearity: NonLinearity,
},
-
- /// Orthogonal initialization
- Orthogonal { gain: f64 },
}
+pub const ZERO: Init = Init::Const(0.);
+pub const ONE: Init = Init::Const(1.);
+
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
@@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
+
+impl Init {
+ /// Creates a new tensor with the specified shape, device, and initialization.
+ pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
+ match self {
+ Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
+ Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
+ Self::Const(cst) => {
+ Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
+ }
+ Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
+ Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
+ Self::Kaiming {
+ dist,
+ fan,
+ non_linearity,
+ } => {
+ let s = s.into();
+ let fan = fan.for_shape(&s);
+ let gain = non_linearity.gain();
+ let std = gain / (fan as f64).sqrt();
+ match dist {
+ NormalOrUniform::Uniform => {
+ let bound = 3f64.sqrt() * std;
+ Var::rand_f64(-bound, bound, s, dtype, device)
+ }
+ NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
+ }
+ }
+ }
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index db01b067..d0b62dbb 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -15,6 +15,7 @@ pub mod vision;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
+pub use init::Init;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use optim::SGD;