summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-17 11:12:05 +0100
committerGitHub <noreply@github.com>2023-08-17 11:12:05 +0100
commit03be33eea482accbcf4c547728c2db7e24b7ebb0 (patch)
treeda5680d6d705d9346edbac9f2ce4a05779b86343 /candle-core/src/op.rs
parentd32e8199cd6c8381aa309528675d6d6a88c0f850 (diff)
downloadcandle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.gz
candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.bz2
candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.zip
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp. * Simplify the custom-ops when no backward is required.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs21
1 files changed, 15 insertions, 6 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 2b57f7f7..cf99f86e 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -118,13 +118,22 @@ pub enum Op {
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
- CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
- CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
- CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
+ CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
+ CustomOp2(
+ Tensor,
+ Tensor,
+ std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
+ ),
+ CustomOp3(
+ Tensor,
+ Tensor,
+ Tensor,
+ std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
+ ),
}
/// Unary ops that can be defined in user-land.
-pub trait CustomOp1: Send + Sync {
+pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
@@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync {
}
}
-pub trait CustomOp2: Send + Sync {
+pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync {
}
}
-pub trait CustomOp3: Send + Sync {
+pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,