summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/Cargo.toml4
-rw-r--r--candle-core/src/backprop.rs26
-rw-r--r--candle-core/src/indexer.rs2
-rw-r--r--candle-core/src/op.rs6
-rw-r--r--candle-core/src/quantized/avx.rs4
-rw-r--r--candle-core/src/quantized/gguf_file.rs2
-rw-r--r--candle-core/src/tensor.rs27
-rw-r--r--candle-core/tests/grad_tests.rs160
-rw-r--r--candle-core/tests/tensor_tests.rs34
9 files changed, 250 insertions, 15 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 0f8c1a9f..52e79a5a 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,8 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
-candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
+candle-kernels = { path = "../candle-kernels", version = "0.3.2", optional = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.2", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index fc0c79a2..c152f31f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -114,7 +114,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
- | Op::UpsampleNearest2D(node)
+ | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
@@ -350,9 +350,27 @@ impl Tensor {
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
- Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
- op: "upsample-nearest2d",
- })?,
+ Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ } => {
+ let (_n, c, h, w) = arg.dims4()?;
+ if target_h % h != 0 || target_w % w != 0 {
+ crate::bail!("backward not supported for non integer upscaling factors")
+ }
+ let scale_h = target_h / h;
+ let scale_w = target_w / w;
+
+ if scale_h != scale_w {
+ crate::bail!("backward not supported for non uniform upscaling factors")
+ };
+ let kernel =
+ Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
+ let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = conv_sum;
+ }
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index df106b73..e3ed41e5 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -64,7 +64,7 @@ impl Tensor {
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
- /// This selects the elemnts for which an index has some specific value.
+ /// This selects the elements for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index fbb20f6c..868673e7 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -132,7 +132,11 @@ pub enum Op {
},
UpsampleNearest1D(Tensor),
- UpsampleNearest2D(Tensor),
+ UpsampleNearest2D {
+ arg: Tensor,
+ target_h: usize,
+ target_w: usize,
+ },
Cat(Vec<Tensor>, usize),
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index 5c3ac822..664f7653 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
q3 = q3.add(32);
// Prepare low and high bits
- // We hardcode the shifts here to avoid loading them into a seperate register
+ // We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
@@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
- //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
+ //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index 620bc037..1e9dc517 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -463,7 +463,7 @@ impl Content {
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
- None => crate::bail!("cannot find tensor-infor for {name}"),
+ None => crate::bail!("cannot find tensor info for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e478869a..f15f8c1c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,4 +1,4 @@
-//! Tensors are N-dimenional matrixes of elements using a single data type.
+//! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
@@ -361,6 +361,16 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
+ /// Returns a new tensor with all the elements having the same specified value. Note that
+ /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
+ pub fn full<D: crate::WithDType, S: Into<Shape>>(
+ value: D,
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
+ }
+
/// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
@@ -669,7 +679,7 @@ impl Tensor {
}
/// Split a tensor into the specified number of chunks, this may return less chunks than
- /// specificed.
+ /// specified.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?;
@@ -994,7 +1004,11 @@ impl Tensor {
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
- let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
+ let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ });
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
@@ -2558,6 +2572,13 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
+
+ /// Returns log(sum(exp(tensor), dim)).
+ pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ let exp = self.exp()?;
+ let sum = exp.sum(sum_dims)?;
+ sum.log()
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 791532f2..16e7a82f 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000]
);
+ // manually checked: see comments
+ let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
+ let y = x.interpolate2d(6, 6)?.reshape(36)?;
+
+ #[rustfmt::skip]
+ let z = Tensor::new(
+ &[
+ 1_f32, 02., 03., 04., 05., 06.,
+ 07., 08., 09., 10., 11., 12.,
+ 13., 14., 15., 16., 17., 18.,
+ 19., 20., 21., 22., 23., 24.,
+ 25., 26., 27., 28., 29., 30.,
+ 31., 32., 33., 34., 35., 36.,
+ ],
+ device,
+ )?;
+ // gradient should be
+ // row 1
+ // 1+2+7+8 = 18
+ // 3+4+9+10 = 26
+ // 5+6+11+12 = 34
+ // row 2
+ // 13+14+19+20 = 66
+ // 15+16+21+22 = 74
+ // 17+18+23+24 = 82
+ // row 3
+ // 25+26+31+32 = 114
+ // 27+28+33+34 = 122
+ // 29+30+35+36 = 130
+ let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
+
+ let grads = loss.backward()?;
+
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ assert_eq!(
+ test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
+ [[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
+ );
+
+ // manually checked: see comments
+ let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
+ let y = x.interpolate2d(6, 6)?.reshape(36)?;
+
+ #[rustfmt::skip]
+ let z = Tensor::new(
+ &[
+ 1_f32, 02., 03., 04., 05., 06.,
+ 07., 08., 09., 10., 11., 12.,
+ 13., 14., 15., 16., 17., 18.,
+ 19., 20., 21., 22., 23., 24.,
+ 25., 26., 27., 28., 29., 30.,
+ 31., 32., 33., 34., 35., 36.,
+ ],
+ device,
+ )?;
+ // gradient should be
+ // row 1
+ // 1+2+3+7+8+9+13+14+15 = 72
+ // 4+5+6+10+11+12+16+17+18 = 99
+ // row 2
+ // 19+20+21+25+26+27+31+32+33 = 234
+ // 22+23+24+28+29+30+34+35+36 = 243
+ let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
+
+ let grads = loss.backward()?;
+
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ assert_eq!(
+ test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
+ [[72_f32, 99.], [234., 261.]]
+ );
+
+ // manually checked: see comments
+ let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
+
+ let y = x.interpolate2d(4, 4)?.reshape(32)?;
+
+ #[rustfmt::skip]
+ let z = Tensor::new(
+ &[
+ 1_f32, 02., 03., 04.,
+ 05., 06., 07., 08.,
+ 09., 10., 11., 12.,
+ 13., 14., 15., 16.,
+ 17., 18., 19., 20.,
+ 21., 22., 23., 24.,
+ 25., 26., 27., 28.,
+ 29., 30., 31., 32.
+ ],
+ device,
+ )?;
+ // gradient should be
+ // m1r1
+ // 1+2+5+6=14
+ // 3+4+7+8=22
+ // m1r2
+ // 9+10+13+14=46
+ // 11+12+15+16=54
+ // m2r1
+ // 17+18+21+22=78
+ // 19+20+23+24=86
+ // m2r2
+ // 25+26+29+30=110
+ // 27+28+31+32=118
+ let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
+
+ let grads = loss.backward()?;
+
+ let grad_x = grads.get(&x).context("no grad for x")?;
+
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
+ [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
+ );
+
+ // manually checked: see comments
+ let x = Var::new(
+ &[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
+ device,
+ )?;
+
+ let y = x.interpolate2d(4, 4)?.reshape(32)?;
+
+ #[rustfmt::skip]
+ let z = Tensor::new(
+ &[
+ 1_f32, 02., 03., 04.,
+ 05., 06., 07., 08.,
+ 09., 10., 11., 12.,
+ 13., 14., 15., 16.,
+ 17., 18., 19., 20.,
+ 21., 22., 23., 24.,
+ 25., 26., 27., 28.,
+ 29., 30., 31., 32.
+ ],
+ device,
+ )?;
+ // gradient should be
+ // m1r1
+ // 1+2+5+6=14
+ // 3+4+7+8=22
+ // m1r2
+ // 9+10+13+14=46
+ // 11+12+15+16=54
+ // m2r1
+ // 17+18+21+22=78
+ // 19+20+23+24=86
+ // m2r2
+ // 25+26+29+30=110
+ // 27+28+31+32=118
+ let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
+
+ let grads = loss.backward()?;
+
+ let grad_x = grads.get(&x).context("no grad for x")?;
+
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
+ [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
+ );
Ok(())
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index c871dc96..e83fb55b 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1,4 +1,4 @@
-use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
+use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> {
Ok(())
}
+fn full(device: &Device) -> Result<()> {
+ assert_eq!(
+ Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
+ [[42, 42, 42], [42, 42, 42]],
+ );
+ Ok(())
+}
+
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
@@ -1072,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
+test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@@ -1221,3 +1230,26 @@ fn cumsum() -> Result<()> {
);
Ok(())
}
+
+/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
+/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
+fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
+ let a_vec: Vec<f64> = a.to_vec1()?;
+ let b_vec: Vec<f64> = b.to_vec1()?;
+
+ assert_eq!(a_vec.len(), b_vec.len());
+ for (a, b) in a_vec.iter().zip(b_vec.iter()) {
+ assert!((a - b).abs() < epsilon);
+ }
+ Ok(())
+}
+
+#[test]
+fn logsumexp() -> Result<()> {
+ let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
+ let output = input.logsumexp(D::Minus1)?;
+ // The expectations obtained from pytorch.
+ let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
+ assert_close(&output, &expected, 0.00001)?;
+ Ok(())
+}