diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/conv.rs | 89 | ||||
-rw-r--r-- | candle-nn/src/dataset.rs | 171 | ||||
-rw-r--r-- | candle-nn/src/group_norm.rs | 83 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 6 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 10 | ||||
-rw-r--r-- | candle-nn/src/vision/cifar.rs | 62 | ||||
-rw-r--r-- | candle-nn/src/vision/mnist.rs | 65 | ||||
-rw-r--r-- | candle-nn/src/vision/mod.rs | 12 |
8 files changed, 185 insertions, 313 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 8fbe7659..67a80417 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -48,3 +48,92 @@ impl Conv1d { } } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Conv2dConfig { + pub padding: usize, + pub stride: usize, +} + +impl Default for Conv2dConfig { + fn default() -> Self { + Self { + padding: 0, + stride: 1, + } + } +} + +#[allow(dead_code)] +#[derive(Debug)] +pub struct Conv2d { + weight: Tensor, + bias: Option<Tensor>, + config: Conv2dConfig, +} + +impl Conv2d { + pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &Conv2dConfig { + &self.config + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +pub fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv1dConfig, + vs: crate::VarBuilder, +) -> Result<Conv1d> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?; + let bound = 1. / (in_channels as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_or_init(out_channels, "bias", init_bs)?; + Ok(Conv1d::new(ws, Some(bs), cfg)) +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv2dConfig, + vs: crate::VarBuilder, +) -> Result<Conv2d> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init( + (out_channels, in_channels, kernel_size, kernel_size), + "weight", + init_ws, + )?; + let bound = 1. / (in_channels as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_or_init(out_channels, "bias", init_bs)?; + Ok(Conv2d::new(ws, Some(bs), cfg)) +} diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs deleted file mode 100644 index b74f1417..00000000 --- a/candle-nn/src/dataset.rs +++ /dev/null @@ -1,171 +0,0 @@ -use candle::{Result, Tensor}; - -pub struct Batcher<I> { - inner: I, - batch_size: usize, - return_last_incomplete_batch: bool, -} - -impl<I> Batcher<I> { - fn new(inner: I) -> Self { - Self { - inner, - batch_size: 16, - return_last_incomplete_batch: false, - } - } - - pub fn batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; - self - } - - pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { - self.return_last_incomplete_batch = r; - self - } -} - -pub struct Iter1<I: Iterator<Item = Tensor>> { - inner: I, -} - -pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> { - inner: I, -} - -impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> { - pub fn new1(inner: I) -> Self { - Self::new(Iter1 { inner }) - } -} - -impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> { - pub fn new2(inner: I) -> Self { - Self::new(Iter2 { inner }) - } -} - -pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> { - inner: I, -} - -pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> { - inner: I, -} - -impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> { - pub fn new_r1(inner: I) -> Self { - Self::new(IterResult1 { inner }) - } -} - -impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> { - pub fn new_r2(inner: I) -> Self { - Self::new(IterResult2 { inner }) - } -} - -impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { - type Item = Result<Tensor>; - - fn next(&mut self) -> Option<Self::Item> { - let mut items = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - // We have two levels of inner here so that we can have two implementations of the - // Iterator trait that are different for Iter1 and Iter2. If rust gets better - // specialization at some point we can get rid of this. - match self.inner.inner.next() { - Some(item) => items.push(item), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - Some(Tensor::stack(&items, 0)) - } -} - -impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option<Self::Item> { - let mut xs = Vec::with_capacity(self.batch_size); - let mut ys = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - match self.inner.inner.next() { - Some((x, y)) => { - xs.push(x); - ys.push(y) - } - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - let xs = Tensor::stack(&xs, 0); - let ys = Tensor::stack(&ys, 0); - Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) - } -} - -impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { - type Item = Result<Tensor>; - - fn next(&mut self) -> Option<Self::Item> { - let mut items = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - // We have two levels of inner here so that we can have two implementations of the - // Iterator trait that are different for Iter1 and Iter2. If rust gets better - // specialization at some point we can get rid of this. - match self.inner.inner.next() { - Some(item) => items.push(item), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - let items = items.into_iter().collect::<Result<Vec<Tensor>>>(); - Some(items.and_then(|items| Tensor::stack(&items, 0))) - } -} - -impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option<Self::Item> { - let mut xs = Vec::with_capacity(self.batch_size); - let mut ys = Vec::with_capacity(self.batch_size); - let mut errs = vec![]; - for _i in 0..self.batch_size { - match self.inner.inner.next() { - Some(Ok((x, y))) => { - xs.push(x); - ys.push(y) - } - Some(Err(err)) => errs.push(err), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - if !errs.is_empty() { - return Some(Err(errs.swap_remove(0))); - } - let xs = Tensor::stack(&xs, 0); - let ys = Tensor::stack(&ys, 0); - Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) - } -} diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs new file mode 100644 index 00000000..ac77db4b --- /dev/null +++ b/candle-nn/src/group_norm.rs @@ -0,0 +1,83 @@ +//! Group Normalization. +//! +//! This layer applies Group Normalization over a mini-batch of inputs. +use candle::{DType, Result, Tensor}; + +// This group norm version handles both weight and bias so removes the mean. +#[derive(Debug)] +pub struct GroupNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + num_channels: usize, + num_groups: usize, +} + +impl GroupNorm { + pub fn new( + weight: Tensor, + bias: Tensor, + num_channels: usize, + num_groups: usize, + eps: f64, + ) -> Result<Self> { + if num_channels % num_groups != 0 { + candle::bail!( + "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})" + ) + } + Ok(Self { + weight, + bias, + eps, + num_channels, + num_groups, + }) + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x_shape = x.dims(); + if x_shape.len() <= 2 { + candle::bail!("input rank for GroupNorm should be at least 3"); + } + let (b_sz, n_channels) = (x_shape[0], x_shape[1]); + let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups; + if n_channels != self.num_channels { + candle::bail!( + "unexpected num-channels in GroupNorm ({n_channels} <> {}", + self.num_channels + ) + } + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let x = x.reshape((b_sz, self.num_groups, hidden_size))?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let mut w_dims = vec![1; x_shape.len()]; + w_dims[1] = n_channels; + let weight = self.weight.reshape(w_dims.clone())?; + let bias = self.bias.reshape(w_dims)?; + x_normed + .to_dtype(x_dtype)? + .reshape(x_shape)? + .broadcast_mul(&weight)? + .broadcast_add(&bias) + } +} + +pub fn group_norm( + num_groups: usize, + num_channels: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result<GroupNorm> { + let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?; + let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?; + GroupNorm::new(weight, bias, num_channels, num_groups, eps) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 46a83800..ae955f56 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,8 +2,8 @@ // error type if needed or add some specialized cases on the candle-core side. pub mod activation; pub mod conv; -pub mod dataset; pub mod embedding; +pub mod group_norm; pub mod init; pub mod layer_norm; pub mod linear; @@ -11,11 +11,11 @@ pub mod loss; pub mod ops; pub mod optim; pub mod var_builder; -pub mod vision; pub use activation::Activation; -pub use conv::{Conv1d, Conv1dConfig}; +pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use embedding::{embedding, Embedding}; +pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, LayerNorm}; pub use linear::{linear, linear_no_bias, Linear}; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 611c66d8..397674f3 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -32,3 +32,13 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; Ok(log_sm) } + +pub fn silu(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? + xs / (xs.neg()?.exp()? + 1.0)? +} + +pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? + (xs.neg()?.exp()? + 1.0)?.recip() +} diff --git a/candle-nn/src/vision/cifar.rs b/candle-nn/src/vision/cifar.rs deleted file mode 100644 index 0683c4d2..00000000 --- a/candle-nn/src/vision/cifar.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! The CIFAR-10 dataset. -//! -//! The files can be downloaded from the following page: -//! <https://www.cs.toronto.edu/~kriz/cifar.html> -//! The binary version of the dataset is used. -use crate::vision::Dataset; -use candle::{DType, Device, Result, Tensor}; -use std::fs::File; -use std::io::{BufReader, Read}; - -const W: usize = 32; -const H: usize = 32; -const C: usize = 3; -const BYTES_PER_IMAGE: usize = W * H * C + 1; -const SAMPLES_PER_FILE: usize = 10000; - -fn read_file(filename: &std::path::Path) -> Result<(Tensor, Tensor)> { - let mut buf_reader = BufReader::new(File::open(filename)?); - let mut data = vec![0u8; SAMPLES_PER_FILE * BYTES_PER_IMAGE]; - buf_reader.read_exact(&mut data)?; - let mut images = vec![]; - let mut labels = vec![]; - for index in 0..SAMPLES_PER_FILE { - let content_offset = BYTES_PER_IMAGE * index; - labels.push(data[content_offset]); - images.push(&data[1 + content_offset..content_offset + BYTES_PER_IMAGE]); - } - let images: Vec<u8> = images - .iter() - .copied() - .flatten() - .copied() - .collect::<Vec<_>>(); - let labels = Tensor::from_vec(labels, SAMPLES_PER_FILE, &Device::Cpu)?; - let images = Tensor::from_vec(images, (SAMPLES_PER_FILE, C, H, W), &Device::Cpu)?; - let images = (images.to_dtype(DType::F32)? / 255.)?; - Ok((images, labels)) -} - -pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> { - let dir = dir.as_ref(); - let (test_images, test_labels) = read_file(&dir.join("test_batch.bin"))?; - let train_images_and_labels = [ - "data_batch_1.bin", - "data_batch_2.bin", - "data_batch_3.bin", - "data_batch_4.bin", - "data_batch_5.bin", - ] - .iter() - .map(|x| read_file(&dir.join(x))) - .collect::<Result<Vec<_>>>()?; - let (train_images, train_labels): (Vec<_>, Vec<_>) = - train_images_and_labels.into_iter().unzip(); - Ok(Dataset { - train_images: Tensor::cat(&train_images, 0)?, - train_labels: Tensor::cat(&train_labels, 0)?, - test_images, - test_labels, - labels: 10, - }) -} diff --git a/candle-nn/src/vision/mnist.rs b/candle-nn/src/vision/mnist.rs deleted file mode 100644 index 2267f9a0..00000000 --- a/candle-nn/src/vision/mnist.rs +++ /dev/null @@ -1,65 +0,0 @@ -//! The MNIST hand-written digit dataset. -//! -//! The files can be obtained from the following link: -//! <http://yann.lecun.com/exdb/mnist/> -use candle::{DType, Device, Result, Tensor}; -use std::fs::File; -use std::io::{self, BufReader, Read}; - -fn read_u32<T: Read>(reader: &mut T) -> Result<u32> { - let mut b = vec![0u8; 4]; - reader.read_exact(&mut b)?; - let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| { - (s + basis * u64::from(x), basis * 256) - }); - Ok(result as u32) -} - -fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> { - let magic_number = read_u32(reader)?; - if magic_number != expected { - Err(io::Error::new( - io::ErrorKind::Other, - format!("incorrect magic number {magic_number} != {expected}"), - ))?; - } - Ok(()) -} - -fn read_labels(filename: &std::path::Path) -> Result<Tensor> { - let mut buf_reader = BufReader::new(File::open(filename)?); - check_magic_number(&mut buf_reader, 2049)?; - let samples = read_u32(&mut buf_reader)?; - let mut data = vec![0u8; samples as usize]; - buf_reader.read_exact(&mut data)?; - let samples = data.len(); - Tensor::from_vec(data, samples, &Device::Cpu) -} - -fn read_images(filename: &std::path::Path) -> Result<Tensor> { - let mut buf_reader = BufReader::new(File::open(filename)?); - check_magic_number(&mut buf_reader, 2051)?; - let samples = read_u32(&mut buf_reader)? as usize; - let rows = read_u32(&mut buf_reader)? as usize; - let cols = read_u32(&mut buf_reader)? as usize; - let data_len = samples * rows * cols; - let mut data = vec![0u8; data_len]; - buf_reader.read_exact(&mut data)?; - let tensor = Tensor::from_vec(data, (samples, rows * cols), &Device::Cpu)?; - tensor.to_dtype(DType::F32)? / 255. -} - -pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Dataset> { - let dir = dir.as_ref(); - let train_images = read_images(&dir.join("train-images-idx3-ubyte"))?; - let train_labels = read_labels(&dir.join("train-labels-idx1-ubyte"))?; - let test_images = read_images(&dir.join("t10k-images-idx3-ubyte"))?; - let test_labels = read_labels(&dir.join("t10k-labels-idx1-ubyte"))?; - Ok(crate::vision::Dataset { - train_images, - train_labels, - test_images, - test_labels, - labels: 10, - }) -} diff --git a/candle-nn/src/vision/mod.rs b/candle-nn/src/vision/mod.rs deleted file mode 100644 index 6ce743eb..00000000 --- a/candle-nn/src/vision/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -use candle::Tensor; - -pub struct Dataset { - pub train_images: Tensor, - pub train_labels: Tensor, - pub test_images: Tensor, - pub test_labels: Tensor, - pub labels: usize, -} - -pub mod cifar; -pub mod mnist; |