summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs12
-rw-r--r--candle-core/src/tensor.rs10
-rw-r--r--candle-core/tests/tensor_tests.rs25
-rw-r--r--candle-metal-kernels/src/lib.rs12
4 files changed, 51 insertions, 8 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f0f03053..97dc346e 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1651,9 +1651,11 @@ fn gemm_config<T>(
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
- let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
+ // We also allow for the case where the stride on the minor dimension is not as expected but
+ // there is a single element.
+ let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
- } else if rhs_m1 == k && rhs_m2 == 1 {
+ } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
@@ -1663,9 +1665,11 @@ fn gemm_config<T>(
})?
};
// The b tensor has dims batching, m, k (lhs)
- let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
+ // We also allow for the case where the stride on the minor dimension is not as expected but
+ // there is a single element.
+ let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
- } else if lhs_m1 == m && lhs_m2 == 1 {
+ } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 92c931eb..b53b0419 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2007,6 +2007,16 @@ impl Tensor {
}
}
+ /// Returns a tensor that is in row major order. This always makes a copy.
+ pub fn force_contiguous(&self) -> Result<Tensor> {
+ let shape = self.shape();
+ let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
+ self.storage()
+ .copy_strided_src(&mut storage, 0, self.layout())?;
+ let op = BackpropOp::new1(self, Op::Copy);
+ Ok(from_storage(storage, shape.clone(), op, false))
+ }
+
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index b2475adc..af28c1c1 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1135,6 +1135,30 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
+// https://github.com/huggingface/candle/issues/1948
+fn squeeze_mm(device: &Device) -> Result<()> {
+ let seq_len = 8_usize;
+ let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
+ let x = a.i((.., seq_len - 1, ..))?;
+ println!(
+ "x shape:{:?}, stride:{:?}, is_contiguous:{}",
+ x.shape(),
+ x.stride(),
+ x.is_contiguous()
+ );
+
+ let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
+ println!(
+ "w shape:{:?}, stride:{:?}, is_contiguous:{}",
+ w.shape(),
+ w.stride(),
+ w.is_contiguous()
+ );
+ let x = x.matmul(&w)?;
+ assert_eq!(x.dims(), &[1, 32]);
+ Ok(())
+}
+
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@@ -1190,6 +1214,7 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
+test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 449bef8f..3f452331 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1451,9 +1451,12 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
- let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
+ // lhs has shape b, m, k
+ // We also allow for the case where the stride on the minor dimension is not as expected but
+ // there is a single element.
+ let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
false
- } else if lhs_m1 == m && lhs_m2 == 1 {
+ } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@@ -1462,9 +1465,10 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
- let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
+ // rhs has shape b, k, n
+ let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
false
- } else if rhs_m1 == k && rhs_m2 == 1 {
+ } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {