diff options
-rw-r--r-- | candle-core/src/backend.rs | 7 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 54 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 11 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 10 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 11 | ||||
-rw-r--r-- | candle-core/src/op.rs | 54 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 4 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 6 |
8 files changed, 81 insertions, 76 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 307868dd..977dba69 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -1,4 +1,4 @@ -use crate::op::{CmpOp, ReduceOp}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; pub(crate) trait BackendStorage: Sized { @@ -25,10 +25,9 @@ pub(crate) trait BackendStorage: Sized { fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>; - fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>; + fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>; - fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout) - -> Result<Self>; + fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>; fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 4d968e7f..aab3caf0 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,4 +1,4 @@ -use crate::op::{Op, ReduceOp}; +use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; @@ -39,10 +39,7 @@ impl Tensor { kernel: rhs, .. } - | Op::Add(lhs, rhs) - | Op::Mul(lhs, rhs) - | Op::Sub(lhs, rhs) - | Op::Div(lhs, rhs) + | Op::Binary(lhs, rhs, _) | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); @@ -74,17 +71,8 @@ impl Tensor { | Op::Transpose(node, _, _) | Op::Narrow(node, _, _, _) | Op::Softmax(node, _) - | Op::Sqr(node) - | Op::Sqrt(node) - | Op::Gelu(node) - | Op::Relu(node) - | Op::Elu(node, _) - | Op::Exp(node) - | Op::Log(node) - | Op::Sin(node) - | Op::Cos(node) - | Op::Abs(node) - | Op::Neg(node) => { + | Op::Unary(node, _) + | Op::Elu(node, _) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -118,19 +106,19 @@ impl Tensor { // this is out of scope. if let Some(op) = node.op() { match op { - Op::Add(lhs, rhs) => { + Op::Binary(lhs, rhs, BinaryOp::Add) => { let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?; let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&grad)?; } - Op::Sub(lhs, rhs) => { + Op::Binary(lhs, rhs, BinaryOp::Sub) => { let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?; let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; } - Op::Mul(lhs, rhs) => { + Op::Binary(lhs, rhs, BinaryOp::Mul) => { let lhs_grad = grad.mul(rhs)?; let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; @@ -138,7 +126,7 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Div(lhs, rhs) => { + Op::Binary(lhs, rhs, BinaryOp::Div) => { let lhs_grad = grad.div(rhs)?; let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; @@ -221,24 +209,26 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Log(arg) => { + Op::Unary(arg, UnaryOp::Log) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&(&grad * *node)?)? } - Op::Sin(arg) => { + Op::Unary(arg, UnaryOp::Sin) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&(&grad * arg.cos())?)? } - Op::Cos(arg) => { + Op::Unary(arg, UnaryOp::Cos) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? } - Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }), - Op::Exp(arg) => { + Op::Unary(_, UnaryOp::Abs) => { + return Err(Error::BackwardNotSupported { op: "abs" }) + } + Op::Unary(arg, UnaryOp::Exp) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&(&grad / arg)?)? } - Op::Neg(arg) => { + Op::Unary(arg, UnaryOp::Neg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } @@ -276,15 +266,19 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), - Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }), + Op::Unary(_, UnaryOp::Gelu) => { + return Err(Error::BackwardNotSupported { op: "gelu" }) + } + Op::Unary(_, UnaryOp::Relu) => { + return Err(Error::BackwardNotSupported { op: "relu" }) + } Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }), - Op::Sqr(arg) => { + Op::Unary(arg, UnaryOp::Sqr) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Sqrt(arg) => { + Op::Unary(arg, UnaryOp::Sqrt) => { let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index b12e0702..1ce1d73a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; @@ -1158,7 +1158,7 @@ impl BackendStorage for CpuStorage { } } - fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> { + fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { match self { Self::BF16(storage) => { if B::BF16_VEC { @@ -1207,7 +1207,12 @@ impl BackendStorage for CpuStorage { } } - fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> { + fn binary_impl<B: BinaryOpT>( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result<Self> { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { let data = if B::BF16_VEC { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9e47c133..c9bb1fba 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{CmpOp, ReduceOp}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; @@ -573,7 +573,7 @@ impl<'a> Map1 for FastReduce<'a> { } } -impl<U: crate::op::UnaryOp> Map1 for U { +impl<U: UnaryOpT> Map1 for U { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src: &CudaSlice<T>, @@ -716,7 +716,7 @@ impl<'a> Map2 for WhereCond<'a> { } } -impl<U: crate::op::BinaryOp> Map2 for U { +impl<U: crate::op::BinaryOpT> Map2 for U { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, lhs: &CudaSlice<T>, @@ -976,13 +976,13 @@ impl BackendStorage for CudaStorage { Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) } - fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> { + fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { let device = self.device().clone(); let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } - fn binary_impl<B: crate::op::BinaryOp>( + fn binary_impl<B: BinaryOpT>( &self, rhs: &Self, lhs_l: &Layout, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 942e82ed..8f1a9916 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::op::{CmpOp, ReduceOp}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; #[derive(Debug, Clone)] @@ -57,16 +57,11 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> { + fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } - fn binary_impl<B: crate::op::BinaryOp>( - &self, - _: &Self, - _: &Layout, - _: &Layout, - ) -> Result<Self> { + fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ece6969c..3dd3de68 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -19,12 +19,34 @@ pub enum ReduceOp { Max, } +// These ops return the same type as their input type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOp { + Add, + Mul, + Sub, + Div, +} + +// Unary ops with no argument +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOp { + Exp, + Log, + Sin, + Cos, + Abs, + Neg, + Sqr, + Sqrt, + Gelu, + Relu, +} + #[derive(Clone)] pub(crate) enum Op { - Add(Tensor, Tensor), - Mul(Tensor, Tensor), - Sub(Tensor, Tensor), - Div(Tensor, Tensor), + Binary(Tensor, Tensor, BinaryOp), + Unary(Tensor, UnaryOp), Cmp(Tensor, CmpOp), Reduce(Tensor, ReduceOp, Vec<usize>), Matmul(Tensor, Tensor), @@ -49,26 +71,16 @@ pub(crate) enum Op { }, ToDType(Tensor), Broadcast(Tensor), - Exp(Tensor), - Log(Tensor), - Sin(Tensor), - Cos(Tensor), - Abs(Tensor), Narrow(Tensor, usize, usize, usize), - Neg(Tensor), Reshape(Tensor), Softmax(Tensor, usize), - Sqr(Tensor), - Sqrt(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), - Gelu(Tensor), - Relu(Tensor), Elu(Tensor, f64), // TODO: Support for custom ops. } -pub(crate) trait UnaryOp { +pub(crate) trait UnaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; @@ -91,7 +103,7 @@ pub(crate) trait UnaryOp { fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } -pub(crate) trait BinaryOp { +pub(crate) trait BinaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; @@ -133,7 +145,7 @@ pub(crate) struct Relu; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { - impl BinaryOp for $op { + impl BinaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); const V: Self = $op; @@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { - impl UnaryOp for $op { + impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; @@ -219,7 +231,7 @@ macro_rules! unary_op { }; ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => { - impl UnaryOp for $op { + impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; @@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); /// `gelu` operation /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> -impl UnaryOp for Gelu { +impl UnaryOpT for Gelu { const NAME: &'static str = "gelu"; const V: Self = Gelu; #[inline(always)] @@ -343,7 +355,7 @@ impl UnaryOp for Gelu { } } -impl UnaryOp for Relu { +impl UnaryOpT for Relu { const NAME: &'static str = "relu"; const KERNEL: &'static str = "urelu"; const V: Self = Relu; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index fb72322c..30232ba0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -147,7 +147,7 @@ impl Storage { } } - pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> { + pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { @@ -161,7 +161,7 @@ impl Storage { } } - pub(crate) fn binary_impl<B: op::BinaryOp>( + pub(crate) fn binary_impl<B: op::BinaryOpT>( &self, rhs: &Self, lhs_layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d6c3e9cb..087a2ff5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{CmpOp, Op, ReduceOp}; +use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -80,7 +80,7 @@ macro_rules! unary_op { .storage() .unary_impl::<crate::op::$op_name>(self.layout())?; let op = if self.track_op() { - Some(Op::$op_name(self.clone())) + Some(Op::Unary(self.clone(), UnaryOp::$op_name)) } else { None }; @@ -99,7 +99,7 @@ macro_rules! binary_op { rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { - Some(Op::$op_name(self.clone(), rhs.clone())) + Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name)) } else { None }; |