diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-20 13:28:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-20 12:28:45 +0100 |
commit | 2a8f28d687b5a33c6c91f40100a42baf6e2fc10a (patch) | |
tree | ee76b9df39f52212666ad6f2096be541973cdafd /candle-core/src/op.rs | |
parent | e9c052bf94521b418852a1c5231c12ddce99a78f (diff) | |
download | candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.gz candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.bz2 candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.zip |
Op refactor (#208)
* Add the binary and unary op enums to factorize some code.
* Bugfix.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 54 |
1 files changed, 33 insertions, 21 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ece6969c..3dd3de68 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -19,12 +19,34 @@ pub enum ReduceOp { Max, } +// These ops return the same type as their input type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOp { + Add, + Mul, + Sub, + Div, +} + +// Unary ops with no argument +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOp { + Exp, + Log, + Sin, + Cos, + Abs, + Neg, + Sqr, + Sqrt, + Gelu, + Relu, +} + #[derive(Clone)] pub(crate) enum Op { - Add(Tensor, Tensor), - Mul(Tensor, Tensor), - Sub(Tensor, Tensor), - Div(Tensor, Tensor), + Binary(Tensor, Tensor, BinaryOp), + Unary(Tensor, UnaryOp), Cmp(Tensor, CmpOp), Reduce(Tensor, ReduceOp, Vec<usize>), Matmul(Tensor, Tensor), @@ -49,26 +71,16 @@ pub(crate) enum Op { }, ToDType(Tensor), Broadcast(Tensor), - Exp(Tensor), - Log(Tensor), - Sin(Tensor), - Cos(Tensor), - Abs(Tensor), Narrow(Tensor, usize, usize, usize), - Neg(Tensor), Reshape(Tensor), Softmax(Tensor, usize), - Sqr(Tensor), - Sqrt(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), - Gelu(Tensor), - Relu(Tensor), Elu(Tensor, f64), // TODO: Support for custom ops. } -pub(crate) trait UnaryOp { +pub(crate) trait UnaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; @@ -91,7 +103,7 @@ pub(crate) trait UnaryOp { fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } -pub(crate) trait BinaryOp { +pub(crate) trait BinaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; @@ -133,7 +145,7 @@ pub(crate) struct Relu; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { - impl BinaryOp for $op { + impl BinaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); const V: Self = $op; @@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { - impl UnaryOp for $op { + impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; @@ -219,7 +231,7 @@ macro_rules! unary_op { }; ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => { - impl UnaryOp for $op { + impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; @@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); /// `gelu` operation /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> -impl UnaryOp for Gelu { +impl UnaryOpT for Gelu { const NAME: &'static str = "gelu"; const V: Self = Gelu; #[inline(always)] @@ -343,7 +355,7 @@ impl UnaryOp for Gelu { } } -impl UnaryOp for Relu { +impl UnaryOpT for Relu { const NAME: &'static str = "relu"; const KERNEL: &'static str = "urelu"; const V: Self = Relu; |