diff options
-rw-r--r-- | candle-core/src/cpu_backend/mod.rs | 8 | ||||
-rw-r--r-- | candle-core/tests/matmul_tests.rs | 13 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 13 |
3 files changed, 8 insertions, 26 deletions
diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 09226b58..6f8250f0 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1330,7 +1330,7 @@ impl Map2 for MatMul { let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, b'N') } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') @@ -1338,7 +1338,7 @@ impl Map2 for MatMul { Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, b'N') } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') @@ -1421,7 +1421,7 @@ impl Map2 for MatMul { let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, b'N') } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') @@ -1429,7 +1429,7 @@ impl Map2 for MatMul { Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, b'N') } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index 834da29a..e3e18107 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -73,20 +73,7 @@ fn squeeze_mm(device: &Device) -> Result<()> { let seq_len = 8_usize; let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?; let x = a.i((.., seq_len - 1, ..))?; - println!( - "x shape:{:?}, stride:{:?}, is_contiguous:{}", - x.shape(), - x.stride(), - x.is_contiguous() - ); - let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?; - println!( - "w shape:{:?}, stride:{:?}, is_contiguous:{}", - w.shape(), - w.stride(), - w.is_contiguous() - ); let x = x.matmul(&w)?; assert_eq!(x.dims(), &[1, 32]); Ok(()) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 1e2c1c77..b3275804 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -107,13 +107,8 @@ fn unary_op(device: &Device) -> Result<()> { ] ); let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?; - assert_eq!( - test_utils::to_vec2_round(&t_f16, 2)?, - [ - [-0.0, 0.84, 4.0, -0.05, 0.35], - [2.69, -0.07, -0.11, 1.73, 2.79] - ], - ); + let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?; + assert!(max_diff.to_vec0::<f32>()? < 5e-3); assert_eq!( test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, [ @@ -1255,8 +1250,8 @@ fn pow() -> Result<()> { let rhs = (&lhs - 2.)?; let res = lhs.pow(&rhs)?; assert_eq!( - test_utils::to_vec2_round(&res, 4)?, - [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + test_utils::to_vec2_round(&res, 3)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]] ); Ok(()) } |