summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/conv.rs89
-rw-r--r--candle-nn/src/dataset.rs171
-rw-r--r--candle-nn/src/group_norm.rs83
-rw-r--r--candle-nn/src/lib.rs6
-rw-r--r--candle-nn/src/ops.rs10
-rw-r--r--candle-nn/src/vision/cifar.rs62
-rw-r--r--candle-nn/src/vision/mnist.rs65
-rw-r--r--candle-nn/src/vision/mod.rs12
8 files changed, 185 insertions, 313 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 8fbe7659..67a80417 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -48,3 +48,92 @@ impl Conv1d {
}
}
}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct Conv2dConfig {
+ pub padding: usize,
+ pub stride: usize,
+}
+
+impl Default for Conv2dConfig {
+ fn default() -> Self {
+ Self {
+ padding: 0,
+ stride: 1,
+ }
+ }
+}
+
+#[allow(dead_code)]
+#[derive(Debug)]
+pub struct Conv2d {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ config: Conv2dConfig,
+}
+
+impl Conv2d {
+ pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
+ Self {
+ weight,
+ bias,
+ config,
+ }
+ }
+
+ pub fn config(&self) -> &Conv2dConfig {
+ &self.config
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
+ }
+}
+
+pub fn conv1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv1dConfig,
+ vs: crate::VarBuilder,
+) -> Result<Conv1d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?;
+ let bound = 1. / (in_channels as f64).sqrt();
+ let init_bs = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
+ Ok(Conv1d::new(ws, Some(bs), cfg))
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv2dConfig,
+ vs: crate::VarBuilder,
+) -> Result<Conv2d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init(
+ (out_channels, in_channels, kernel_size, kernel_size),
+ "weight",
+ init_ws,
+ )?;
+ let bound = 1. / (in_channels as f64).sqrt();
+ let init_bs = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
+ Ok(Conv2d::new(ws, Some(bs), cfg))
+}
diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs
deleted file mode 100644
index b74f1417..00000000
--- a/candle-nn/src/dataset.rs
+++ /dev/null
@@ -1,171 +0,0 @@
-use candle::{Result, Tensor};
-
-pub struct Batcher<I> {
- inner: I,
- batch_size: usize,
- return_last_incomplete_batch: bool,
-}
-
-impl<I> Batcher<I> {
- fn new(inner: I) -> Self {
- Self {
- inner,
- batch_size: 16,
- return_last_incomplete_batch: false,
- }
- }
-
- pub fn batch_size(mut self, batch_size: usize) -> Self {
- self.batch_size = batch_size;
- self
- }
-
- pub fn return_last_incomplete_batch(mut self, r: bool) -> Self {
- self.return_last_incomplete_batch = r;
- self
- }
-}
-
-pub struct Iter1<I: Iterator<Item = Tensor>> {
- inner: I,
-}
-
-pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> {
- inner: I,
-}
-
-impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> {
- pub fn new1(inner: I) -> Self {
- Self::new(Iter1 { inner })
- }
-}
-
-impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
- pub fn new2(inner: I) -> Self {
- Self::new(Iter2 { inner })
- }
-}
-
-pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> {
- inner: I,
-}
-
-pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> {
- inner: I,
-}
-
-impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> {
- pub fn new_r1(inner: I) -> Self {
- Self::new(IterResult1 { inner })
- }
-}
-
-impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> {
- pub fn new_r2(inner: I) -> Self {
- Self::new(IterResult2 { inner })
- }
-}
-
-impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
- type Item = Result<Tensor>;
-
- fn next(&mut self) -> Option<Self::Item> {
- let mut items = Vec::with_capacity(self.batch_size);
- for _i in 0..self.batch_size {
- // We have two levels of inner here so that we can have two implementations of the
- // Iterator trait that are different for Iter1 and Iter2. If rust gets better
- // specialization at some point we can get rid of this.
- match self.inner.inner.next() {
- Some(item) => items.push(item),
- None => {
- if self.return_last_incomplete_batch {
- break;
- }
- return None;
- }
- }
- }
- Some(Tensor::stack(&items, 0))
- }
-}
-
-impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
- type Item = Result<(Tensor, Tensor)>;
-
- fn next(&mut self) -> Option<Self::Item> {
- let mut xs = Vec::with_capacity(self.batch_size);
- let mut ys = Vec::with_capacity(self.batch_size);
- for _i in 0..self.batch_size {
- match self.inner.inner.next() {
- Some((x, y)) => {
- xs.push(x);
- ys.push(y)
- }
- None => {
- if self.return_last_incomplete_batch {
- break;
- }
- return None;
- }
- }
- }
- let xs = Tensor::stack(&xs, 0);
- let ys = Tensor::stack(&ys, 0);
- Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
- }
-}
-
-impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
- type Item = Result<Tensor>;
-
- fn next(&mut self) -> Option<Self::Item> {
- let mut items = Vec::with_capacity(self.batch_size);
- for _i in 0..self.batch_size {
- // We have two levels of inner here so that we can have two implementations of the
- // Iterator trait that are different for Iter1 and Iter2. If rust gets better
- // specialization at some point we can get rid of this.
- match self.inner.inner.next() {
- Some(item) => items.push(item),
- None => {
- if self.return_last_incomplete_batch {
- break;
- }
- return None;
- }
- }
- }
- let items = items.into_iter().collect::<Result<Vec<Tensor>>>();
- Some(items.and_then(|items| Tensor::stack(&items, 0)))
- }
-}
-
-impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> {
- type Item = Result<(Tensor, Tensor)>;
-
- fn next(&mut self) -> Option<Self::Item> {
- let mut xs = Vec::with_capacity(self.batch_size);
- let mut ys = Vec::with_capacity(self.batch_size);
- let mut errs = vec![];
- for _i in 0..self.batch_size {
- match self.inner.inner.next() {
- Some(Ok((x, y))) => {
- xs.push(x);
- ys.push(y)
- }
- Some(Err(err)) => errs.push(err),
- None => {
- if self.return_last_incomplete_batch {
- break;
- }
- return None;
- }
- }
- }
- if !errs.is_empty() {
- return Some(Err(errs.swap_remove(0)));
- }
- let xs = Tensor::stack(&xs, 0);
- let ys = Tensor::stack(&ys, 0);
- Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
- }
-}
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
new file mode 100644
index 00000000..ac77db4b
--- /dev/null
+++ b/candle-nn/src/group_norm.rs
@@ -0,0 +1,83 @@
+//! Group Normalization.
+//!
+//! This layer applies Group Normalization over a mini-batch of inputs.
+use candle::{DType, Result, Tensor};
+
+// This group norm version handles both weight and bias so removes the mean.
+#[derive(Debug)]
+pub struct GroupNorm {
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ num_channels: usize,
+ num_groups: usize,
+}
+
+impl GroupNorm {
+ pub fn new(
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ num_groups: usize,
+ eps: f64,
+ ) -> Result<Self> {
+ if num_channels % num_groups != 0 {
+ candle::bail!(
+ "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
+ )
+ }
+ Ok(Self {
+ weight,
+ bias,
+ eps,
+ num_channels,
+ num_groups,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x_shape = x.dims();
+ if x_shape.len() <= 2 {
+ candle::bail!("input rank for GroupNorm should be at least 3");
+ }
+ let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
+ let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
+ if n_channels != self.num_channels {
+ candle::bail!(
+ "unexpected num-channels in GroupNorm ({n_channels} <> {}",
+ self.num_channels
+ )
+ }
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let mut w_dims = vec![1; x_shape.len()];
+ w_dims[1] = n_channels;
+ let weight = self.weight.reshape(w_dims.clone())?;
+ let bias = self.bias.reshape(w_dims)?;
+ x_normed
+ .to_dtype(x_dtype)?
+ .reshape(x_shape)?
+ .broadcast_mul(&weight)?
+ .broadcast_add(&bias)
+ }
+}
+
+pub fn group_norm(
+ num_groups: usize,
+ num_channels: usize,
+ eps: f64,
+ vb: crate::VarBuilder,
+) -> Result<GroupNorm> {
+ let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
+ let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
+ GroupNorm::new(weight, bias, num_channels, num_groups, eps)
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 46a83800..ae955f56 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -2,8 +2,8 @@
// error type if needed or add some specialized cases on the candle-core side.
pub mod activation;
pub mod conv;
-pub mod dataset;
pub mod embedding;
+pub mod group_norm;
pub mod init;
pub mod layer_norm;
pub mod linear;
@@ -11,11 +11,11 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod var_builder;
-pub mod vision;
pub use activation::Activation;
-pub use conv::{Conv1d, Conv1dConfig};
+pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
+pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, LayerNorm};
pub use linear::{linear, linear_no_bias, Linear};
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 611c66d8..397674f3 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -32,3 +32,13 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
Ok(log_sm)
}
+
+pub fn silu(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
+ xs / (xs.neg()?.exp()? + 1.0)?
+}
+
+pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
+ (xs.neg()?.exp()? + 1.0)?.recip()
+}
diff --git a/candle-nn/src/vision/cifar.rs b/candle-nn/src/vision/cifar.rs
deleted file mode 100644
index 0683c4d2..00000000
--- a/candle-nn/src/vision/cifar.rs
+++ /dev/null
@@ -1,62 +0,0 @@
-//! The CIFAR-10 dataset.
-//!
-//! The files can be downloaded from the following page:
-//! <https://www.cs.toronto.edu/~kriz/cifar.html>
-//! The binary version of the dataset is used.
-use crate::vision::Dataset;
-use candle::{DType, Device, Result, Tensor};
-use std::fs::File;
-use std::io::{BufReader, Read};
-
-const W: usize = 32;
-const H: usize = 32;
-const C: usize = 3;
-const BYTES_PER_IMAGE: usize = W * H * C + 1;
-const SAMPLES_PER_FILE: usize = 10000;
-
-fn read_file(filename: &std::path::Path) -> Result<(Tensor, Tensor)> {
- let mut buf_reader = BufReader::new(File::open(filename)?);
- let mut data = vec![0u8; SAMPLES_PER_FILE * BYTES_PER_IMAGE];
- buf_reader.read_exact(&mut data)?;
- let mut images = vec![];
- let mut labels = vec![];
- for index in 0..SAMPLES_PER_FILE {
- let content_offset = BYTES_PER_IMAGE * index;
- labels.push(data[content_offset]);
- images.push(&data[1 + content_offset..content_offset + BYTES_PER_IMAGE]);
- }
- let images: Vec<u8> = images
- .iter()
- .copied()
- .flatten()
- .copied()
- .collect::<Vec<_>>();
- let labels = Tensor::from_vec(labels, SAMPLES_PER_FILE, &Device::Cpu)?;
- let images = Tensor::from_vec(images, (SAMPLES_PER_FILE, C, H, W), &Device::Cpu)?;
- let images = (images.to_dtype(DType::F32)? / 255.)?;
- Ok((images, labels))
-}
-
-pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
- let dir = dir.as_ref();
- let (test_images, test_labels) = read_file(&dir.join("test_batch.bin"))?;
- let train_images_and_labels = [
- "data_batch_1.bin",
- "data_batch_2.bin",
- "data_batch_3.bin",
- "data_batch_4.bin",
- "data_batch_5.bin",
- ]
- .iter()
- .map(|x| read_file(&dir.join(x)))
- .collect::<Result<Vec<_>>>()?;
- let (train_images, train_labels): (Vec<_>, Vec<_>) =
- train_images_and_labels.into_iter().unzip();
- Ok(Dataset {
- train_images: Tensor::cat(&train_images, 0)?,
- train_labels: Tensor::cat(&train_labels, 0)?,
- test_images,
- test_labels,
- labels: 10,
- })
-}
diff --git a/candle-nn/src/vision/mnist.rs b/candle-nn/src/vision/mnist.rs
deleted file mode 100644
index 2267f9a0..00000000
--- a/candle-nn/src/vision/mnist.rs
+++ /dev/null
@@ -1,65 +0,0 @@
-//! The MNIST hand-written digit dataset.
-//!
-//! The files can be obtained from the following link:
-//! <http://yann.lecun.com/exdb/mnist/>
-use candle::{DType, Device, Result, Tensor};
-use std::fs::File;
-use std::io::{self, BufReader, Read};
-
-fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
- let mut b = vec![0u8; 4];
- reader.read_exact(&mut b)?;
- let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
- (s + basis * u64::from(x), basis * 256)
- });
- Ok(result as u32)
-}
-
-fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
- let magic_number = read_u32(reader)?;
- if magic_number != expected {
- Err(io::Error::new(
- io::ErrorKind::Other,
- format!("incorrect magic number {magic_number} != {expected}"),
- ))?;
- }
- Ok(())
-}
-
-fn read_labels(filename: &std::path::Path) -> Result<Tensor> {
- let mut buf_reader = BufReader::new(File::open(filename)?);
- check_magic_number(&mut buf_reader, 2049)?;
- let samples = read_u32(&mut buf_reader)?;
- let mut data = vec![0u8; samples as usize];
- buf_reader.read_exact(&mut data)?;
- let samples = data.len();
- Tensor::from_vec(data, samples, &Device::Cpu)
-}
-
-fn read_images(filename: &std::path::Path) -> Result<Tensor> {
- let mut buf_reader = BufReader::new(File::open(filename)?);
- check_magic_number(&mut buf_reader, 2051)?;
- let samples = read_u32(&mut buf_reader)? as usize;
- let rows = read_u32(&mut buf_reader)? as usize;
- let cols = read_u32(&mut buf_reader)? as usize;
- let data_len = samples * rows * cols;
- let mut data = vec![0u8; data_len];
- buf_reader.read_exact(&mut data)?;
- let tensor = Tensor::from_vec(data, (samples, rows * cols), &Device::Cpu)?;
- tensor.to_dtype(DType::F32)? / 255.
-}
-
-pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Dataset> {
- let dir = dir.as_ref();
- let train_images = read_images(&dir.join("train-images-idx3-ubyte"))?;
- let train_labels = read_labels(&dir.join("train-labels-idx1-ubyte"))?;
- let test_images = read_images(&dir.join("t10k-images-idx3-ubyte"))?;
- let test_labels = read_labels(&dir.join("t10k-labels-idx1-ubyte"))?;
- Ok(crate::vision::Dataset {
- train_images,
- train_labels,
- test_images,
- test_labels,
- labels: 10,
- })
-}
diff --git a/candle-nn/src/vision/mod.rs b/candle-nn/src/vision/mod.rs
deleted file mode 100644
index 6ce743eb..00000000
--- a/candle-nn/src/vision/mod.rs
+++ /dev/null
@@ -1,12 +0,0 @@
-use candle::Tensor;
-
-pub struct Dataset {
- pub train_images: Tensor,
- pub train_labels: Tensor,
- pub test_images: Tensor,
- pub test_labels: Tensor,
- pub labels: usize,
-}
-
-pub mod cifar;
-pub mod mnist;