#![allow(clippy::redundant_closure_call)] use crate::Tensor; use half::{bf16, f16}; use num_traits::float::Float; #[derive(Clone, Copy, PartialEq, Eq)] pub enum CmpOp { Eq, Ne, Le, Ge, Lt, Gt, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ReduceOp { Sum, Min, Max, ArgMin, ArgMax, } impl ReduceOp { pub(crate) fn name(&self) -> &'static str { match self { Self::ArgMax => "argmax", Self::ArgMin => "argmin", Self::Min => "min", Self::Max => "max", Self::Sum => "sum", } } } // These ops return the same type as their input type. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BinaryOp { Add, Mul, Sub, Div, Maximum, Minimum, } // Unary ops with no argument #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum UnaryOp { Exp, Log, Sin, Cos, Abs, Neg, Recip, Sqr, Sqrt, Gelu, GeluErf, Erf, Relu, Silu, Tanh, Floor, Ceil, Round, Sign, } #[derive(Clone)] pub enum Op { Binary(Tensor, Tensor, BinaryOp), Unary(Tensor, UnaryOp), Cmp(Tensor, CmpOp), // The third argument is the reduced shape with `keepdim=true`. Reduce(Tensor, ReduceOp, Vec), Matmul(Tensor, Tensor), Gather(Tensor, Tensor, usize), ScatterAdd(Tensor, Tensor, Tensor, usize), IndexSelect(Tensor, Tensor, usize), IndexAdd(Tensor, Tensor, Tensor, usize), WhereCond(Tensor, Tensor, Tensor), #[allow(dead_code)] Conv1D { arg: Tensor, kernel: Tensor, padding: usize, stride: usize, dilation: usize, }, #[allow(dead_code)] ConvTranspose1D { arg: Tensor, kernel: Tensor, padding: usize, output_padding: usize, stride: usize, dilation: usize, }, #[allow(dead_code)] Conv2D { arg: Tensor, kernel: Tensor, padding: usize, stride: usize, dilation: usize, }, #[allow(dead_code)] ConvTranspose2D { arg: Tensor, kernel: Tensor, padding: usize, output_padding: usize, stride: usize, dilation: usize, }, AvgPool2D { arg: Tensor, kernel_size: (usize, usize), stride: (usize, usize), }, MaxPool2D { arg: Tensor, kernel_size: (usize, usize), stride: (usize, usize), }, UpsampleNearest1D { arg: Tensor, target_size: usize, }, UpsampleNearest2D { arg: Tensor, target_h: usize, target_w: usize, }, Cat(Vec, usize), #[allow(dead_code)] // add is currently unused. Affine { arg: Tensor, mul: f64, add: f64, }, ToDType(Tensor), Copy(Tensor), Broadcast(Tensor), Narrow(Tensor, usize, usize, usize), SliceScatter0(Tensor, Tensor, usize), Reshape(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), Permute(Tensor, Vec), Elu(Tensor, f64), Powf(Tensor, f64), CustomOp1( Tensor, std::sync::Arc>, ), CustomOp2( Tensor, Tensor, std::sync::Arc>, ), CustomOp3( Tensor, Tensor, Tensor, std::sync::Arc>, ), } pub trait UnaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; fn bf16(v1: bf16) -> bf16; fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; fn i64(v1: i64) -> i64; // There is no very good way to represent optional function in traits so we go for an explicit // boolean flag to mark the function as existing. const BF16_VEC: bool = false; fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } pub trait BinaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; fn bf16(v1: bf16, v2: bf16) -> bf16; fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; fn i64(v1: i64, v2: i64) -> i64; const BF16_VEC: bool = false; fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {} const F32_VEC: bool = false; fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} const I64_VEC: bool = false; fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} } pub(crate) struct Add; pub(crate) struct Div; pub(crate) struct Mul; pub(crate) struct Sub; pub(crate) struct Maximum; pub(crate) struct Minimum; pub(crate) struct Exp; pub(crate) struct Log; pub(crate) struct Sin; pub(crate) struct Cos; pub(crate) struct Abs; pub(crate) struct Neg; pub(crate) struct Recip; pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; pub(crate) struct GeluErf; pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Silu; pub(crate) struct Tanh; pub(crate) struct Floor; pub(crate) struct Ceil; pub(crate) struct Round; pub(crate) struct Sign; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { impl BinaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); const V: Self = $op; #[inline(always)] fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) } #[inline(always)] fn f16(v1: f16, v2: f16) -> f16 { $e(v1, v2) } #[inline(always)] fn f32(v1: f32, v2: f32) -> f32 { $e(v1, v2) } #[inline(always)] fn f64(v1: f64, v2: f64) -> f64 { $e(v1, v2) } #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } #[inline(always)] fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } #[cfg(feature = "mkl")] const F32_VEC: bool = true; #[cfg(feature = "mkl")] const F64_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { crate::mkl::$f32_vec(xs1, xs2, ys) } #[cfg(feature = "mkl")] #[inline(always)] fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { crate::mkl::$f64_vec(xs1, xs2, ys) } #[cfg(feature = "accelerate")] const F32_VEC: bool = true; #[cfg(feature = "accelerate")] const F64_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { crate::accelerate::$f32_vec(xs1, xs2, ys) } #[cfg(feature = "accelerate")] #[inline(always)] fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { crate::accelerate::$f64_vec(xs1, xs2, ys) } } }; } bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add); bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub); bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul); bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); bin_op!( Minimum, "minimum", |v1, v2| if v1 > v2 { v2 } else { v1 }, vs_min, vd_min ); bin_op!( Maximum, "maximum", |v1, v2| if v1 < v2 { v2 } else { v1 }, vs_max, vd_max ); #[allow(clippy::redundant_closure_call)] macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; #[inline(always)] fn bf16($a: bf16) -> bf16 { $e } #[inline(always)] fn f16($a: f16) -> f16 { $e } #[inline(always)] fn f32($a: f32) -> f32 { $e } #[inline(always)] fn f64($a: f64) -> f64 { $e } #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } #[inline(always)] fn u32(_: u32) -> u32 { todo!("no unary function for u32") } #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } } }; ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => { impl UnaryOpT for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; #[inline(always)] fn bf16($a: bf16) -> bf16 { $e } #[inline(always)] fn f16($a: f16) -> f16 { $e } #[inline(always)] fn f32($a: f32) -> f32 { $e } #[inline(always)] fn f64($a: f64) -> f64 { $e } #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } #[inline(always)] fn u32(_: u32) -> u32 { todo!("no unary function for u32") } #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } #[cfg(feature = "mkl")] const F32_VEC: bool = true; #[cfg(feature = "mkl")] const F64_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::mkl::$f32_vec(xs, ys) } #[cfg(feature = "mkl")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::$f64_vec(xs, ys) } #[cfg(feature = "accelerate")] const F32_VEC: bool = true; #[cfg(feature = "accelerate")] const F64_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::accelerate::$f32_vec(xs, ys) } #[cfg(feature = "accelerate")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::accelerate::$f64_vec(xs, ys) } } }; } unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp); unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); unary_op!(Neg, "neg", v, -v); unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); // Hardcode the value for sqrt(2/pi) // https://github.com/huggingface/candle/issues/1982 #[allow(clippy::excessive_precision)] const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373; #[allow(clippy::excessive_precision)] const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373; /// Tanh based approximation of the `gelu` operation /// GeluErf is the more precise one. /// impl UnaryOpT for Gelu { const NAME: &'static str = "gelu"; const V: Self = Gelu; #[inline(always)] fn bf16(v: bf16) -> bf16 { bf16::from_f32_const(0.5) * v * (bf16::ONE + bf16::tanh( bf16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), )) } #[inline(always)] fn f16(v: f16) -> f16 { f16::from_f32_const(0.5) * v * (f16::ONE + f16::tanh( f16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (f16::ONE + f16::from_f32_const(0.044715) * v * v), )) } #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn f64(v: f64) -> f64 { 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn u8(_: u8) -> u8 { 0 } #[inline(always)] fn u32(_: u32) -> u32 { 0 } #[inline(always)] fn i64(_: i64) -> i64 { 0 } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] const F32_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::mkl::vs_gelu(xs, ys) } #[cfg(feature = "mkl")] const F64_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::vd_gelu(xs, ys) } #[cfg(feature = "accelerate")] const F32_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::accelerate::vs_gelu(xs, ys) } #[cfg(feature = "accelerate")] const F64_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::accelerate::vd_gelu(xs, ys) } } /// `erf` operation /// impl UnaryOpT for Erf { const NAME: &'static str = "erf"; const KERNEL: &'static str = "uerf"; const V: Self = Erf; #[inline(always)] fn bf16(v: bf16) -> bf16 { bf16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] fn f16(v: f16) -> f16 { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } #[inline(always)] fn f64(v: f64) -> f64 { crate::cpu::erf::erf(v) } #[inline(always)] fn u8(_: u8) -> u8 { 0 } #[inline(always)] fn u32(_: u32) -> u32 { 0 } #[inline(always)] fn i64(_: i64) -> i64 { 0 } } /// Silu operation impl UnaryOpT for Silu { const NAME: &'static str = "silu"; const V: Self = Silu; #[inline(always)] fn bf16(v: bf16) -> bf16 { v / (bf16::ONE + (-v).exp()) } #[inline(always)] fn f16(v: f16) -> f16 { v / (f16::ONE + (-v).exp()) } #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } #[inline(always)] fn f64(v: f64) -> f64 { v / (1.0 + (-v).exp()) } #[inline(always)] fn u8(_: u8) -> u8 { 0 } #[inline(always)] fn u32(_: u32) -> u32 { 0 } #[inline(always)] fn i64(_: i64) -> i64 { 0 } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] const F32_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::mkl::vs_silu(xs, ys) } #[cfg(feature = "mkl")] const F64_VEC: bool = true; #[cfg(feature = "mkl")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::vd_silu(xs, ys) } #[cfg(feature = "accelerate")] const F32_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f32_vec(xs: &[f32], ys: &mut [f32]) { crate::accelerate::vs_silu(xs, ys) } #[cfg(feature = "accelerate")] const F64_VEC: bool = true; #[cfg(feature = "accelerate")] #[inline(always)] fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::accelerate::vd_silu(xs, ys) } } impl UnaryOpT for Abs { const NAME: &'static str = "abs"; const KERNEL: &'static str = "uabs"; const V: Self = Abs; #[inline(always)] fn bf16(v: bf16) -> bf16 { v.abs() } #[inline(always)] fn f16(v: f16) -> f16 { v.abs() } #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } #[inline(always)] fn f64(v: f64) -> f64 { v.abs() } #[inline(always)] fn u8(v: u8) -> u8 { v } #[inline(always)] fn u32(v: u32) -> u32 { v } #[inline(always)] fn i64(v: i64) -> i64 { v.abs() } } impl UnaryOpT for Ceil { const NAME: &'static str = "ceil"; const KERNEL: &'static str = "uceil"; const V: Self = Ceil; #[inline(always)] fn bf16(v: bf16) -> bf16 { v.ceil() } #[inline(always)] fn f16(v: f16) -> f16 { v.ceil() } #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } #[inline(always)] fn f64(v: f64) -> f64 { v.ceil() } #[inline(always)] fn u8(v: u8) -> u8 { v } #[inline(always)] fn u32(v: u32) -> u32 { v } #[inline(always)] fn i64(v: i64) -> i64 { v } } impl UnaryOpT for Floor { const NAME: &'static str = "floor"; const KERNEL: &'static str = "ufloor"; const V: Self = Floor; #[inline(always)] fn bf16(v: bf16) -> bf16 { v.floor() } #[inline(always)] fn f16(v: f16) -> f16 { v.floor() } #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } #[inline(always)] fn f64(v: f64) -> f64 { v.floor() } #[inline(always)] fn u8(v: u8) -> u8 { v } #[inline(always)] fn u32(v: u32) -> u32 { v } #[inline(always)] fn i64(v: i64) -> i64 { v } } impl UnaryOpT for Round { const NAME: &'static str = "round"; const KERNEL: &'static str = "uround"; const V: Self = Round; #[inline(always)] fn bf16(v: bf16) -> bf16 { v.round() } #[inline(always)] fn f16(v: f16) -> f16 { v.round() } #[inline(always)] fn f32(v: f32) -> f32 { v.round() } #[inline(always)] fn f64(v: f64) -> f64 { v.round() } #[inline(always)] fn u8(v: u8) -> u8 { v } #[inline(always)] fn u32(v: u32) -> u32 { v } #[inline(always)] fn i64(v: i64) -> i64 { v } } impl UnaryOpT for GeluErf { const NAME: &'static str = "gelu_erf"; const KERNEL: &'static str = "ugelu_erf"; const V: Self = GeluErf; #[inline(always)] fn bf16(v: bf16) -> bf16 { bf16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] fn f16(v: f16) -> f16 { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } #[inline(always)] fn f64(v: f64) -> f64 { (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v } #[inline(always)] fn u8(_: u8) -> u8 { 0 } #[inline(always)] fn u32(_: u32) -> u32 { 0 } #[inline(always)] fn i64(_: i64) -> i64 { 0 } } impl UnaryOpT for Relu { const NAME: &'static str = "relu"; const KERNEL: &'static str = "urelu"; const V: Self = Relu; #[inline(always)] fn bf16(v: bf16) -> bf16 { v.max(bf16::ZERO) } #[inline(always)] fn f16(v: f16) -> f16 { v.max(f16::ZERO) } #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } #[inline(always)] fn f64(v: f64) -> f64 { v.max(0f64) } #[inline(always)] fn u8(v: u8) -> u8 { v } #[inline(always)] fn u32(v: u32) -> u32 { v } #[inline(always)] fn i64(v: i64) -> i64 { v } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are /// properly checked when creating a new value #[derive(Clone)] pub struct BackpropOp(Option); impl BackpropOp { pub(crate) fn none() -> Self { BackpropOp(None) } pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self { let op = if arg.track_op() { Some(f(arg.clone())) } else { None }; Self(op) } pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self { let op = if arg1.track_op() || arg2.track_op() { Some(f(arg1.clone(), arg2.clone())) } else { None }; Self(op) } pub(crate) fn new3( arg1: &Tensor, arg2: &Tensor, arg3: &Tensor, f: impl Fn(Tensor, Tensor, Tensor) -> Op, ) -> Self { let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() { Some(f(arg1.clone(), arg2.clone(), arg3.clone())) } else { None }; Self(op) } pub(crate) fn new>(args: &[A], f: impl Fn(Vec) -> Op) -> Self { let op = if args.iter().any(|arg| arg.as_ref().track_op()) { let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); Some(f(args)) } else { None }; Self(op) } pub(crate) fn is_none(&self) -> bool { self.0.is_none() } } impl std::ops::Deref for BackpropOp { type Target = Option; fn deref(&self) -> &Self::Target { &self.0 } } impl UnaryOpT for Sign { const NAME: &'static str = "sign"; const KERNEL: &'static str = "usign"; const V: Self = Sign; #[inline(always)] fn bf16(v: bf16) -> bf16 { bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8) } #[inline(always)] fn f16(v: f16) -> f16 { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } #[inline(always)] fn f64(v: f64) -> f64 { f64::from(v > 0.) - f64::from(v < 0.) } #[inline(always)] fn u8(v: u8) -> u8 { u8::min(1, v) } #[inline(always)] fn u32(v: u32) -> u32 { u32::min(1, v) } #[inline(always)] fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } }