summaryrefslogtreecommitdiff
path: root/candle-core/tests/custom_op_tests.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-27 15:20:37 +0100
committerGitHub <noreply@github.com>2024-10-27 15:20:37 +0100
commit0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch)
treedaa4ae5e627a95c6c54cd8eabec32e15d71309be /candle-core/tests/custom_op_tests.rs
parent594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff)
downloadcandle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip
UG metal integration. (#2580)
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r--candle-core/tests/custom_op_tests.rs16
1 files changed, 11 insertions, 5 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
index f2c01aca..3572a4c9 100644
--- a/candle-core/tests/custom_op_tests.rs
+++ b/candle-core/tests/custom_op_tests.rs
@@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> {
Ok(())
}
-#[cfg(feature = "cuda")]
+#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
@@ -160,15 +160,21 @@ fn ug_op() -> Result<()> {
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
- let device = Device::new_cuda(0)?;
+ let device = if candle_core::utils::cuda_is_available() {
+ Device::new_cuda(0)?
+ } else if candle_core::utils::metal_is_available() {
+ Device::new_metal(0)?
+ } else {
+ candle_core::bail!("metal/cuda is mandatory for this test")
+ };
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
- to_vec1_round(&t, 4)?,
+ to_vec1_round(&t, 2)?,
&[
- 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
- 8103.0806, 22026.469, 59874.133
+ 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
+ 59874.13
]
);
Ok(())