diff options
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r-- | candle-core/tests/custom_op_tests.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index 55b5e894..7ec04c6a 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -39,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> { let cpu = &Device::Cpu; let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; let t = (t - 5.)?; - let elu_t = t.custom_op1(Elu { alpha: 1. })?; + let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?; assert_eq!( to_vec1_round(&elu_t, 4)?, &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] @@ -96,7 +96,7 @@ impl CustomOp1 for EluWithBackward { fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> { let alpha = self.0.alpha; - let bwd = arg.custom_op1(EluBackward { alpha })?; + let bwd = arg.apply_op1(EluBackward { alpha })?; Ok(Some(grad_res.mul(&bwd)?)) } } @@ -105,7 +105,7 @@ impl CustomOp1 for EluWithBackward { fn custom_op1_with_backward() -> Result<()> { let cpu = &Device::Cpu; let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?; - let elu_t = t.custom_op1(EluWithBackward::new(2.))?; + let elu_t = t.apply_op1(EluWithBackward::new(2.))?; assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); let grads = elu_t.backward()?; |