summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/op.rs21
-rw-r--r--candle-core/src/quantized/mod.rs6
-rw-r--r--candle-core/src/storage.rs6
-rw-r--r--candle-core/src/tensor.rs65
-rw-r--r--candle-core/tests/custom_op_tests.rs6
-rw-r--r--candle-examples/examples/custom-ops/main.rs2
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs2
-rw-r--r--candle-flash-attn/src/lib.rs4
8 files changed, 81 insertions, 31 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 2b57f7f7..cf99f86e 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -118,13 +118,22 @@ pub enum Op {
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
- CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
- CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
- CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
+ CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
+ CustomOp2(
+ Tensor,
+ Tensor,
+ std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
+ ),
+ CustomOp3(
+ Tensor,
+ Tensor,
+ Tensor,
+ std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
+ ),
}
/// Unary ops that can be defined in user-land.
-pub trait CustomOp1: Send + Sync {
+pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
@@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync {
}
}
-pub trait CustomOp2: Send + Sync {
+pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync {
}
}
-pub trait CustomOp3: Send + Sync {
+pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index a0ed5b4d..a334b2c1 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -147,11 +147,11 @@ impl QTensor {
}
}
-pub struct QMatMul(std::sync::Arc<Box<dyn crate::CustomOp1>>);
+pub struct QMatMul(QTensor);
impl QMatMul {
pub fn from_qtensor(qtensor: QTensor) -> Self {
- Self(std::sync::Arc::new(Box::new(qtensor)))
+ Self(qtensor)
}
}
@@ -196,6 +196,6 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.custom_op1_arc(self.0.clone())
+ xs.apply_op1_no_bwd(&self.0)
}
}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 791b65dd..4a6cdc34 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -138,7 +138,7 @@ impl Storage {
}
}
- pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
+ pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
@@ -151,7 +151,7 @@ impl Storage {
}
}
- pub(crate) fn custom_op2(
+ pub(crate) fn apply_op2(
&self,
l1: &Layout,
t2: &Self,
@@ -172,7 +172,7 @@ impl Storage {
}
}
- pub(crate) fn custom_op3(
+ pub(crate) fn apply_op3(
&self,
l1: &Layout,
t2: &Self,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index c14a4e39..c71ea5ec 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1870,22 +1870,53 @@ impl Tensor {
std::ptr::eq(lhs, rhs)
}
+ /// Applies a unary custom op without backward support
+ pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
+ let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
+ Ok(from_storage(storage, shape, BackpropOp::none(), false))
+ }
+
+ /// Applies a binary custom op without backward support
+ pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
+ let (storage, shape) =
+ self.storage()
+ .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
+ Ok(from_storage(storage, shape, BackpropOp::none(), false))
+ }
+
+ /// Applies a ternary custom op without backward support
+ pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
+ let (storage, shape) = self.storage().apply_op3(
+ self.layout(),
+ &t2.storage(),
+ t2.layout(),
+ &t3.storage(),
+ t3.layout(),
+ c,
+ )?;
+ Ok(from_storage(storage, shape, BackpropOp::none(), false))
+ }
+
/// Applies a unary custom op.
- pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
+ pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
- .custom_op1(self.layout(), c.as_ref().as_ref())?;
+ .apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
- pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
- self.custom_op1_arc(Arc::new(Box::new(c)))
+ pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
+ self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
- pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
- let (storage, shape) = self.storage().custom_op2(
+ pub fn apply_op2_arc(
+ &self,
+ rhs: &Self,
+ c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
+ ) -> Result<Self> {
+ let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
@@ -1895,13 +1926,18 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
- pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
- self.custom_op2_arc(r, Arc::new(Box::new(c)))
+ pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
+ self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
- pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
- let (storage, shape) = self.storage().custom_op3(
+ pub fn apply_op3_arc(
+ &self,
+ t2: &Self,
+ t3: &Self,
+ c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
+ ) -> Result<Self> {
+ let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
@@ -1915,8 +1951,13 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
- pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
- self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
+ pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
+ &self,
+ t2: &Self,
+ t3: &Self,
+ c: C,
+ ) -> Result<Self> {
+ self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
index 55b5e894..7ec04c6a 100644
--- a/candle-core/tests/custom_op_tests.rs
+++ b/candle-core/tests/custom_op_tests.rs
@@ -39,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
- let elu_t = t.custom_op1(Elu { alpha: 1. })?;
+ let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&elu_t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
@@ -96,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
- let bwd = arg.custom_op1(EluBackward { alpha })?;
+ let bwd = arg.apply_op1(EluBackward { alpha })?;
Ok(Some(grad_res.mul(&bwd)?))
}
}
@@ -105,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
- let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
+ let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
let grads = elu_t.backward()?;
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs
index 63bcd83a..7f7a3f26 100644
--- a/candle-examples/examples/custom-ops/main.rs
+++ b/candle-examples/examples/custom-ops/main.rs
@@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
println!("{t}");
- let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
+ let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
println!("{t}");
Ok(())
}
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index ab4e382c..ad5e4cd2 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -68,7 +68,7 @@ impl CustomOp1 for AllReduce {
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
- x.custom_op1(AllReduce { comm: comm.clone() })
+ x.apply_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 092743f1..3c5fd455 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -178,7 +178,7 @@ pub fn flash_attn(
softmax_scale,
causal,
};
- q.custom_op3(k, v, op)
+ q.apply_op3(k, v, op)
}
struct FlashAttnVarLen {
@@ -402,5 +402,5 @@ pub fn flash_attn_varlen(
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
};
- q.custom_op3(k, v, op)
+ q.apply_op3(k, v, op)
}