summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-20 13:28:45 +0200
committerGitHub <noreply@github.com>2023-07-20 12:28:45 +0100
commit2a8f28d687b5a33c6c91f40100a42baf6e2fc10a (patch)
treeee76b9df39f52212666ad6f2096be541973cdafd /candle-core/src/op.rs
parente9c052bf94521b418852a1c5231c12ddce99a78f (diff)
downloadcandle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.gz
candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.bz2
candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.zip
Op refactor (#208)
* Add the binary and unary op enums to factorize some code. * Bugfix.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs54
1 files changed, 33 insertions, 21 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index ece6969c..3dd3de68 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -19,12 +19,34 @@ pub enum ReduceOp {
Max,
}
+// These ops return the same type as their input type.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum BinaryOp {
+ Add,
+ Mul,
+ Sub,
+ Div,
+}
+
+// Unary ops with no argument
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum UnaryOp {
+ Exp,
+ Log,
+ Sin,
+ Cos,
+ Abs,
+ Neg,
+ Sqr,
+ Sqrt,
+ Gelu,
+ Relu,
+}
+
#[derive(Clone)]
pub(crate) enum Op {
- Add(Tensor, Tensor),
- Mul(Tensor, Tensor),
- Sub(Tensor, Tensor),
- Div(Tensor, Tensor),
+ Binary(Tensor, Tensor, BinaryOp),
+ Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
@@ -49,26 +71,16 @@ pub(crate) enum Op {
},
ToDType(Tensor),
Broadcast(Tensor),
- Exp(Tensor),
- Log(Tensor),
- Sin(Tensor),
- Cos(Tensor),
- Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
- Neg(Tensor),
Reshape(Tensor),
Softmax(Tensor, usize),
- Sqr(Tensor),
- Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
- Gelu(Tensor),
- Relu(Tensor),
Elu(Tensor, f64),
// TODO: Support for custom ops.
}
-pub(crate) trait UnaryOp {
+pub(crate) trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@@ -91,7 +103,7 @@ pub(crate) trait UnaryOp {
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
-pub(crate) trait BinaryOp {
+pub(crate) trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@@ -133,7 +145,7 @@ pub(crate) struct Relu;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
- impl BinaryOp for $op {
+ impl BinaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
@@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
- impl UnaryOp for $op {
+ impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@@ -219,7 +231,7 @@ macro_rules! unary_op {
};
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
- impl UnaryOp for $op {
+ impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
-impl UnaryOp for Gelu {
+impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
@@ -343,7 +355,7 @@ impl UnaryOp for Gelu {
}
}
-impl UnaryOp for Relu {
+impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;