summaryrefslogtreecommitdiff
path: root/candle-core/tests/custom_op_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r--candle-core/tests/custom_op_tests.rs30
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
index be59e0c0..f2c01aca 100644
--- a/candle-core/tests/custom_op_tests.rs
+++ b/candle-core/tests/custom_op_tests.rs
@@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}
+
+#[cfg(feature = "cuda")]
+#[allow(clippy::approx_constant)]
+#[test]
+fn ug_op() -> Result<()> {
+ let kernel = {
+ use ug::lang::op;
+
+ let layout = ug::Layout::from_shape(&[12]);
+ let ptr = op::Arg::ptr(ug::DType::F32);
+ let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
+ let src = op::unary(op::UnaryOp::Exp, src)?;
+ let st = op::store(ptr.id(), layout, src)?;
+ let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
+ let opts: ug::lower_op::Opts = Default::default();
+ kernel.lower(&opts.with_global(0, 12))?
+ };
+ let device = Device::new_cuda(0)?;
+ let op = candle_core::UgIOp1::new("test", kernel, &device)?;
+ let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
+ t.inplace_op1(&op)?;
+ assert_eq!(
+ to_vec1_round(&t, 4)?,
+ &[
+ 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
+ 8103.0806, 22026.469, 59874.133
+ ]
+ );
+ Ok(())
+}