diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/op.rs | 21 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 6 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 6 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 65 |
4 files changed, 74 insertions, 24 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 2b57f7f7..cf99f86e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -118,13 +118,22 @@ pub enum Op { ToDevice(Tensor), Transpose(Tensor, usize, usize), Elu(Tensor, f64), - CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>), - CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>), - CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>), + CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>), + CustomOp2( + Tensor, + Tensor, + std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>, + ), + CustomOp3( + Tensor, + Tensor, + Tensor, + std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>, + ), } /// Unary ops that can be defined in user-land. -pub trait CustomOp1: Send + Sync { +pub trait CustomOp1 { // Box<dyn> does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; @@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync { } } -pub trait CustomOp2: Send + Sync { +pub trait CustomOp2 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, @@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync { } } -pub trait CustomOp3: Send + Sync { +pub trait CustomOp3 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index a0ed5b4d..a334b2c1 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -147,11 +147,11 @@ impl QTensor { } } -pub struct QMatMul(std::sync::Arc<Box<dyn crate::CustomOp1>>); +pub struct QMatMul(QTensor); impl QMatMul { pub fn from_qtensor(qtensor: QTensor) -> Self { - Self(std::sync::Arc::new(Box::new(qtensor))) + Self(qtensor) } } @@ -196,6 +196,6 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.custom_op1_arc(self.0.clone()) + xs.apply_op1_no_bwd(&self.0) } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 791b65dd..4a6cdc34 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -138,7 +138,7 @@ impl Storage { } } - pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { + pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { match self { Self::Cpu(storage) => { let (storage, shape) = c.cpu_fwd(storage, l)?; @@ -151,7 +151,7 @@ impl Storage { } } - pub(crate) fn custom_op2( + pub(crate) fn apply_op2( &self, l1: &Layout, t2: &Self, @@ -172,7 +172,7 @@ impl Storage { } } - pub(crate) fn custom_op3( + pub(crate) fn apply_op3( &self, l1: &Layout, t2: &Self, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c14a4e39..c71ea5ec 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1870,22 +1870,53 @@ impl Tensor { std::ptr::eq(lhs, rhs) } + /// Applies a unary custom op without backward support + pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> { + let (storage, shape) = self.storage().apply_op1(self.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a binary custom op without backward support + pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> { + let (storage, shape) = + self.storage() + .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a ternary custom op without backward support + pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + )?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + /// Applies a unary custom op. - pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> { + pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> { let (storage, shape) = self .storage() - .custom_op1(self.layout(), c.as_ref().as_ref())?; + .apply_op1(self.layout(), c.as_ref().as_ref())?; let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> { - self.custom_op1_arc(Arc::new(Box::new(c))) + pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> { + self.apply_op1_arc(Arc::new(Box::new(c))) } /// Applies a binary custom op. - pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> { - let (storage, shape) = self.storage().custom_op2( + pub fn apply_op2_arc( + &self, + rhs: &Self, + c: Arc<Box<dyn CustomOp2 + Send + Sync>>, + ) -> Result<Self> { + let (storage, shape) = self.storage().apply_op2( self.layout(), &rhs.storage(), rhs.layout(), @@ -1895,13 +1926,18 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> { - self.custom_op2_arc(r, Arc::new(Box::new(c))) + pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> { + self.apply_op2_arc(r, Arc::new(Box::new(c))) } /// Applies a ternary custom op. - pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> { - let (storage, shape) = self.storage().custom_op3( + pub fn apply_op3_arc( + &self, + t2: &Self, + t3: &Self, + c: Arc<Box<dyn CustomOp3 + Send + Sync>>, + ) -> Result<Self> { + let (storage, shape) = self.storage().apply_op3( self.layout(), &t2.storage(), t2.layout(), @@ -1915,8 +1951,13 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> { - self.custom_op3_arc(t2, t3, Arc::new(Box::new(c))) + pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>( + &self, + t2: &Self, + t3: &Self, + c: C, + ) -> Result<Self> { + self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) } } |