diff options
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ba8d2fb4..aea8b733 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -51,6 +51,7 @@ pub enum UnaryOp { Cos, Abs, Neg, + Recip, Sqr, Sqrt, Gelu, @@ -79,6 +80,21 @@ pub enum Op { stride: usize, }, + #[allow(dead_code)] + Conv2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + + AvgPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), + Cat(Vec<Tensor>, usize), #[allow(dead_code)] // add is currently unused. @@ -264,6 +280,7 @@ pub(crate) struct Sin; pub(crate) struct Cos; pub(crate) struct Abs; pub(crate) struct Neg; +pub(crate) struct Recip; pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; @@ -410,6 +427,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); +unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); |