summaryrefslogtreecommitdiff
path: root/candle-core/tests/custom_op_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r--candle-core/tests/custom_op_tests.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
index 55b5e894..7ec04c6a 100644
--- a/candle-core/tests/custom_op_tests.rs
+++ b/candle-core/tests/custom_op_tests.rs
@@ -39,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
- let elu_t = t.custom_op1(Elu { alpha: 1. })?;
+ let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&elu_t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
@@ -96,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
- let bwd = arg.custom_op1(EluBackward { alpha })?;
+ let bwd = arg.apply_op1(EluBackward { alpha })?;
Ok(Some(grad_res.mul(&bwd)?))
}
}
@@ -105,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
- let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
+ let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
let grads = elu_t.backward()?;