summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/shape.rs63
-rw-r--r--candle-core/src/tensor.rs68
-rw-r--r--candle-core/tests/tensor_tests.rs20
3 files changed, 108 insertions, 43 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index d8f8f756..49fbf022 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -185,6 +185,69 @@ impl Shape {
self.0.extend(additional_dims);
self
}
+
+ /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
+ /// broadcasted shape. This is to be used for binary pointwise ops.
+ pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
+ let lhs = self;
+ let lhs_dims = lhs.dims();
+ let rhs_dims = rhs.dims();
+ let lhs_ndims = lhs_dims.len();
+ let rhs_ndims = rhs_dims.len();
+ let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
+ let mut bcast_dims = vec![0; bcast_ndims];
+ for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
+ let rev_idx = bcast_ndims - idx;
+ let l_value = if lhs_ndims < rev_idx {
+ 1
+ } else {
+ lhs_dims[lhs_ndims - rev_idx]
+ };
+ let r_value = if rhs_ndims < rev_idx {
+ 1
+ } else {
+ rhs_dims[rhs_ndims - rev_idx]
+ };
+ *bcast_value = if l_value == r_value {
+ l_value
+ } else if l_value == 1 {
+ r_value
+ } else if r_value == 1 {
+ l_value
+ } else {
+ Err(Error::ShapeMismatchBinaryOp {
+ lhs: lhs.clone(),
+ rhs: rhs.clone(),
+ op,
+ }
+ .bt())?
+ }
+ }
+ Ok(Shape::from(bcast_dims))
+ }
+
+ pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
+ let lhs = self;
+ let lhs_dims = lhs.dims();
+ let rhs_dims = rhs.dims();
+ if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
+ crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
+ }
+ let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
+ let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
+ if lhs_k != rhs_k {
+ crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
+ }
+
+ let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
+ let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
+ let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
+ let bcast_dims = bcast.dims();
+
+ let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
+ let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
+ Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
+ }
}
pub trait Dim {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 45aa07bc..4ea66186 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -106,7 +106,9 @@ macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
- let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
+ let shape = lhs
+ .shape()
+ .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
@@ -415,48 +417,6 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, false)
}
- pub(crate) fn broadcast_shape_binary_op<'a>(
- &'a self,
- rhs: &'a Self,
- op: &'static str,
- ) -> Result<Shape> {
- let lhs = self;
- let lhs_dims = lhs.shape().dims();
- let rhs_dims = rhs.shape().dims();
- let lhs_ndims = lhs_dims.len();
- let rhs_ndims = rhs_dims.len();
- let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
- let mut bcast_dims = vec![0; bcast_ndims];
- for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
- let rev_idx = bcast_ndims - idx;
- let l_value = if lhs_ndims < rev_idx {
- 1
- } else {
- lhs_dims[lhs_ndims - rev_idx]
- };
- let r_value = if rhs_ndims < rev_idx {
- 1
- } else {
- rhs_dims[rhs_ndims - rev_idx]
- };
- *bcast_value = if l_value == r_value {
- l_value
- } else if l_value == 1 {
- r_value
- } else if r_value == 1 {
- l_value
- } else {
- Err(Error::ShapeMismatchBinaryOp {
- lhs: self.shape().clone(),
- rhs: rhs.shape().clone(),
- op,
- }
- .bt())?
- }
- }
- Ok(Shape::from(bcast_dims))
- }
-
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@@ -961,6 +921,28 @@ impl Tensor {
Ok(from_storage(storage, c_shape, op, false))
}
+ /// Matrix-multiplication with broadcasting support.
+ ///
+ /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
+ /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
+ /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
+ pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
+ let lhs = self;
+ let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
+ let l_broadcast = l_shape != *lhs.shape();
+ let r_broadcast = r_shape != *rhs.shape();
+ // TODO: Avoid concretising the broadcasted matrixes via contiguous.
+ match (l_broadcast, r_broadcast) {
+ (true, true) => lhs
+ .broadcast_as(&l_shape)?
+ .contiguous()?
+ .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
+ (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
+ (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
+ (false, false) => lhs.matmul(rhs),
+ }
+ }
+
/// Returns a tensor with the same shape as the input tensor, the values are taken from
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
/// input tensor is equal to zero.
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 0b77f1a5..907c876e 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -747,6 +747,25 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
+fn broadcast_matmul(device: &Device) -> Result<()> {
+ let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
+ let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
+ let out = lhs.broadcast_matmul(&rhs)?;
+ assert_eq!(out.dims(), &[3, 6, 4, 2]);
+ for idx1 in 0..3 {
+ for idx2 in 0..6 {
+ let out = out.i((idx1, idx2))?;
+ let lhs = lhs.i((idx1, 0))?;
+ let rhs = rhs.i(idx2)?;
+ let out2 = lhs.matmul(&rhs);
+ let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
+ // With cuda, we see errors of up to ~1e-12.
+ assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
+ }
+ }
+ Ok(())
+}
+
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@@ -864,6 +883,7 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
+test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);