summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs18
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index ba8d2fb4..aea8b733 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -51,6 +51,7 @@ pub enum UnaryOp {
Cos,
Abs,
Neg,
+ Recip,
Sqr,
Sqrt,
Gelu,
@@ -79,6 +80,21 @@ pub enum Op {
stride: usize,
},
+ #[allow(dead_code)]
+ Conv2D {
+ arg: Tensor,
+ kernel: Tensor,
+ padding: usize,
+ stride: usize,
+ },
+
+ AvgPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+ UpsampleNearest2D(Tensor),
+
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
@@ -264,6 +280,7 @@ pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
+pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
@@ -410,6 +427,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
+unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);