summaryrefslogtreecommitdiff
path: root/candle-core/tests/custom_op_tests.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-21 16:18:05 +0200
committerGitHub <noreply@github.com>2023-07-21 15:18:05 +0100
commita6bcdfb2697f68acff7009a22876fb5d7e5734eb (patch)
tree66ff966da439c34535d63cca579d25810ed519d2 /candle-core/tests/custom_op_tests.rs
parentb02229ce9279079a4f38510a4e735db92a4649b2 (diff)
downloadcandle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.tar.gz
candle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.tar.bz2
candle-a6bcdfb2697f68acff7009a22876fb5d7e5734eb.zip
Custom ops with a single argument (#214)
* Add the CustomOp1 trait. * Add an example of custom op. * Polish the custom op example. * Add some backward pass test for custom ops.
Diffstat (limited to 'candle-core/tests/custom_op_tests.rs')
-rw-r--r--candle-core/tests/custom_op_tests.rs157
1 files changed, 157 insertions, 0 deletions
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
new file mode 100644
index 00000000..3e1e0c19
--- /dev/null
+++ b/candle-core/tests/custom_op_tests.rs
@@ -0,0 +1,157 @@
+use candle::backend::BackendStorage;
+use candle::cpu_backend;
+use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
+use half::{bf16, f16};
+
+mod test_utils;
+use test_utils::to_vec1_round;
+
+fn fwd<T: num_traits::Float>(v: T, alpha: T) -> T {
+ if v.is_sign_positive() {
+ v
+ } else {
+ (v.exp() - T::one()) * alpha
+ }
+}
+
+struct Elu {
+ alpha: f64,
+}
+
+impl CustomOp1 for Elu {
+ fn name(&self) -> &'static str {
+ "elu"
+ }
+
+ fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
+ use CpuStorage::*;
+
+ // In this example, we pattern match over the different dtypes. Some helper functions and
+ // traits from the `cpu_backend` module can be used to avoid this in some common cases, see
+ // e.g. `Map1`.
+ let storage = match s {
+ BF16(s) => {
+ let alpha = bf16::from_f64(self.alpha);
+ let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
+ BF16(data)
+ }
+ F16(s) => {
+ let alpha = f16::from_f64(self.alpha);
+ let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
+ F16(data)
+ }
+ F32(s) => {
+ let alpha = self.alpha as f32;
+ let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
+ F32(data)
+ }
+ F64(s) => {
+ let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha));
+ F64(data)
+ }
+ _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
+ };
+ Ok((storage, l.shape().clone()))
+ }
+}
+
+#[test]
+fn custom_op1_no_backward() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
+ let t = (t - 5.)?;
+ let elu_t = t.custom_op1(Elu { alpha: 1. })?;
+ assert_eq!(
+ to_vec1_round(&elu_t, 4)?,
+ &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+ );
+ Ok(())
+}
+
+// Define a similar struct as Elu but with backward support.
+fn bwd<T: num_traits::Float>(v: T, alpha: T) -> T {
+ if v.is_sign_positive() {
+ T::one()
+ } else {
+ v.exp() * alpha
+ }
+}
+
+struct EluBackward {
+ alpha: f64,
+}
+
+impl CustomOp1 for EluBackward {
+ fn name(&self) -> &'static str {
+ "elu-bwd"
+ }
+
+ fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
+ use CpuStorage::*;
+
+ // In this example, we pattern match over the different dtypes. Some helper functions and
+ // traits from the `cpu_backend` module can be used to avoid this in some common cases, see
+ // e.g. `Map1`.
+ let storage = match s {
+ BF16(s) => {
+ let alpha = bf16::from_f64(self.alpha);
+ let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
+ BF16(data)
+ }
+ F16(s) => {
+ let alpha = f16::from_f64(self.alpha);
+ let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
+ F16(data)
+ }
+ F32(s) => {
+ let alpha = self.alpha as f32;
+ let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
+ F32(data)
+ }
+ F64(s) => {
+ let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha));
+ F64(data)
+ }
+ _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
+ };
+ Ok((storage, l.shape().clone()))
+ }
+}
+
+struct EluWithBackward(Elu);
+
+impl EluWithBackward {
+ fn new(alpha: f64) -> Self {
+ Self(Elu { alpha })
+ }
+}
+
+impl CustomOp1 for EluWithBackward {
+ fn name(&self) -> &'static str {
+ "elu"
+ }
+
+ fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
+ self.0.cpu_fwd(s, l)
+ }
+
+ fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Tensor> {
+ let alpha = self.0.alpha;
+ let bwd = arg.custom_op1(EluBackward { alpha })?;
+ grad_res.mul(&bwd)
+ }
+}
+
+#[test]
+fn custom_op1_with_backward() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let t = candle::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
+ let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
+ assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
+
+ let grads = elu_t.backward()?;
+ let grad_x = grads.get(&t).unwrap();
+ assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]);
+
+ Ok(())
+}