summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/batch_norm.rs215
-rw-r--r--candle-nn/src/encoding.rs150
-rw-r--r--candle-nn/src/lib.rs1
3 files changed, 311 insertions, 55 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 8cfc6740..856c2c7a 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -7,15 +7,21 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
-use candle::{DType, Result, Tensor};
+use candle::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
pub eps: f64,
pub remove_mean: bool,
+
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
/// parameter at all, 1 used for gamma and 0 for beta.
pub affine: bool,
+
+ /// Controls exponential moving average of running stats. Defaults to 0.1
+ ///
+ /// `running_stat * (1.0 - momentum) + stat * momentum`.
+ pub momentum: f64,
}
impl Default for BatchNormConfig {
@@ -24,6 +30,7 @@ impl Default for BatchNormConfig {
eps: 1e-5,
remove_mean: true,
affine: true,
+ momentum: 0.1,
}
}
}
@@ -32,23 +39,61 @@ impl From<f64> for BatchNormConfig {
fn from(eps: f64) -> Self {
Self {
eps,
- remove_mean: true,
- affine: true,
+ ..Default::default()
}
}
}
#[derive(Clone, Debug)]
pub struct BatchNorm {
- running_mean: Tensor,
- running_var: Tensor,
+ running_mean: Var,
+ running_var: Var,
weight_and_bias: Option<(Tensor, Tensor)>,
remove_mean: bool,
eps: f64,
- num_features: usize,
+ momentum: f64,
}
impl BatchNorm {
+ fn check_validity(&self, num_features: usize) -> Result<()> {
+ if self.eps < 0. {
+ candle::bail!("batch-norm eps cannot be negative {}", self.eps)
+ }
+ if !(0.0..=1.0).contains(&self.momentum) {
+ candle::bail!(
+ "batch-norm momentum must be between 0 and 1, is {}",
+ self.momentum
+ )
+ }
+ if self.running_mean.dims() != [num_features] {
+ candle::bail!(
+ "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
+ self.running_mean.shape(),
+ )
+ }
+ if self.running_var.dims() != [num_features] {
+ candle::bail!(
+ "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
+ self.running_var.shape(),
+ )
+ }
+ if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
+ if weight.dims() != [num_features] {
+ candle::bail!(
+ "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
+ weight.shape(),
+ )
+ }
+ if bias.dims() != [num_features] {
+ candle::bail!(
+ "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
+ bias.shape(),
+ )
+ }
+ }
+ Ok(())
+ }
+
pub fn new(
num_features: usize,
running_mean: Tensor,
@@ -57,29 +102,16 @@ impl BatchNorm {
bias: Tensor,
eps: f64,
) -> Result<Self> {
- if eps < 0. {
- candle::bail!("batch-norm eps cannot be negative {eps}")
- }
- if weight.dims() != [num_features] {
- candle::bail!(
- "batch-norm unexpected weight shape {:?} {num_features}",
- weight.shape()
- )
- }
- if bias.dims() != [num_features] {
- candle::bail!(
- "batch-norm unexpected bias shape {:?} {num_features}",
- bias.shape()
- )
- }
- Ok(Self {
- running_mean,
- running_var,
+ let out = Self {
+ running_mean: Var::from_tensor(&running_mean)?,
+ running_var: Var::from_tensor(&running_var)?,
weight_and_bias: Some((weight, bias)),
remove_mean: true,
eps,
- num_features,
- })
+ momentum: 0.1,
+ };
+ out.check_validity(num_features)?;
+ Ok(out)
}
pub fn new_no_bias(
@@ -88,25 +120,64 @@ impl BatchNorm {
running_var: Tensor,
eps: f64,
) -> Result<Self> {
- if eps < 0. {
- candle::bail!("batch-norm eps cannot be negative {eps}")
- }
- Ok(Self {
- running_mean,
- running_var,
+ let out = Self {
+ running_mean: Var::from_tensor(&running_mean)?,
+ running_var: Var::from_tensor(&running_var)?,
+ weight_and_bias: None,
+ remove_mean: true,
+ eps,
+ momentum: 0.1,
+ };
+ out.check_validity(num_features)?;
+ Ok(out)
+ }
+
+ pub fn new_with_momentum(
+ num_features: usize,
+ running_mean: Tensor,
+ running_var: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ momentum: f64,
+ ) -> Result<Self> {
+ let out = Self {
+ running_mean: Var::from_tensor(&running_mean)?,
+ running_var: Var::from_tensor(&running_var)?,
+ weight_and_bias: Some((weight, bias)),
+ remove_mean: true,
+ eps,
+ momentum,
+ };
+ out.check_validity(num_features)?;
+ Ok(out)
+ }
+
+ pub fn new_no_bias_with_momentum(
+ num_features: usize,
+ running_mean: Tensor,
+ running_var: Tensor,
+ eps: f64,
+ momentum: f64,
+ ) -> Result<Self> {
+ let out = Self {
+ running_mean: Var::from_tensor(&running_mean)?,
+ running_var: Var::from_tensor(&running_var)?,
weight_and_bias: None,
remove_mean: true,
eps,
- num_features,
- })
+ momentum,
+ };
+ out.check_validity(num_features)?;
+ Ok(out)
}
pub fn running_mean(&self) -> &Tensor {
- &self.running_mean
+ self.running_mean.as_tensor()
}
pub fn running_var(&self) -> &Tensor {
- &self.running_var
+ self.running_var.as_tensor()
}
pub fn eps(&self) -> f64 {
@@ -117,7 +188,12 @@ impl BatchNorm {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
- pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
+ pub fn momentum(&self) -> f64 {
+ self.momentum
+ }
+
+ pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
+ let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
@@ -129,40 +205,54 @@ impl BatchNorm {
x.shape()
)
}
- if x.dim(1)? != self.num_features {
+ if x.dim(1)? != num_features {
candle::bail!(
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
x.shape(),
- self.num_features
+ num_features
)
}
let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims();
+ // Flatten all the dimensions exception the channel one as this performs a Spatial Batch
+ // Normalization.
let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean {
+ // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let mean_x = x.mean_keepdim(1)?;
+ let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
+ + (mean_x.flatten_all()? * self.momentum)?)?;
+ self.running_mean.set(&updated_running_mean)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
+ // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let norm_x = x.sqr()?.mean_keepdim(1)?;
- let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
- let x = x_normed.to_dtype(x_dtype)?;
+ let updated_running_var = {
+ let batch_size = x.dim(1)? as f64;
+ let running_var_weight = 1.0 - self.momentum;
+ let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
+ ((self.running_var.as_tensor() * running_var_weight)?
+ + (&norm_x.flatten_all()? * norm_x_weight)?)?
+ };
+ self.running_var.set(&updated_running_var)?;
+ let x = x
+ .broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)?;
let x = match &self.weight_and_bias {
None => x,
Some((weight, bias)) => {
- let weight = weight.reshape((self.num_features, 1))?;
- let bias = bias.reshape((self.num_features, 1))?;
+ let weight = weight.reshape(((), 1))?;
+ let bias = bias.reshape(((), 1))?;
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
}
};
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
-}
-impl crate::Module for BatchNorm {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x
.dims()
.iter()
@@ -170,9 +260,13 @@ impl crate::Module for BatchNorm {
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
+
let x = x
- .broadcast_sub(&self.running_mean.reshape(target_shape)?)?
- .broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
+ .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
+ .broadcast_div(
+ &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
+ )?;
+
match &self.weight_and_bias {
None => Ok(x),
Some((weight, bias)) => {
@@ -184,30 +278,41 @@ impl crate::Module for BatchNorm {
}
}
+impl crate::ModuleT for BatchNorm {
+ fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
+ if train {
+ self.forward_train(x)
+ } else {
+ self.forward_eval(x)
+ }
+ }
+}
+
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,
vb: crate::VarBuilder,
) -> Result<BatchNorm> {
+ use crate::Init;
let config = config.into();
if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
}
- let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
- let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
+ let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
+ let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
let weight_and_bias = if config.affine {
- let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
- let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
+ let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
+ let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
Some((weight, bias))
} else {
None
};
Ok(BatchNorm {
- running_mean,
- running_var,
+ running_mean: Var::from_tensor(&running_mean)?,
+ running_var: Var::from_tensor(&running_var)?,
weight_and_bias,
remove_mean: config.remove_mean,
eps: config.eps,
- num_features,
+ momentum: config.momentum,
})
}
diff --git a/candle-nn/src/encoding.rs b/candle-nn/src/encoding.rs
new file mode 100644
index 00000000..38e2cc3b
--- /dev/null
+++ b/candle-nn/src/encoding.rs
@@ -0,0 +1,150 @@
+//! Encoding Utilities. (e.g., one-hot/cold encoding)
+
+use candle::{bail, DType, Result, Tensor, WithDType};
+
+/// One-hot/cold encoding.
+///
+/// Given an input tensor of indices, this function returns a tensor of the same shape as the input
+/// tensor with an additional dimension of the given depth size. The values in the returned tensor are
+/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`.
+///
+/// This method returns a tensor with a rank that is one rank larger than the input tensor.
+///
+/// As an example, the following tensor will be encoded to a one-hot matrix:
+///
+/// `[[0i64, 2], [1, -1]]`
+///
+/// with a depth of 4 will be encoded to:
+///
+/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]`
+///
+/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored,
+/// resulting in a vector of values set to the `off_value`.
+///
+///
+/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`.
+/// By default `on_value` is `1` and `off_value` is `0`.
+///
+/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values.
+///
+/// # Examples
+///
+/// ## One-hot encoding
+///
+/// ```rust
+/// use candle::{Shape, Tensor, Device};
+/// use candle_nn::encoding::one_hot;
+///
+/// let device = candle::Device::Cpu;
+///
+/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap();
+/// let depth = 4;
+/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap();
+///
+/// let expected_matrix = [
+/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
+/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
+/// ];
+///
+/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));
+///
+/// let matrix = one_hot.to_vec3::<f32>().unwrap();
+///
+/// assert_eq!(matrix, expected_matrix);
+///```
+/// ## One-cold Encoding
+///
+/// ```rust
+/// use candle::{Shape, Tensor, Device};
+/// use candle_nn::encoding::one_hot;
+///
+///
+/// let device = candle::Device::Cpu;
+/// let depth = 4;
+/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap();
+/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap();
+///
+/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]];
+///
+/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));
+///
+/// let matrix = one_cold.to_vec3::<u8>().unwrap();
+///
+/// assert_eq!(matrix, expected_matrix);
+/// ```
+///
+///
+/// # Bails
+///
+/// This method bails if:
+/// - One of the index value is less than -1.
+/// - One of the index value is greater than or equal to the depth value.
+/// - The input data type is not `U8`, `U32`, or `I64`.
+///
+/// # API Design
+///
+/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method.
+pub fn one_hot<D: WithDType>(
+ indices: Tensor,
+ depth: usize,
+ on_value: D,
+ off_value: D,
+) -> Result<Tensor> {
+ let mut target_shape = indices.dims().to_vec();
+ target_shape.push(depth);
+ let indices = indices.flatten_all()?;
+ let mut out = vec![off_value; depth * indices.elem_count()];
+ match indices.dtype() {
+ DType::U8 => {
+ let indices = indices.to_vec1::<u8>()?;
+ for (i, &index) in indices.iter().enumerate() {
+ set_at_index(index, i * depth, depth, &mut out, on_value)?;
+ }
+ }
+ DType::U32 => {
+ let indices = indices.to_vec1::<u32>()?;
+ for (i, &index) in indices.iter().enumerate() {
+ set_at_index(index, i * depth, depth, &mut out, on_value)?;
+ }
+ }
+ DType::I64 => {
+ let indices = indices.to_vec1::<i64>()?;
+ for (i, &index) in indices.iter().enumerate() {
+ set_at_index(index, i * depth, depth, &mut out, on_value)?;
+ }
+ }
+ dtype => {
+ bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
+ }
+ };
+ Tensor::from_vec(out, target_shape, indices.device())
+}
+
+fn set_at_index<D: WithDType, I: Into<i64>>(
+ value: I,
+ offset: usize,
+ depth: usize,
+ v: &mut Vec<D>,
+ on_value: D,
+) -> Result<()> {
+ let value = value.into();
+ // Skip for an entire row of off_values
+ if value == -1 {
+ return Ok(());
+ }
+ if value < -1 {
+ bail!(
+ "one_hot: invalid negative index value {value}, expected a positive index value or -1"
+ );
+ }
+ let value = value as usize;
+ if value >= depth {
+ bail!("one_hot: index value {value} exceeds depth {depth}")
+ }
+ let idx = offset + value;
+ if idx >= v.len() {
+ bail!("one_hot: index out of bounds {idx}, len {}", v.len());
+ }
+ v[idx] = on_value;
+ Ok(())
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 8f00e54c..6306c55a 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -2,6 +2,7 @@ pub mod activation;
pub mod batch_norm;
pub mod conv;
pub mod embedding;
+pub mod encoding;
pub mod func;
pub mod group_norm;
pub mod init;