diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-17 11:12:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-17 11:12:05 +0100 |
commit | 03be33eea482accbcf4c547728c2db7e24b7ebb0 (patch) | |
tree | da5680d6d705d9346edbac9f2ce4a05779b86343 /candle-core/src/op.rs | |
parent | d32e8199cd6c8381aa309528675d6d6a88c0f850 (diff) | |
download | candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.gz candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.bz2 candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.zip |
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp.
* Simplify the custom-ops when no backward is required.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 2b57f7f7..cf99f86e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -118,13 +118,22 @@ pub enum Op { ToDevice(Tensor), Transpose(Tensor, usize, usize), Elu(Tensor, f64), - CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>), - CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>), - CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>), + CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>), + CustomOp2( + Tensor, + Tensor, + std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>, + ), + CustomOp3( + Tensor, + Tensor, + Tensor, + std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>, + ), } /// Unary ops that can be defined in user-land. -pub trait CustomOp1: Send + Sync { +pub trait CustomOp1 { // Box<dyn> does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; @@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync { } } -pub trait CustomOp2: Send + Sync { +pub trait CustomOp2 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, @@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync { } } -pub trait CustomOp3: Send + Sync { +pub trait CustomOp3 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, |