summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-20 13:28:45 +0200
committerGitHub <noreply@github.com>2023-07-20 12:28:45 +0100
commit2a8f28d687b5a33c6c91f40100a42baf6e2fc10a (patch)
treeee76b9df39f52212666ad6f2096be541973cdafd /candle-core/src
parente9c052bf94521b418852a1c5231c12ddce99a78f (diff)
downloadcandle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.gz
candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.tar.bz2
candle-2a8f28d687b5a33c6c91f40100a42baf6e2fc10a.zip
Op refactor (#208)
* Add the binary and unary op enums to factorize some code. * Bugfix.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs7
-rw-r--r--candle-core/src/backprop.rs54
-rw-r--r--candle-core/src/cpu_backend.rs11
-rw-r--r--candle-core/src/cuda_backend.rs10
-rw-r--r--candle-core/src/dummy_cuda_backend.rs11
-rw-r--r--candle-core/src/op.rs54
-rw-r--r--candle-core/src/storage.rs4
-rw-r--r--candle-core/src/tensor.rs6
8 files changed, 81 insertions, 76 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 307868dd..977dba69 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -1,4 +1,4 @@
-use crate::op::{CmpOp, ReduceOp};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub(crate) trait BackendStorage: Sized {
@@ -25,10 +25,9 @@ pub(crate) trait BackendStorage: Sized {
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
- fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
+ fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
- fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
- -> Result<Self>;
+ fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 4d968e7f..aab3caf0 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -1,4 +1,4 @@
-use crate::op::{Op, ReduceOp};
+use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@@ -39,10 +39,7 @@ impl Tensor {
kernel: rhs,
..
}
- | Op::Add(lhs, rhs)
- | Op::Mul(lhs, rhs)
- | Op::Sub(lhs, rhs)
- | Op::Div(lhs, rhs)
+ | Op::Binary(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@@ -74,17 +71,8 @@ impl Tensor {
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
- | Op::Sqr(node)
- | Op::Sqrt(node)
- | Op::Gelu(node)
- | Op::Relu(node)
- | Op::Elu(node, _)
- | Op::Exp(node)
- | Op::Log(node)
- | Op::Sin(node)
- | Op::Cos(node)
- | Op::Abs(node)
- | Op::Neg(node) => {
+ | Op::Unary(node, _)
+ | Op::Elu(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@@ -118,19 +106,19 @@ impl Tensor {
// this is out of scope.
if let Some(op) = node.op() {
match op {
- Op::Add(lhs, rhs) => {
+ Op::Binary(lhs, rhs, BinaryOp::Add) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
- Op::Sub(lhs, rhs) => {
+ Op::Binary(lhs, rhs, BinaryOp::Sub) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
}
- Op::Mul(lhs, rhs) => {
+ Op::Binary(lhs, rhs, BinaryOp::Mul) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
@@ -138,7 +126,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
- Op::Div(lhs, rhs) => {
+ Op::Binary(lhs, rhs, BinaryOp::Div) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
@@ -221,24 +209,26 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Log(arg) => {
+ Op::Unary(arg, UnaryOp::Log) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
- Op::Sin(arg) => {
+ Op::Unary(arg, UnaryOp::Sin) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
}
- Op::Cos(arg) => {
+ Op::Unary(arg, UnaryOp::Cos) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
- Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
- Op::Exp(arg) => {
+ Op::Unary(_, UnaryOp::Abs) => {
+ return Err(Error::BackwardNotSupported { op: "abs" })
+ }
+ Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
}
- Op::Neg(arg) => {
+ Op::Unary(arg, UnaryOp::Neg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
@@ -276,15 +266,19 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
- Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
+ Op::Unary(_, UnaryOp::Gelu) => {
+ return Err(Error::BackwardNotSupported { op: "gelu" })
+ }
+ Op::Unary(_, UnaryOp::Relu) => {
+ return Err(Error::BackwardNotSupported { op: "relu" })
+ }
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
- Op::Sqr(arg) => {
+ Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Sqrt(arg) => {
+ Op::Unary(arg, UnaryOp::Sqrt) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index b12e0702..1ce1d73a 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
-use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
@@ -1158,7 +1158,7 @@ impl BackendStorage for CpuStorage {
}
}
- fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
+ fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
if B::BF16_VEC {
@@ -1207,7 +1207,12 @@ impl BackendStorage for CpuStorage {
}
}
- fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
+ fn binary_impl<B: BinaryOpT>(
+ &self,
+ rhs: &Self,
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ ) -> Result<Self> {
match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => {
let data = if B::BF16_VEC {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 9e47c133..c9bb1fba 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
-use crate::op::{CmpOp, ReduceOp};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@@ -573,7 +573,7 @@ impl<'a> Map1 for FastReduce<'a> {
}
}
-impl<U: crate::op::UnaryOp> Map1 for U {
+impl<U: UnaryOpT> Map1 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@@ -716,7 +716,7 @@ impl<'a> Map2 for WhereCond<'a> {
}
}
-impl<U: crate::op::BinaryOp> Map2 for U {
+impl<U: crate::op::BinaryOpT> Map2 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
@@ -976,13 +976,13 @@ impl BackendStorage for CudaStorage {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
- fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
+ fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
- fn binary_impl<B: crate::op::BinaryOp>(
+ fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 942e82ed..8f1a9916 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -1,5 +1,5 @@
#![allow(dead_code)]
-use crate::op::{CmpOp, ReduceOp};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
@@ -57,16 +57,11 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
+ fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- fn binary_impl<B: crate::op::BinaryOp>(
- &self,
- _: &Self,
- _: &Layout,
- _: &Layout,
- ) -> Result<Self> {
+ fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index ece6969c..3dd3de68 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -19,12 +19,34 @@ pub enum ReduceOp {
Max,
}
+// These ops return the same type as their input type.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum BinaryOp {
+ Add,
+ Mul,
+ Sub,
+ Div,
+}
+
+// Unary ops with no argument
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum UnaryOp {
+ Exp,
+ Log,
+ Sin,
+ Cos,
+ Abs,
+ Neg,
+ Sqr,
+ Sqrt,
+ Gelu,
+ Relu,
+}
+
#[derive(Clone)]
pub(crate) enum Op {
- Add(Tensor, Tensor),
- Mul(Tensor, Tensor),
- Sub(Tensor, Tensor),
- Div(Tensor, Tensor),
+ Binary(Tensor, Tensor, BinaryOp),
+ Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
@@ -49,26 +71,16 @@ pub(crate) enum Op {
},
ToDType(Tensor),
Broadcast(Tensor),
- Exp(Tensor),
- Log(Tensor),
- Sin(Tensor),
- Cos(Tensor),
- Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
- Neg(Tensor),
Reshape(Tensor),
Softmax(Tensor, usize),
- Sqr(Tensor),
- Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
- Gelu(Tensor),
- Relu(Tensor),
Elu(Tensor, f64),
// TODO: Support for custom ops.
}
-pub(crate) trait UnaryOp {
+pub(crate) trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@@ -91,7 +103,7 @@ pub(crate) trait UnaryOp {
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
-pub(crate) trait BinaryOp {
+pub(crate) trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@@ -133,7 +145,7 @@ pub(crate) struct Relu;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
- impl BinaryOp for $op {
+ impl BinaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
@@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
- impl UnaryOp for $op {
+ impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@@ -219,7 +231,7 @@ macro_rules! unary_op {
};
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
- impl UnaryOp for $op {
+ impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
-impl UnaryOp for Gelu {
+impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
@@ -343,7 +355,7 @@ impl UnaryOp for Gelu {
}
}
-impl UnaryOp for Relu {
+impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index fb72322c..30232ba0 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -147,7 +147,7 @@ impl Storage {
}
}
- pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
+ pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
@@ -161,7 +161,7 @@ impl Storage {
}
}
- pub(crate) fn binary_impl<B: op::BinaryOp>(
+ pub(crate) fn binary_impl<B: op::BinaryOpT>(
&self,
rhs: &Self,
lhs_layout: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index d6c3e9cb..087a2ff5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
-use crate::op::{CmpOp, Op, ReduceOp};
+use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -80,7 +80,7 @@ macro_rules! unary_op {
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() {
- Some(Op::$op_name(self.clone()))
+ Some(Op::Unary(self.clone(), UnaryOp::$op_name))
} else {
None
};
@@ -99,7 +99,7 @@ macro_rules! binary_op {
rhs.layout(),
)?;
let op = if self.track_op() || rhs.track_op() {
- Some(Op::$op_name(self.clone(), rhs.clone()))
+ Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
} else {
None
};