diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-27 15:20:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-27 15:20:37 +0100 |
commit | 0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch) | |
tree | daa4ae5e627a95c6c54cd8eabec32e15d71309be /candle-core/tests/custom_op_tests.rs | |
parent | 594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff) | |
download | candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2 candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip |
UG metal integration. (#2580)
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r-- | candle-core/tests/custom_op_tests.rs | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index f2c01aca..3572a4c9 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> { Ok(()) } -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "metal"))] #[allow(clippy::approx_constant)] #[test] fn ug_op() -> Result<()> { @@ -160,15 +160,21 @@ fn ug_op() -> Result<()> { let opts: ug::lower_op::Opts = Default::default(); kernel.lower(&opts.with_global(0, 12))? }; - let device = Device::new_cuda(0)?; + let device = if candle_core::utils::cuda_is_available() { + Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + Device::new_metal(0)? + } else { + candle_core::bail!("metal/cuda is mandatory for this test") + }; let op = candle_core::UgIOp1::new("test", kernel, &device)?; let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; t.inplace_op1(&op)?; assert_eq!( - to_vec1_round(&t, 4)?, + to_vec1_round(&t, 2)?, &[ - 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578, - 8103.0806, 22026.469, 59874.133 + 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47, + 59874.13 ] ); Ok(()) |