diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 26 | ||||
-rw-r--r-- | candle-core/src/indexer.rs | 2 | ||||
-rw-r--r-- | candle-core/src/op.rs | 6 | ||||
-rw-r--r-- | candle-core/src/quantized/avx.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 27 | ||||
-rw-r--r-- | candle-core/tests/grad_tests.rs | 160 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 34 |
9 files changed, 250 insertions, 15 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 0f8c1a9f..52e79a5a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } +candle-kernels = { path = "../candle-kernels", version = "0.3.2", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.2", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fc0c79a2..c152f31f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -114,7 +114,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D(node) - | Op::UpsampleNearest2D(node) + | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) @@ -350,9 +350,27 @@ impl Tensor { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest1d", })?, - Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest2d", - })?, + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::SliceScatter0(lhs, rhs, start_rhs) => { let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index df106b73..e3ed41e5 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -64,7 +64,7 @@ impl Tensor { #[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { - /// This selects the elemnts for which an index has some specific value. + /// This selects the elements for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbb20f6c..868673e7 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -132,7 +132,11 @@ pub enum Op { }, UpsampleNearest1D(Tensor), - UpsampleNearest2D(Tensor), + UpsampleNearest2D { + arg: Tensor, + target_h: usize, + target_w: usize, + }, Cat(Vec<Tensor>, usize), diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 5c3ac822..664f7653 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res q3 = q3.add(32); // Prepare low and high bits - // We hardcode the shifts here to avoid loading them into a seperate register + // We hardcode the shifts here to avoid loading them into a separate register let q3l_0 = _mm256_and_si256(q3bits, m3); let q3h_0 = if j == 0 { _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) @@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bits = _mm256_loadu_si256(q5 as *const __m256i); q5 = q5.add(32); - //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register + //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register let q5l_0 = _mm256_and_si256(q5bits, m4); let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); let q5l_0_right_shift = match j { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 620bc037..1e9dc517 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -463,7 +463,7 @@ impl Content { ) -> Result<QTensor> { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, - None => crate::bail!("cannot find tensor-infor for {name}"), + None => crate::bail!("cannot find tensor info for {name}"), }; tensor_info.read(reader, self.tensor_data_offset) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e478869a..f15f8c1c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -//! Tensors are N-dimenional matrixes of elements using a single data type. +//! Tensors are N-dimensional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ @@ -361,6 +361,16 @@ impl Tensor { Self::new_impl(array, shape, device, false) } + /// Returns a new tensor with all the elements having the same specified value. Note that + /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + pub fn full<D: crate::WithDType, S: Into<Shape>>( + value: D, + shape: S, + device: &Device, + ) -> Result<Self> { + Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape) + } + /// Creates a new 1D tensor from an iterator. pub fn from_iter<D: crate::WithDType>( iter: impl IntoIterator<Item = D>, @@ -669,7 +679,7 @@ impl Tensor { } /// Split a tensor into the specified number of chunks, this may return less chunks than - /// specificed. + /// specified. pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> { let dim = dim.to_index(self.shape(), "chunk")?; let size = self.dim(dim)?; @@ -994,7 +1004,11 @@ impl Tensor { /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; - let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D { + arg, + target_h, + target_w, + }); let storage = self .storage() .upsample_nearest2d(self.layout(), target_h, target_w)?; @@ -2558,6 +2572,13 @@ impl Tensor { } mask.where_cond(/* on_true= */ &src, /* on_false= */ self) } + + /// Returns log(sum(exp(tensor), dim)). + pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + let exp = self.exp()?; + let sum = exp.sum(sum_dims)?; + sum.log() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 791532f2..16e7a82f 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> { [0.7358, 2.0000, 0.2707, 1.0000] ); + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+7+8 = 18 + // 3+4+9+10 = 26 + // 5+6+11+12 = 34 + // row 2 + // 13+14+19+20 = 66 + // 15+16+21+22 = 74 + // 17+18+23+24 = 82 + // row 3 + // 25+26+31+32 = 114 + // 27+28+33+34 = 122 + // 29+30+35+36 = 130 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+3+7+8+9+13+14+15 = 72 + // 4+5+6+10+11+12+16+17+18 = 99 + // row 2 + // 19+20+21+25+26+27+31+32+33 = 234 + // 22+23+24+28+29+30+34+35+36 = 243 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[72_f32, 99.], [234., 261.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); + + // manually checked: see comments + let x = Var::new( + &[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]], + device, + )?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); Ok(()) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..e83fb55b 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,4 +1,4 @@ -use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> { Ok(()) } +fn full(device: &Device) -> Result<()> { + assert_eq!( + Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?, + [[42, 42, 42], [42, 42, 42]], + ); + Ok(()) +} + fn arange(device: &Device) -> Result<()> { assert_eq!( Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?, @@ -1072,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal); +test_device!(full, full_cpu, full_gpu, full_metal); test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); @@ -1221,3 +1230,26 @@ fn cumsum() -> Result<()> { ); Ok(()) } + +/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data. +/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon. +fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { + let a_vec: Vec<f64> = a.to_vec1()?; + let b_vec: Vec<f64> = b.to_vec1()?; + + assert_eq!(a_vec.len(), b_vec.len()); + for (a, b) in a_vec.iter().zip(b_vec.iter()) { + assert!((a - b).abs() < epsilon); + } + Ok(()) +} + +#[test] +fn logsumexp() -> Result<()> { + let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let output = input.logsumexp(D::Minus1)?; + // The expectations obtained from pytorch. + let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; + assert_close(&output, &expected, 0.00001)?; + Ok(()) +} |