summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-19 16:37:52 +0200
committerGitHub <noreply@github.com>2023-07-19 15:37:52 +0100
commitcb687b4897052fbade21df0c589474d8bb94ab0b (patch)
treebe2e6a7aa5c063803eff036df2588f41aaf46709
parent67e20c37920d9f7677e3a4274eef8c73455274c8 (diff)
downloadcandle-cb687b4897052fbade21df0c589474d8bb94ab0b.tar.gz
candle-cb687b4897052fbade21df0c589474d8bb94ab0b.tar.bz2
candle-cb687b4897052fbade21df0c589474d8bb94ab0b.zip
Add some more developed training examples. (#199)
* Use contiguous tensors for variables. * Sketch the mnist example. * Start adding the reduce ops. * Renaming. * Refactor the reduce operations. * Bugfix for the broadcasting vectorization.
-rw-r--r--.gitignore1
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/backprop.rs8
-rw-r--r--candle-core/src/cpu_backend.rs73
-rw-r--r--candle-core/src/cuda_backend.rs19
-rw-r--r--candle-core/src/dummy_cuda_backend.rs2
-rw-r--r--candle-core/src/op.rs9
-rw-r--r--candle-core/src/storage.rs11
-rw-r--r--candle-core/src/tensor.rs128
-rw-r--r--candle-examples/examples/simple-training/main.rs44
10 files changed, 232 insertions, 65 deletions
diff --git a/.gitignore b/.gitignore
index df9a6132..9ff37524 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
# Generated by Cargo
# will have compiled files and executables
debug/
+data/
dist/
target/
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index c897510e..018279b3 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -16,7 +16,7 @@ pub(crate) trait BackendStorage: Sized {
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
- fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
+ fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index c72f603f..3de11d35 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -67,6 +67,8 @@ impl Tensor {
Op::Reshape(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
+ | Op::Max(node, _)
+ | Op::Min(node, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -203,6 +205,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
}
+ Op::Max(_args, _sum_dims) => {
+ return Err(Error::BackwardNotSupported { op: "max" })
+ }
+ Op::Min(_args, _sum_dims) => {
+ return Err(Error::BackwardNotSupported { op: "min" })
+ }
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 6458b452..f1118ee7 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -93,47 +93,52 @@ impl<'a> Map2 for WCond<'a> {
}
}
-struct Sum<'a> {
+struct Reduce<'a> {
dst_shape: &'a Shape,
- sum_dims: &'a [usize],
- sum_dims_and_stride: Vec<(usize, usize)>,
+ reduce_dims: &'a [usize],
+ reduce_dims_and_stride: Vec<(usize, usize)>,
+ op: crate::op::ReduceOp,
}
-impl<'a> Map1 for Sum<'a> {
+impl<'a> Map1 for Reduce<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
+ match self.op {
+ crate::op::ReduceOp::Min | crate::op::ReduceOp::Max => todo!(),
+ crate::op::ReduceOp::Sum => (),
+ }
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
- // Handle the case where we sum over the last dimensions separately as it is
+ // Handle the case where we reduce over the last dimensions separately as it is
// fairly common and easy to optimize. This rely on the layout being contiguous!
- // sum_dims is sorted, check if it is ranging from a to n-1.
- let sum_over_last_dims = self
- .sum_dims
+ // reduce_dims is sorted, check if it is ranging from a to n-1.
+ let reduce_over_last_dims = self
+ .reduce_dims
.iter()
.rev()
.enumerate()
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
- if sum_over_last_dims {
- let sum_sz = self
- .sum_dims_and_stride
+ if reduce_over_last_dims {
+ let reduce_sz = self
+ .reduce_dims_and_stride
.iter()
.map(|(u, _)| u)
.product::<usize>();
let mut src_i = 0;
for dst_v in dst.iter_mut() {
- for &s in src[src_i..src_i + sum_sz].iter() {
+ for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v += s
}
- src_i += sum_sz
+ src_i += reduce_sz
}
return Ok(dst);
};
for (unstr_index, &src) in src.iter().enumerate() {
let mut dst_index = unstr_index;
- // Set the sum_dims indexes to 0.
- for &(dim, stride) in self.sum_dims_and_stride.iter() {
+ // Set the reduce_dims indexes to 0.
+ for &(dim, stride) in self.reduce_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
@@ -144,8 +149,8 @@ impl<'a> Map1 for Sum<'a> {
None => {
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
let mut dst_index = unstr_index;
- // Set the sum_dims indexes to 0.
- for &(dim, stride) in self.sum_dims_and_stride.iter() {
+ // Set the reduce_dims indexes to 0.
+ for &(dim, stride) in self.reduce_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
@@ -340,7 +345,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
- let rhs = &rhs[ob.start..];
+ let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
@@ -358,7 +363,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
ys
}
Some(ob) => {
- let rhs = &rhs[ob.start..];
+ let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
@@ -379,7 +384,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
- let lhs = &lhs[ob.start..];
+ let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
@@ -397,7 +402,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
ys
}
Some(ob) => {
- let lhs = &lhs[ob.start..];
+ let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
@@ -1010,25 +1015,31 @@ impl BackendStorage for CpuStorage {
}
}
- fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
+ fn reduce_op(
+ &self,
+ op: crate::op::ReduceOp,
+ layout: &Layout,
+ reduce_dims: &[usize],
+ ) -> Result<Self> {
let src_dims = layout.dims();
let mut dst_dims = src_dims.to_vec();
- for &sum_dim in sum_dims.iter() {
- dst_dims[sum_dim] = 1;
+ for &dim in reduce_dims.iter() {
+ dst_dims[dim] = 1;
}
let dst_shape = Shape::from(dst_dims);
- let mut sum_dims = sum_dims.to_vec();
- // Sort the sum_dims as they have to be processed from left to right when converting the
+ let mut reduce_dims = reduce_dims.to_vec();
+ // Sort the reduce_dims as they have to be processed from left to right when converting the
// indexes.
- sum_dims.sort();
- let sum_dims_and_stride: Vec<_> = sum_dims
+ reduce_dims.sort();
+ let reduce_dims_and_stride: Vec<_> = reduce_dims
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
- Sum {
+ Reduce {
dst_shape: &dst_shape,
- sum_dims: &sum_dims,
- sum_dims_and_stride,
+ reduce_dims: &reduce_dims,
+ reduce_dims_and_stride,
+ op,
}
.map(self, layout)
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 74a3cf30..07d354b6 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -955,10 +955,21 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
- let device = self.device().clone();
- let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
- Ok(Self { slice, device })
+ fn reduce_op(
+ &self,
+ op: crate::op::ReduceOp,
+ layout: &Layout,
+ sum_dims: &[usize],
+ ) -> Result<Self> {
+ match op {
+ crate::op::ReduceOp::Sum => {
+ let device = self.device().clone();
+ let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
+ Ok(Self { slice, device })
+ }
+ crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()),
+ crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()),
+ }
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index a9c11bf6..f7cf8ab8 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -40,7 +40,7 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
+ fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 07ee7670..c5ff8179 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -29,6 +29,8 @@ pub(crate) enum Op {
add: f64,
},
Sum(Tensor, Vec<usize>),
+ Max(Tensor, Vec<usize>),
+ Min(Tensor, Vec<usize>),
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
@@ -354,3 +356,10 @@ impl UnaryOp for Relu {
v
}
}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ReduceOp {
+ Sum,
+ Min,
+ Max,
+}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 1531b212..e689905e 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -80,14 +80,19 @@ impl Storage {
}
}
- pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
+ pub(crate) fn reduce_op(
+ &self,
+ op: crate::op::ReduceOp,
+ layout: &Layout,
+ s: &[usize],
+ ) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
- let storage = storage.sum(layout, s)?;
+ let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
- let storage = storage.sum(layout, s)?;
+ let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a93514fc..32c8acd6 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -154,8 +154,14 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let storage = device.ones(&crate::shape::SCALAR, dtype)?;
- from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
+ if is_variable {
+ let shape = shape.into();
+ let storage = device.ones(&shape, dtype)?;
+ Ok(from_storage(storage, shape, None, is_variable))
+ } else {
+ let storage = device.ones(&crate::shape::SCALAR, dtype)?;
+ from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
+ }
}
/// Creates a new tensor filled with ones.
@@ -192,8 +198,14 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
- from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
+ if is_variable {
+ let shape = shape.into();
+ let storage = device.zeros(&shape, dtype)?;
+ Ok(from_storage(storage, shape, None, is_variable))
+ } else {
+ let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
+ from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
+ }
}
/// Creates a new tensor filled with zeros.
@@ -593,9 +605,77 @@ impl Tensor {
}
}
- pub fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
+ fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
+ match dims {
+ [] => Ok(self),
+ [i] => self.squeeze(*i),
+ dims => {
+ let dims = self
+ .dims()
+ .iter()
+ .enumerate()
+ .filter_map(|(dim_idx, &v)| {
+ if dims.contains(&dim_idx) {
+ None
+ } else {
+ Some(v)
+ }
+ })
+ .collect::<Vec<_>>();
+ self.reshape(dims)
+ }
+ }
+ }
+
+ fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
+ let max_dims = max_dims.to_indexes(self.shape(), "max")?;
+ let storage =
+ self.storage()
+ .reduce_op(crate::op::ReduceOp::Max, self.layout(), &max_dims)?;
+ let op = if self.track_op() {
+ Some(Op::Max(self.clone(), max_dims.to_vec()))
+ } else {
+ None
+ };
+ let mut dims = self.dims().to_vec();
+ for &max_dim in max_dims.iter() {
+ dims[max_dim] = 1
+ }
+ let max = from_storage(storage, dims, op, false);
+ if keepdim {
+ Ok(max)
+ } else {
+ max.squeeze_dims(&max_dims)
+ }
+ }
+
+ fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
+ let min_dims = min_dims.to_indexes(self.shape(), "min")?;
+ let storage =
+ self.storage()
+ .reduce_op(crate::op::ReduceOp::Min, self.layout(), &min_dims)?;
+ let op = if self.track_op() {
+ Some(Op::Min(self.clone(), min_dims.to_vec()))
+ } else {
+ None
+ };
+ let mut dims = self.dims().to_vec();
+ for &min_dim in min_dims.iter() {
+ dims[min_dim] = 1
+ }
+ let min = from_storage(storage, dims, op, false);
+ if keepdim {
+ Ok(min)
+ } else {
+ min.squeeze_dims(&min_dims)
+ }
+ }
+
+ fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
- let storage = self.storage().sum(self.layout(), &sum_dims)?;
+ let storage =
+ self.storage()
+ .reduce_op(crate::op::ReduceOp::Sum, self.layout(), &sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
@@ -609,25 +689,7 @@ impl Tensor {
if keepdim {
Ok(sum)
} else {
- match sum_dims.as_slice() {
- [] => Ok(sum),
- [i] => sum.squeeze(*i),
- sum_dims => {
- let dims = sum
- .dims()
- .iter()
- .enumerate()
- .filter_map(|(dim_idx, &v)| {
- if sum_dims.contains(&dim_idx) {
- None
- } else {
- Some(v)
- }
- })
- .collect::<Vec<_>>();
- sum.reshape(dims)
- }
- }
+ sum.squeeze_dims(&sum_dims)
}
}
@@ -659,6 +721,22 @@ impl Tensor {
self.sum_impl(sum_dims, false)
}
+ pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> {
+ self.max_impl(max_dims, true)
+ }
+
+ pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> {
+ self.max_impl(max_dims, false)
+ }
+
+ pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
+ self.min_impl(min_dims, true)
+ }
+
+ pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> {
+ self.min_impl(min_dims, false)
+ }
+
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs
new file mode 100644
index 00000000..df67f741
--- /dev/null
+++ b/candle-examples/examples/simple-training/main.rs
@@ -0,0 +1,44 @@
+// This should rearch 91.5% accuracy.
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use candle::{DType, Var, D};
+
+const IMAGE_DIM: usize = 784;
+const LABELS: usize = 10;
+
+pub fn main() -> Result<()> {
+ let dev = candle::Device::cuda_if_available(0)?;
+ let m = candle_nn::vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.shape());
+ println!("train-labels: {:?}", m.train_labels.shape());
+ println!("test-images: {:?}", m.test_images.shape());
+ println!("test-labels: {:?}", m.test_labels.shape());
+ let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
+ let bs = Var::zeros(LABELS, DType::F32, &dev)?;
+ let sgd = candle_nn::SGD::new(&[&ws, &bs], 0.1);
+ for epoch in 1..200 {
+ let logits = m.train_images.matmul(&ws)?.broadcast_add(&bs)?;
+ let loss = logits.softmax(D::Minus1)?;
+ // TODO: log_softmax + let loss = loss.nll_loss(&m.train_labels);
+ sgd.backward_step(&loss)?;
+
+ let _test_logits = m.test_images.matmul(&ws)?.broadcast_add(&bs)?;
+ /* TODO
+ let test_accuracy = test_logits
+ .argmax(Some(-1), false)
+ .eq_tensor(&m.test_labels)
+ .to_kind(Kind::Float)
+ .mean(Kind::Float)
+ .double_value(&[]);
+ */
+ let test_accuracy = 0.;
+ println!(
+ "{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
+ loss.to_scalar::<f32>()?,
+ 100. * test_accuracy
+ )
+ }
+ Ok(())
+}