summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/test_utils.rs8
-rw-r--r--candle-core/tests/conv_tests.rs35
-rw-r--r--candle-core/tests/grad_tests.rs32
-rw-r--r--candle-core/tests/layout_tests.rs2
-rw-r--r--candle-core/tests/pool_tests.rs10
-rw-r--r--candle-core/tests/tensor_tests.rs82
6 files changed, 121 insertions, 48 deletions
diff --git a/candle-core/src/test_utils.rs b/candle-core/src/test_utils.rs
index 8ff73fc0..3b8fb904 100644
--- a/candle-core/src/test_utils.rs
+++ b/candle-core/src/test_utils.rs
@@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
- ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
+ ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@@ -15,6 +15,12 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
+
+ #[cfg(feature = "metal")]
+ #[test]
+ fn $test_metal() -> Result<()> {
+ $fn_name(&Device::new_metal(0)?)
+ }
};
}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index a5375c11..39c6cec0 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
Ok(())
}
-test_device!(conv1d, conv1d_cpu, conv1d_gpu);
-test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
-test_device!(conv2d, conv2d_cpu, conv2d_gpu);
+test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
+test_device!(
+ conv1d_small,
+ conv1d_small_cpu,
+ conv1d_small_gpu,
+ conv1d_small_metal
+);
+test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
- conv2d_non_square_gpu
+ conv2d_non_square_gpu,
+ conv2d_non_square_metal
+);
+test_device!(
+ conv2d_small,
+ conv2d_small_cpu,
+ conv2d_small_gpu,
+ conv2d_small_metal
+);
+test_device!(
+ conv2d_smaller,
+ conv2d_smaller_cpu,
+ conv2d_smaller_gpu,
+ conv2d_smaller_metal
+);
+test_device!(
+ conv2d_grad,
+ conv2d_grad_cpu,
+ conv2d_grad_gpu,
+ conv2_grad_metal
);
-test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
-test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
-test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 6413ea2e..791532f2 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
-test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
-test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
-test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
-test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
-test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
-test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
+test_device!(
+ simple_grad,
+ simple_grad_cpu,
+ simple_grad_gpu,
+ simple_grad_metal
+);
+test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
+test_device!(
+ matmul_grad,
+ matmul_grad_cpu,
+ matmul_grad_gpu,
+ matmul_grad_metal
+);
+test_device!(
+ grad_descent,
+ grad_descent_cpu,
+ grad_descent_gpu,
+ grad_descent_metal
+);
+test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
+test_device!(
+ binary_grad,
+ binary_grad_cpu,
+ binary_grad_gpu,
+ binary_grad_metal
+);
diff --git a/candle-core/tests/layout_tests.rs b/candle-core/tests/layout_tests.rs
index 1b29476f..e0618850 100644
--- a/candle-core/tests/layout_tests.rs
+++ b/candle-core/tests/layout_tests.rs
@@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
-test_device!(contiguous, contiguous_cpu, contiguous_gpu);
+test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
#[test]
fn strided_blocks() -> Result<()> {
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index c6db194d..a3708ec4 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
-test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
+test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
- avg_pool2d_pytorch_gpu
+ avg_pool2d_pytorch_gpu,
+ avg_pool2d_pytorch_metal
);
-test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
+test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
- upsample_nearest2d_gpu
+ upsample_nearest2d_gpu,
+ upsample_nearest2d_metal
);
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index f565972a..eb684909 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1070,35 +1070,59 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
-test_device!(zeros, zeros_cpu, zeros_gpu);
-test_device!(ones, ones_cpu, ones_gpu);
-test_device!(arange, arange_cpu, arange_gpu);
-test_device!(add_mul, add_mul_cpu, add_mul_gpu);
-test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
-test_device!(narrow, narrow_cpu, narrow_gpu);
-test_device!(broadcast, broadcast_cpu, broadcast_gpu);
-test_device!(cat, cat_cpu, cat_gpu);
-test_device!(sum, sum_cpu, sum_gpu);
-test_device!(min, min_cpu, min_gpu);
-test_device!(max, max_cpu, max_gpu);
-test_device!(argmax, argmax_cpu, argmax_gpu);
-test_device!(argmin, argmin_cpu, argmin_gpu);
-test_device!(transpose, transpose_cpu, transpose_gpu);
-test_device!(unary_op, unary_op_cpu, unary_op_gpu);
-test_device!(binary_op, binary_op_cpu, binary_op_gpu);
-test_device!(embeddings, embeddings_cpu, embeddings_gpu);
-test_device!(cmp, cmp_cpu, cmp_gpu);
-test_device!(matmul, matmul_cpu, matmul_gpu);
-test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
-test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
-test_device!(index_select, index_select_cpu, index_select_gpu);
-test_device!(index_add, index_add_cpu, index_add_gpu);
-test_device!(gather, gather_cpu, gather_gpu);
-test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
-test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
-test_device!(randn, randn_cpu, randn_gpu);
-test_device!(clamp, clamp_cpu, clamp_gpu);
-test_device!(var, var_cpu, var_gpu);
+test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
+test_device!(ones, ones_cpu, ones_gpu, ones_metal);
+test_device!(arange, arange_cpu, arange_gpu, arange_metal);
+test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
+test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
+test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
+test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
+test_device!(cat, cat_cpu, cat_gpu, cat_metal);
+test_device!(sum, sum_cpu, sum_gpu, sum_metal);
+test_device!(min, min_cpu, min_gpu, min_metal);
+test_device!(max, max_cpu, max_gpu, max_metal);
+test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
+test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
+test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
+test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
+test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
+test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
+test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
+test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
+test_device!(
+ broadcast_matmul,
+ broadcast_matmul_cpu,
+ broadcast_matmul_gpu,
+ broadcast_matmul_metal
+);
+test_device!(
+ broadcasting,
+ broadcasting_cpu,
+ broadcasting_gpu,
+ broadcasting_metal
+);
+test_device!(
+ index_select,
+ index_select_cpu,
+ index_select_gpu,
+ index_select_metal
+);
+test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
+test_device!(gather, gather_cpu, gather_gpu, gather_metal);
+test_device!(
+ scatter_add,
+ scatter_add_cpu,
+ scatter_add_gpu,
+ scatter_add_metal
+);
+test_device!(
+ slice_scatter,
+ slice_scatter_cpu,
+ slice_scatter_gpu,
+ slice_scatter_metal
+);
+test_device!(randn, randn_cpu, randn_gpu, randn_metal);
+test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381