diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-21 16:18:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-21 15:18:05 +0100 |
commit | a6bcdfb2697f68acff7009a22876fb5d7e5734eb (patch) | |
tree | 66ff966da439c34535d63cca579d25810ed519d2 /candle-core/tests/custom_op_tests.rs | |
parent | b02229ce9279079a4f38510a4e735db92a4649b2 (diff) | |
download | candle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.tar.gz candle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.tar.bz2 candle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.zip |
Custom ops with a single argument (#214)
* Add the CustomOp1 trait.
* Add an example of custom op.
* Polish the custom op example.
* Add some backward pass test for custom ops.
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r-- | candle-core/tests/custom_op_tests.rs | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs new file mode 100644 index 00000000..3e1e0c19 --- /dev/null +++ b/candle-core/tests/custom_op_tests.rs @@ -0,0 +1,157 @@ +use candle::backend::BackendStorage; +use candle::cpu_backend; +use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; +use half::{bf16, f16}; + +mod test_utils; +use test_utils::to_vec1_round; + +fn fwd<T: num_traits::Float>(v: T, alpha: T) -> T { + if v.is_sign_positive() { + v + } else { + (v.exp() - T::one()) * alpha + } +} + +struct Elu { + alpha: f64, +} + +impl CustomOp1 for Elu { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + use CpuStorage::*; + + // In this example, we pattern match over the different dtypes. Some helper functions and + // traits from the `cpu_backend` module can be used to avoid this in some common cases, see + // e.g. `Map1`. + let storage = match s { + BF16(s) => { + let alpha = bf16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + BF16(data) + } + F16(s) => { + let alpha = f16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + F16(data) + } + F32(s) => { + let alpha = self.alpha as f32; + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + F32(data) + } + F64(s) => { + let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)); + F64(data) + } + _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, + }; + Ok((storage, l.shape().clone())) + } +} + +#[test] +fn custom_op1_no_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; + let t = (t - 5.)?; + let elu_t = t.custom_op1(Elu { alpha: 1. })?; + assert_eq!( + to_vec1_round(&elu_t, 4)?, + &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + Ok(()) +} + +// Define a similar struct as Elu but with backward support. +fn bwd<T: num_traits::Float>(v: T, alpha: T) -> T { + if v.is_sign_positive() { + T::one() + } else { + v.exp() * alpha + } +} + +struct EluBackward { + alpha: f64, +} + +impl CustomOp1 for EluBackward { + fn name(&self) -> &'static str { + "elu-bwd" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + use CpuStorage::*; + + // In this example, we pattern match over the different dtypes. Some helper functions and + // traits from the `cpu_backend` module can be used to avoid this in some common cases, see + // e.g. `Map1`. + let storage = match s { + BF16(s) => { + let alpha = bf16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + BF16(data) + } + F16(s) => { + let alpha = f16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + F16(data) + } + F32(s) => { + let alpha = self.alpha as f32; + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + F32(data) + } + F64(s) => { + let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)); + F64(data) + } + _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, + }; + Ok((storage, l.shape().clone())) + } +} + +struct EluWithBackward(Elu); + +impl EluWithBackward { + fn new(alpha: f64) -> Self { + Self(Elu { alpha }) + } +} + +impl CustomOp1 for EluWithBackward { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + self.0.cpu_fwd(s, l) + } + + fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Tensor> { + let alpha = self.0.alpha; + let bwd = arg.custom_op1(EluBackward { alpha })?; + grad_res.mul(&bwd) + } +} + +#[test] +fn custom_op1_with_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = candle::Var::new(&[-2f32, 0f32, 2f32], cpu)?; + let elu_t = t.custom_op1(EluWithBackward::new(2.))?; + assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); + + let grads = elu_t.backward()?; + let grad_x = grads.get(&t).unwrap(); + assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]); + + Ok(()) +} |