diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/accelerate.rs | 32 | ||||
-rw-r--r-- | candle-core/src/backend.rs | 1 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 12 | ||||
-rw-r--r-- | candle-core/src/cpu/erf.rs | 763 | ||||
-rw-r--r-- | candle-core/src/cpu/kernels.rs | 95 | ||||
-rw-r--r-- | candle-core/src/cpu/mod.rs | 1 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 265 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 223 | ||||
-rw-r--r-- | candle-core/src/cudnn.rs | 6 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 11 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/error.rs | 6 | ||||
-rw-r--r-- | candle-core/src/indexer.rs | 39 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 15 | ||||
-rw-r--r-- | candle-core/src/op.rs | 91 | ||||
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 8 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 8 | ||||
-rw-r--r-- | candle-core/src/scalar.rs | 23 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 205 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 13 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 133 |
22 files changed, 1871 insertions, 85 deletions
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs index 87e0ee8d..1cb34e19 100644 --- a/candle-core/src/accelerate.rs +++ b/candle-core/src/accelerate.rs @@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) { y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) } +#[inline] +pub fn vs_tanh_inplace(y: &mut [f32]) { + unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vd_tanh_inplace(y: &mut [f64]) { + unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vs_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vd_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + macro_rules! binary_op { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => { #[inline] diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 67a08714..03a07434 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -57,6 +57,7 @@ pub trait BackendStorage: Sized { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d2099df7..a2548198 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -91,13 +91,14 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest1D(node) | Op::UpsampleNearest2D(node) | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) - | Op::Reduce(node, _, _) + | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -111,6 +112,7 @@ impl Tensor { track_grad |= tg; nodes } + Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { nodes @@ -262,6 +264,9 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; } + Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest1d", + })?, Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, @@ -437,6 +442,10 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, + Op::Unary(_, UnaryOp::GeluErf) => { + Err(Error::BackwardNotSupported { op: "gelu-erf" })? + } Op::Unary(arg, UnaryOp::Relu) => { let sum_grad = grads.or_insert(arg)?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; @@ -517,6 +526,7 @@ impl Tensor { } } +#[derive(Debug)] pub struct GradStore(HashMap<TensorId, Tensor>); impl GradStore { diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs new file mode 100644 index 00000000..ca6be53f --- /dev/null +++ b/candle-core/src/cpu/erf.rs @@ -0,0 +1,763 @@ +#![allow(clippy::excessive_precision)] +// Code taken from https://github.com/statrs-dev/statrs +//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and +//! related functions + +mod evaluate { + //! Provides functions that don't have a numerical solution and must + //! be solved computationally (e.g. evaluation of a polynomial) + + /// evaluates a polynomial at `z` where `coeff` are the coeffecients + /// to a polynomial of order `k` where `k` is the length of `coeff` and the + /// coeffecient + /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to + /// `2z^2 - z + 3` + /// + /// # Remarks + /// + /// Returns 0 for a 0 length coefficient slice + pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { + let n = coeff.len(); + if n == 0 { + return 0.0; + } + + let mut sum = *coeff.last().unwrap(); + for c in coeff[0..n - 1].iter().rev() { + sum = *c + z * sum; + } + sum + } +} +use std::f64; + +/// `erf` calculates the error function at `x`. +pub fn erf(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x >= 0.0 && x.is_infinite() { + 1.0 + } else if x <= 0.0 && x.is_infinite() { + -1.0 + } else if x == 0. { + 0.0 + } else { + erf_impl(x, false) + } +} + +/// `erf_inv` calculates the inverse error function +/// at `x`. +pub fn erf_inv(x: f64) -> f64 { + if x == 0.0 { + 0.0 + } else if x >= 1.0 { + f64::INFINITY + } else if x <= -1.0 { + f64::NEG_INFINITY + } else if x < 0.0 { + erf_inv_impl(-x, 1.0 + x, -1.0) + } else { + erf_inv_impl(x, 1.0 - x, 1.0) + } +} + +/// `erfc` calculates the complementary error function +/// at `x`. +pub fn erfc(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x == f64::INFINITY { + 0.0 + } else if x == f64::NEG_INFINITY { + 2.0 + } else { + erf_impl(x, true) + } +} + +/// `erfc_inv` calculates the complementary inverse +/// error function at `x`. +pub fn erfc_inv(x: f64) -> f64 { + if x <= 0.0 { + f64::INFINITY + } else if x >= 2.0 { + f64::NEG_INFINITY + } else if x > 1.0 { + erf_inv_impl(-1.0 + x, 2.0 - x, -1.0) + } else { + erf_inv_impl(1.0 - x, x, 1.0) + } +} + +// ********************************************************** +// ********** Coefficients for erf_impl polynomial ********** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_impl` +/// in the interval [1e-10, 0.5]. +const ERF_IMPL_AN: &[f64] = &[ + 0.00337916709551257388990745, + -0.00073695653048167948530905, + -0.374732337392919607868241, + 0.0817442448733587196071743, + -0.0421089319936548595203468, + 0.0070165709512095756344528, + -0.00495091255982435110337458, + 0.000871646599037922480317225, +]; + +/// Polynomial coefficients for a denominator of `erf_impl` +/// in the interval [1e-10, 0.5] +const ERF_IMPL_AD: &[f64] = &[ + 1.0, + -0.218088218087924645390535, + 0.412542972725442099083918, + -0.0841891147873106755410271, + 0.0655338856400241519690695, + -0.0120019604454941768171266, + 0.00408165558926174048329689, + -0.000615900721557769691924509, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BN: &[f64] = &[ + -0.0361790390718262471360258, + 0.292251883444882683221149, + 0.281447041797604512774415, + 0.125610208862766947294894, + 0.0274135028268930549240776, + 0.00250839672168065762786937, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BD: &[f64] = &[ + 1.0, + 1.8545005897903486499845, + 1.43575803037831418074962, + 0.582827658753036572454135, + 0.124810476932949746447682, + 0.0113724176546353285778481, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CN: &[f64] = &[ + -0.0397876892611136856954425, + 0.153165212467878293257683, + 0.191260295600936245503129, + 0.10276327061989304213645, + 0.029637090615738836726027, + 0.0046093486780275489468812, + 0.000307607820348680180548455, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CD: &[f64] = &[ + 1.0, + 1.95520072987627704987886, + 1.64762317199384860109595, + 0.768238607022126250082483, + 0.209793185936509782784315, + 0.0319569316899913392596356, + 0.00213363160895785378615014, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DN: &[f64] = &[ + -0.0300838560557949717328341, + 0.0538578829844454508530552, + 0.0726211541651914182692959, + 0.0367628469888049348429018, + 0.00964629015572527529605267, + 0.00133453480075291076745275, + 0.778087599782504251917881e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DD: &[f64] = &[ + 1.0, + 1.75967098147167528287343, + 1.32883571437961120556307, + 0.552528596508757581287907, + 0.133793056941332861912279, + 0.0179509645176280768640766, + 0.00104712440019937356634038, + -0.106640381820357337177643e-7, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_EN: &[f64] = &[ + -0.0117907570137227847827732, + 0.014262132090538809896674, + 0.0202234435902960820020765, + 0.00930668299990432009042239, + 0.00213357802422065994322516, + 0.00025022987386460102395382, + 0.120534912219588189822126e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_ED: &[f64] = &[ + 1.0, + 1.50376225203620482047419, + 0.965397786204462896346934, + 0.339265230476796681555511, + 0.0689740649541569716897427, + 0.00771060262491768307365526, + 0.000371421101531069302990367, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FN: &[f64] = &[ + -0.00546954795538729307482955, + 0.00404190278731707110245394, + 0.0054963369553161170521356, + 0.00212616472603945399437862, + 0.000394984014495083900689956, + 0.365565477064442377259271e-4, + 0.135485897109932323253786e-5, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FD: &[f64] = &[ + 1.0, + 1.21019697773630784832251, + 0.620914668221143886601045, + 0.173038430661142762569515, + 0.0276550813773432047594539, + 0.00240625974424309709745382, + 0.891811817251336577241006e-4, + -0.465528836283382684461025e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GN: &[f64] = &[ + -0.00270722535905778347999196, + 0.0013187563425029400461378, + 0.00119925933261002333923989, + 0.00027849619811344664248235, + 0.267822988218331849989363e-4, + 0.923043672315028197865066e-6, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GD: &[f64] = &[ + 1.0, + 0.814632808543141591118279, + 0.268901665856299542168425, + 0.0449877216103041118694989, + 0.00381759663320248459168994, + 0.000131571897888596914350697, + 0.404815359675764138445257e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HN: &[f64] = &[ + -0.00109946720691742196814323, + 0.000406425442750422675169153, + 0.000274499489416900707787024, + 0.465293770646659383436343e-4, + 0.320955425395767463401993e-5, + 0.778286018145020892261936e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HD: &[f64] = &[ + 1.0, + 0.588173710611846046373373, + 0.139363331289409746077541, + 0.0166329340417083678763028, + 0.00100023921310234908642639, + 0.24254837521587225125068e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_IN: &[f64] = &[ + -0.00056907993601094962855594, + 0.000169498540373762264416984, + 0.518472354581100890120501e-4, + 0.382819312231928859704678e-5, + 0.824989931281894431781794e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_ID: &[f64] = &[ + 1.0, + 0.339637250051139347430323, + 0.043472647870310663055044, + 0.00248549335224637114641629, + 0.535633305337152900549536e-4, + -0.117490944405459578783846e-12, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JN: &[f64] = &[ + -0.000241313599483991337479091, + 0.574224975202501512365975e-4, + 0.115998962927383778460557e-4, + 0.581762134402593739370875e-6, + 0.853971555085673614607418e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JD: &[f64] = &[ + 1.0, + 0.233044138299687841018015, + 0.0204186940546440312625597, + 0.000797185647564398289151125, + 0.117019281670172327758019e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KN: &[f64] = &[ + -0.000146674699277760365803642, + 0.162666552112280519955647e-4, + 0.269116248509165239294897e-5, + 0.979584479468091935086972e-7, + 0.101994647625723465722285e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KD: &[f64] = &[ + 1.0, + 0.165907812944847226546036, + 0.0103361716191505884359634, + 0.000286593026373868366935721, + 0.298401570840900340874568e-5, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LN: &[f64] = &[ + -0.583905797629771786720406e-4, + 0.412510325105496173512992e-5, + 0.431790922420250949096906e-6, + 0.993365155590013193345569e-8, + 0.653480510020104699270084e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LD: &[f64] = &[ + 1.0, + 0.105077086072039915406159, + 0.00414278428675475620830226, + 0.726338754644523769144108e-4, + 0.477818471047398785369849e-6, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MN: &[f64] = &[ + -0.196457797609229579459841e-4, + 0.157243887666800692441195e-5, + 0.543902511192700878690335e-7, + 0.317472492369117710852685e-9, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MD: &[f64] = &[ + 1.0, + 0.052803989240957632204885, + 0.000926876069151753290378112, + 0.541011723226630257077328e-5, + 0.535093845803642394908747e-15, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_NN: &[f64] = &[ + -0.789224703978722689089794e-5, + 0.622088451660986955124162e-6, + 0.145728445676882396797184e-7, + 0.603715505542715364529243e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_ND: &[f64] = &[ + 1.0, + 0.0375328846356293715248719, + 0.000467919535974625308126054, + 0.193847039275845656900547e-5, +]; + +// ********************************************************** +// ********** Coefficients for erf_inv_impl polynomial ****** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AN: &[f64] = &[ + -0.000508781949658280665617, + -0.00836874819741736770379, + 0.0334806625409744615033, + -0.0126926147662974029034, + -0.0365637971411762664006, + 0.0219878681111168899165, + 0.00822687874676915743155, + -0.00538772965071242932965, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AD: &[f64] = &[ + 1.0, + -0.970005043303290640362, + -1.56574558234175846809, + 1.56221558398423026363, + 0.662328840472002992063, + -0.71228902341542847553, + -0.0527396382340099713954, + 0.0795283687341571680018, + -0.00233393759374190016776, + 0.000886216390456424707504, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BN: &[f64] = &[ + -0.202433508355938759655, + 0.105264680699391713268, + 8.37050328343119927838, + 17.6447298408374015486, + -18.8510648058714251895, + -44.6382324441786960818, + 17.445385985570866523, + 21.1294655448340526258, + -3.67192254707729348546, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BD: &[f64] = &[ + 1.0, + 6.24264124854247537712, + 3.9713437953343869095, + -28.6608180499800029974, + -20.1432634680485188801, + 48.5609213108739935468, + 10.8268667355460159008, + -22.6436933413139721736, + 1.72114765761200282724, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CN: &[f64] = &[ + -0.131102781679951906451, + -0.163794047193317060787, + 0.117030156341995252019, + 0.387079738972604337464, + 0.337785538912035898924, + 0.142869534408157156766, + 0.0290157910005329060432, + 0.00214558995388805277169, + -0.679465575181126350155e-6, + 0.285225331782217055858e-7, + -0.681149956853776992068e-9, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CD: &[f64] = &[ + 1.0, + 3.46625407242567245975, + 5.38168345707006855425, + 4.77846592945843778382, + 2.59301921623620271374, + 0.848854343457902036425, + 0.152264338295331783612, + 0.01105924229346489121, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DN: &[f64] = &[ + -0.0350353787183177984712, + -0.00222426529213447927281, + 0.0185573306514231072324, + 0.00950804701325919603619, + 0.00187123492819559223345, + 0.000157544617424960554631, + 0.460469890584317994083e-5, + -0.230404776911882601748e-9, + 0.266339227425782031962e-11, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DD: &[f64] = &[ + 1.0, + 1.3653349817554063097, + 0.762059164553623404043, + 0.220091105764131249824, + 0.0341589143670947727934, + 0.00263861676657015992959, + 0.764675292302794483503e-4, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_EN: &[f64] = &[ + -0.0167431005076633737133, + -0.00112951438745580278863, + 0.00105628862152492910091, + 0.000209386317487588078668, + 0.149624783758342370182e-4, + 0.449696789927706453732e-6, + 0.462596163522878599135e-8, + -0.281128735628831791805e-13, + 0.99055709973310326855e-16, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_ED: &[f64] = &[ + 1.0, + 0.591429344886417493481, + 0.138151865749083321638, + 0.0160746087093676504695, + 0.000964011807005165528527, + 0.275335474764726041141e-4, + 0.282243172016108031869e-6, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FN: &[f64] = &[ + -0.0024978212791898131227, + -0.779190719229053954292e-5, + 0.254723037413027451751e-4, + 0.162397777342510920873e-5, + 0.396341011304801168516e-7, + 0.411632831190944208473e-9, + 0.145596286718675035587e-11, + -0.116765012397184275695e-17, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FD: &[f64] = &[ + 1.0, + 0.207123112214422517181, + 0.0169410838120975906478, + 0.000690538265622684595676, + 0.145007359818232637924e-4, + 0.144437756628144157666e-6, + 0.509761276599778486139e-9, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GN: &[f64] = &[ + -0.000539042911019078575891, + -0.28398759004727721098e-6, + 0.899465114892291446442e-6, + 0.229345859265920864296e-7, + 0.225561444863500149219e-9, + 0.947846627503022684216e-12, + 0.135880130108924861008e-14, + -0.348890393399948882918e-21, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GD: &[f64] = &[ + 1.0, + 0.0845746234001899436914, + 0.00282092984726264681981, + 0.468292921940894236786e-4, + 0.399968812193862100054e-6, + 0.161809290887904476097e-8, + 0.231558608310259605225e-11, +]; + +/// `erf_impl` computes the error function at `z`. +/// If `inv` is true, `1 - erf` is calculated as opposed to `erf` +fn erf_impl(z: f64, inv: bool) -> f64 { + if z < 0.0 { + if !inv { + return -erf_impl(-z, false); + } + if z < -0.5 { + return 2.0 - erf_impl(-z, true); + } + return 1.0 + erf_impl(-z, false); + } + + let result = if z < 0.5 { + if z < 1e-10 { + z * 1.125 + z * 0.003379167095512573896158903121545171688 + } else { + z * 1.125 + + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD) + } + } else if z < 110.0 { + let (r, b) = if z < 0.75 { + ( + evaluate::polynomial(z - 0.5, ERF_IMPL_BN) + / evaluate::polynomial(z - 0.5, ERF_IMPL_BD), + 0.3440242112, + ) + } else if z < 1.25 { + ( + evaluate::polynomial(z - 0.75, ERF_IMPL_CN) + / evaluate::polynomial(z - 0.75, ERF_IMPL_CD), + 0.419990927, + ) + } else if z < 2.25 { + ( + evaluate::polynomial(z - 1.25, ERF_IMPL_DN) + / evaluate::polynomial(z - 1.25, ERF_IMPL_DD), + 0.4898625016, + ) + } else if z < 3.5 { + ( + evaluate::polynomial(z - 2.25, ERF_IMPL_EN) + / evaluate::polynomial(z - 2.25, ERF_IMPL_ED), + 0.5317370892, + ) + } else if z < 5.25 { + ( + evaluate::polynomial(z - 3.5, ERF_IMPL_FN) + / evaluate::polynomial(z - 3.5, ERF_IMPL_FD), + 0.5489973426, + ) + } else if z < 8.0 { + ( + evaluate::polynomial(z - 5.25, ERF_IMPL_GN) + / evaluate::polynomial(z - 5.25, ERF_IMPL_GD), + 0.5571740866, + ) + } else if z < 11.5 { + ( + evaluate::polynomial(z - 8.0, ERF_IMPL_HN) + / evaluate::polynomial(z - 8.0, ERF_IMPL_HD), + 0.5609807968, + ) + } else if z < 17.0 { + ( + evaluate::polynomial(z - 11.5, ERF_IMPL_IN) + / evaluate::polynomial(z - 11.5, ERF_IMPL_ID), + 0.5626493692, + ) + } else if z < 24.0 { + ( + evaluate::polynomial(z - 17.0, ERF_IMPL_JN) + / evaluate::polynomial(z - 17.0, ERF_IMPL_JD), + 0.5634598136, + ) + } else if z < 38.0 { + ( + evaluate::polynomial(z - 24.0, ERF_IMPL_KN) + / evaluate::polynomial(z - 24.0, ERF_IMPL_KD), + 0.5638477802, + ) + } else if z < 60.0 { + ( + evaluate::polynomial(z - 38.0, ERF_IMPL_LN) + / evaluate::polynomial(z - 38.0, ERF_IMPL_LD), + 0.5640528202, + ) + } else if z < 85.0 { + ( + evaluate::polynomial(z - 60.0, ERF_IMPL_MN) + / evaluate::polynomial(z - 60.0, ERF_IMPL_MD), + 0.5641309023, + ) + } else { + ( + evaluate::polynomial(z - 85.0, ERF_IMPL_NN) + / evaluate::polynomial(z - 85.0, ERF_IMPL_ND), + 0.5641584396, + ) + }; + let g = (-z * z).exp() / z; + g * b + g * r + } else { + 0.0 + }; + + if inv && z >= 0.5 { + result + } else if z >= 0.5 || inv { + 1.0 - result + } else { + result + } +} + +// `erf_inv_impl` computes the inverse error function where +// `p`,`q`, and `s` are the first, second, and third intermediate +// parameters respectively +fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 { + let result = if p <= 0.5 { + let y = 0.0891314744949340820313; + let g = p * (p + 10.0); + let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD); + g * y + g * r + } else if q >= 0.25 { + let y = 2.249481201171875; + let g = (-2.0 * q.ln()).sqrt(); + let xs = q - 0.25; + let r = + evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD); + g / (y + r) + } else { + let x = (-q.ln()).sqrt(); + if x < 3.0 { + let y = 0.807220458984375; + let xs = x - 1.125; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN) + / evaluate::polynomial(xs, ERF_INV_IMPL_CD); + y * x + r * x + } else if x < 6.0 { + let y = 0.93995571136474609375; + let xs = x - 3.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN) + / evaluate::polynomial(xs, ERF_INV_IMPL_DD); + y * x + r * x + } else if x < 18.0 { + let y = 0.98362827301025390625; + let xs = x - 6.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN) + / evaluate::polynomial(xs, ERF_INV_IMPL_ED); + y * x + r * x + } else if x < 44.0 { + let y = 0.99714565277099609375; + let xs = x - 18.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN) + / evaluate::polynomial(xs, ERF_INV_IMPL_FD); + y * x + r * x + } else { + let y = 0.99941349029541015625; + let xs = x - 44.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN) + / evaluate::polynomial(xs, ERF_INV_IMPL_GD); + y * x + r * x + } + }; + s * result +} diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 97e195ef..527646d6 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -1,4 +1,7 @@ -pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { +pub trait VecOps: num_traits::NumAssign + Copy { + fn min(self, rhs: Self) -> Self; + fn max(self, rhs: Self) -> Self; + /// Dot-product of two vectors. /// /// # Safety @@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x > *res { - *res = x - } + *res = (*res).max(*xs.add(i)) } } @@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x < *res { - *res = x - } + *res = (*res).min(*xs.add(i)) } } } impl VecOps for f32 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { super::vec_dot_f32(lhs, rhs, res, len) } @@ -76,6 +83,16 @@ impl VecOps for f32 { impl VecOps for half::f16 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { let mut res_f32 = 0f32; super::vec_dot_f16(lhs, rhs, &mut res_f32, len); @@ -83,11 +100,61 @@ impl VecOps for half::f16 { } } -impl VecOps for f64 {} -impl VecOps for half::bf16 {} -impl VecOps for u8 {} -impl VecOps for u32 {} -impl VecOps for i64 {} +impl VecOps for f64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for half::bf16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for u8 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for u32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for i64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index 9a8e6317..50afb30f 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -1,3 +1,4 @@ +pub mod erf; pub mod kernels; trait Cpu<const ARR: usize> { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index ed3dd3fc..4e808b34 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; +use rayon::prelude::*; + +const USE_IM2COL_CONV1D: bool = true; +const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. @@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U } // This function maps over two strided index sequences. -fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( +pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( } // Similar to binary_map but with vectorized variants. -fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( +pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -723,6 +727,36 @@ impl Map1 for MaxPool2D { } } +struct UpsampleNearest1D(usize); + +impl Map1 for UpsampleNearest1D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // TODO: Specialized implementation for the case 2*sz? + let dst_sz = self.0; + let (b_sz, c, src_sz) = layout.shape().dims3()?; + let stride = layout.stride(); + let stride_sz = stride[2]; + let src_index = layout.start_offset(); + let scale_sz = src_sz as f64 / dst_sz as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_sz]; + let src_idxs = (0..dst_sz) + .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize)) + .collect::<Vec<_>>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_sz..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_sz..]; + let src_index = src_index + c_idx * stride[1]; + for (idx, src_idx) in src_idxs.iter().enumerate() { + dst[idx] = src[src_index + src_idx * stride_sz] + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset in 0..p.k_size { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * l_out; let k_cont = (0..p.c_in) .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) @@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + l_k, + stride, + dilation, + padding, + } = self; + let (b, c, l) = layout.shape().dims3()?; + let l_out = self.l_out(l); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * l_out * c * l_k]; + let (src_s0, src_s1, src_s2) = { + let s = layout.stride(); + (s[0], s[1], s[2]) + }; + // TODO: provide specialized kernels for the common use cases. + // - l_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * l_out * c * l_k; + for l_idx in 0..l_out { + let dst_idx = dst_idx + l_idx * c * l_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * l_k; + let src_idx = c_idx * src_s1 + src_idx; + for l_k_idx in 0..l_k { + let src_l = l_idx * stride + l_k_idx * dilation; + if padding != 0 && (src_l < padding || src_l >= l + padding) { + continue; + } + let src_l = src_l - padding; + let src_idx = src_idx + src_l * src_s2; + let dst_idx = dst_idx + l_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + Ok(dst) + } +} + struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { @@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset_h in 0..p.k_h { for offset_w in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * out_w * out_h; let k_cont = (0..p.c_in) .map(|c_in_idx| { @@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> { } } } - let num_threads = crate::utils::get_num_threads(); for k_y in 0..p.k_h { for k_x in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let k_cont = (0..p.c_in) .map(|c_in_idx| { k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] @@ -1298,8 +1461,9 @@ impl Map2 for MatMul { ) -> Result<Vec<T>> { use gemm::{gemm, Parallelism}; - if T::DTYPE == DType::BF16 { - return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } let (b, m, n, k) = self.0; @@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage { MaxPool2D(kernel_size, stride).map(self, layout) } + fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> { + UpsampleNearest1D(sz).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { UpsampleNearest2D(h, w).map(self, layout) } @@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { - Conv1D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV1D { + return Conv1D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col1D { + l_k: params.k_size, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let l_out = params.l_out(); + let k = op.l_k * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv2d( @@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { - Conv2D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d( diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 663f2319..00fd1d04 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; -use candle_kernels as kernels; +pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ @@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice { // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; let slice = match dtype { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { @@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) @@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } @@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice { } #[derive(Debug)] -enum CudaStorageSlice { +pub enum CudaStorageSlice { U8(CudaSlice<u8>), U32(CudaSlice<u32>), I64(CudaSlice<i64>), @@ -394,7 +401,7 @@ enum CudaStorageSlice { } type S = CudaStorageSlice; -trait Map1 { +pub trait Map1 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src: &CudaSlice<T>, @@ -416,7 +423,7 @@ trait Map1 { } } -trait Map2 { +pub trait Map2 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -441,7 +448,7 @@ trait Map2 { } } -trait Map2InPlace { +pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, dst: &mut CudaSlice<T>, @@ -472,7 +479,7 @@ trait Map2InPlace { } } -trait Map1Any { +pub trait Map1Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( &self, src: &CudaSlice<T>, @@ -495,7 +502,7 @@ trait Map1Any { } } -trait Map2Any { +pub trait Map2Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -532,7 +539,7 @@ impl Map1 for Clone { } } -fn kernel_name<T: WithDType>(root: &str) -> String { +pub fn kernel_name<T: WithDType>(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } @@ -593,6 +600,105 @@ impl Map1 for Elu { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let l_out = self.l_out(dims[2]); + let dst_el = dims[0] * l_out * dims[1] * self.l_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + l_out, + self.l_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let (h_out, w_out) = self.hw_out(dims[2], dims[3]); + let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + h_out, + w_out, + self.h_k, + self.w_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + struct Powf(f64); impl Map1 for Powf { fn f<T: DeviceRepr + WithDType>( @@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>( #[derive(Debug)] pub struct CudaStorage { - slice: CudaStorageSlice, - device: CudaDevice, + pub slice: CudaStorageSlice, + pub device: CudaDevice, } pub trait CudaDType: Sized { @@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { + const USE_IM2COL_CONV1D: bool = true; + let device = self.device().clone(); - let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV1D { + let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col1D { + l_k: params.k_size, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(not(feature = "cudnn"))] @@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { + const USE_IM2COL_CONV2D: bool = true; + let device = self.device().clone(); - let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV2D { + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(feature = "cudnn")] @@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> { + crate::bail!("upsample-nearest1d is not supported on cuda") + } + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { let device = self.device().clone(); let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; @@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index 235ad6e3..dd466ba2 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d< let x_shape = [ params.b_size as i32, params.c_in as i32, - params.i_w as i32, params.i_h as i32, + params.i_w as i32, ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { @@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d< [ params.c_out as i32, params.c_in as i32, - params.k_w as i32, params.k_h as i32, + params.k_w as i32, ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); let y = cudnn.create_4d_tensor( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, - [params.b_size as i32, params.c_out as i32, w_out, h_out], + [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; let conv2d = Conv2dForward { conv: &conv, diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index adfc4a3c..c7a1567f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,15 +1,24 @@ +//! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; use crate::{CpuStorage, Error, Result}; +/// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Unsigned 8 bits integer. U8, + // Unsigned 32 bits integer. U32, + // Signed 64 bits integer. I64, + // Brain floating-point using half precision (16 bits). BF16, + // Floating-point using half precision (16 bits). F16, + // Floating-point using single precision (32 bits). F32, + // Floating-point using double precision (64 bits). F64, } @@ -33,6 +42,7 @@ impl std::str::FromStr for DType { } impl DType { + /// String representation for dtypes. pub fn as_str(&self) -> &'static str { match self { Self::U8 => "u8", @@ -45,6 +55,7 @@ impl DType { } } + /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`. pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 6c896653..5cc9c6d8 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 1cf20a84..be8f7b07 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -30,7 +30,7 @@ pub enum Error { UnsupportedDTypeForOp(DType, &'static str), // === Dimension Index Errors === - #[error("{op}: dimension index {dim} out of range for {shape:?}")] + #[error("{op}: dimension index {dim} out of range for shape {shape:?}")] DimOutOfRange { shape: Shape, dim: i32, @@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>; impl Error { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Wrapped(Box::new(err)) + Self::Wrapped(Box::new(err)).bt() } pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Msg(err.to_string()) + Self::Msg(err.to_string()).bt() } pub fn bt(self) -> Self { diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 2b6d694b..7b84d316 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -46,19 +46,31 @@ impl Tensor { current_dim += 1; out } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + crate::bail!("multi-dimensional tensor indexing is not supported") + } + let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"), }; } Ok(x) } } -#[derive(Debug, Clone)] +#[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { /// This selects the elemnts for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), + /// Indexing via a 1d tensor + IndexSelect(Tensor), + Err(Error), } impl From<usize> for TensorIndexer { @@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer { } } +impl From<&[u32]> for TensorIndexer { + fn from(index: &[u32]) -> Self { + match Tensor::new(index, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<Vec<u32>> for TensorIndexer { + fn from(index: Vec<u32>) -> Self { + let len = index.len(); + match Tensor::from_vec(index, len, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<&Tensor> for TensorIndexer { + fn from(tensor: &Tensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer { diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index a0347416..52effdcf 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -59,6 +59,7 @@ mod op; pub mod pickle; pub mod quantized; pub mod safetensors; +pub mod scalar; pub mod shape; mod storage; mod strided_index; @@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) { } // A simple trait defining a module with forward method using a single argument. -pub trait Module: std::fmt::Debug { +pub trait Module { fn forward(&self, xs: &Tensor) -> Result<Tensor>; - - /// Change the module to use training mode vs eval mode. - /// - /// The default implementation does nothing as this is only used for a couple modules such as - /// dropout or batch-normalization. - fn set_training(&mut self, _training: bool) {} } impl Module for quantized::QMatMul { @@ -124,3 +119,9 @@ impl Module for quantized::QMatMul { self.forward(xs) } } + +impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self(xs) + } +} diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbfc9c1a..4882a205 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -58,6 +58,8 @@ pub enum UnaryOp { Sqr, Sqrt, Gelu, + GeluErf, + Erf, Relu, Tanh, } @@ -116,6 +118,7 @@ pub enum Op { stride: (usize, usize), }, + UpsampleNearest1D(Tensor), UpsampleNearest2D(Tensor), Cat(Vec<Tensor>, usize), @@ -324,6 +327,8 @@ pub(crate) struct Recip; pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; +pub(crate) struct GeluErf; +pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; @@ -600,6 +605,92 @@ impl UnaryOpT for Gelu { fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::vd_gelu(xs, ys) } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::vs_gelu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::vd_gelu(xs, ys) + } +} + +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } +} + +impl UnaryOpT for GeluErf { + const NAME: &'static str = "gelu_erf"; + const KERNEL: &'static str = "ugelu_erf"; + const V: Self = GeluErf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } } impl UnaryOpT for Relu { diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 65fd6a6e..a0fe455c 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34); pub struct BlockQ8_1 { pub(crate) d: f16, pub(crate) s: f16, - pub(crate) qs: [u8; QK8_1], + pub(crate) qs: [i8; QK8_1], } const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36); @@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 { } sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) } Ok(sumf) } @@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 { } sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) } Ok(sumf) } @@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 { for j in 0..Self::BLCK_SIZE / 2 { let v0 = xs[j] * id; let v1 = xs[j + Self::BLCK_SIZE / 2] * id; - ys.qs[j] = f32::round(v0) as u8; - ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8; + ys.qs[j] = f32::round(v0) as i8; + ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8; sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32; } ys.s = f16::from_f32(sum as f32) * ys.d; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 5c2bb2b2..f627f0f6 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -229,7 +229,7 @@ impl QTensor { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct QMatMul(std::sync::Arc<QTensor>); impl QMatMul { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index f37bb8ef..d588ea67 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -78,11 +78,7 @@ impl st::View for &Tensor { } impl Tensor { - pub fn save_safetensors<P: AsRef<std::path::Path>>( - &self, - name: &str, - filename: P, - ) -> Result<()> { + pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } @@ -267,7 +263,7 @@ impl MmapedFile { /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. - pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let inner = memmap2::MmapOptions::new() diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs new file mode 100644 index 00000000..43e1f4c8 --- /dev/null +++ b/candle-core/src/scalar.rs @@ -0,0 +1,23 @@ +use crate::{Result, Tensor, WithDType}; + +pub enum TensorScalar { + Tensor(Tensor), + Scalar(Tensor), +} + +pub trait TensorOrScalar { + fn to_tensor_scalar(self) -> Result<TensorScalar>; +} + +impl TensorOrScalar for &Tensor { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + Ok(TensorScalar::Tensor(self.clone())) + } +} + +impl<T: WithDType> TensorOrScalar for T { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + let scalar = Tensor::new(self, &crate::Device::Cpu)?; + Ok(TensorScalar::Scalar(scalar)) + } +} diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index aea8b887..4d500e7f 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -1,3 +1,4 @@ +//! The shape of a tensor is a tuple with the size of each of its dimensions. #![allow(clippy::redundant_closure_call)] use crate::{Error, Result}; @@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape { } } +impl From<(usize, usize, usize, usize, usize, usize)> for Shape { + fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { + Self(vec![ + d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, + ]) + } +} + impl From<Vec<usize>> for Shape { fn from(dims: Vec<usize>) -> Self { Self(dims) @@ -119,6 +128,7 @@ impl Shape { Self(dims.to_vec()) } + /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc. pub fn rank(&self) -> usize { self.0.len() } @@ -127,10 +137,12 @@ impl Shape { self.0 } + /// The dimensions as a slice of `usize`. pub fn dims(&self) -> &[usize] { &self.0 } + /// The total number of elements, this is the product of all dimension sizes. pub fn elem_count(&self) -> usize { self.0.iter().product() } @@ -182,6 +194,8 @@ impl Shape { true } + /// Modifies the shape by adding a list of additional dimensions at the end of the existing + /// dimensions. pub fn extend(mut self, additional_dims: &[usize]) -> Self { self.0.extend(additional_dims); self @@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) { } } +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + let d5 = self.5.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4, d5]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); @@ -457,3 +494,171 @@ mod tests { assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result<Shape>; +} + +impl<S: Into<Shape>> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result<Shape> { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + Ok(el_count.into()) + } +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((el_count / d1, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, ()) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((d1, el_count / d1).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, ()) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, ()) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, (), d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, d4, ()) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, d4, el_count / d).into()) + } +} diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8bd14ea9..9bd1fed6 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -369,6 +369,19 @@ impl Storage { } } + pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e181f240..9dccf2b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,8 +1,10 @@ +//! Tensors are N-dimenional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, }; +use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -103,6 +105,28 @@ macro_rules! binary_op { }; } +macro_rules! binary_op_scalar { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?; + let storage = self.storage().binary_impl::<crate::op::$op_name>( + &*rhs.storage(), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + macro_rules! broadcast_binary_op { ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { @@ -445,8 +469,8 @@ impl Tensor { binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); - binary_op!(maximum, Maximum); - binary_op!(minimum, Minimum); + binary_op_scalar!(maximum, Maximum); + binary_op_scalar!(minimum, Minimum); broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_sub, sub); @@ -465,6 +489,8 @@ impl Tensor { unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); + unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); unary_op!(relu, Relu); /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple @@ -642,7 +668,12 @@ impl Tensor { let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); dims[dim] = 1; - let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())); + let op = match op { + ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => { + BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())) + } + ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(), + }; let res = from_storage(storage, dims, op, false); if keepdim { Ok(res) @@ -775,8 +806,15 @@ impl Tensor { /// comparison operation is specified by the `op` argument. /// /// The returned tensor has the same shape as the original tensors and uses `u8` elements. - pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> { - let shape = self.same_shape_binary_op(rhs, "cmp")?; + pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, "cmp")?; let storage = self .storage() .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; @@ -785,45 +823,68 @@ impl Tensor { } /// Element-wise equality. - pub fn eq(&self, rhs: &Self) -> Result<Self> { + pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Eq) } /// Element-wise non-equality. - pub fn ne(&self, rhs: &Self) -> Result<Self> { + pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ne) } /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < /// rhs` and 0 otherwise. - pub fn lt(&self, rhs: &Self) -> Result<Self> { + pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Lt) } /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > /// rhs` and 0 otherwise. - pub fn gt(&self, rhs: &Self) -> Result<Self> { + pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Gt) } /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= /// rhs` and 0 otherwise. - pub fn ge(&self, rhs: &Self) -> Result<Self> { + pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ge) } /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= /// rhs` and 0 otherwise. - pub fn le(&self, rhs: &Self) -> Result<Self> { + pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Le) } - /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the + /// Clamp the tensor values to be between `min` and `max`. + pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> { + self.maximum(min)?.minimum(max) + } + + /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element. + /// + /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned + /// tensor also has three dimensions, `(batch, channels, target_size)`. + pub fn interpolate1d(&self, target_size: usize) -> Result<Self> { + let (n, c, _l) = self.dims3()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest1D); + let storage = self + .storage() + .upsample_nearest1d(self.layout(), target_size)?; + Ok(from_storage(storage, (n, c, target_size), op, false)) + } + + /// Alias for `interpolate1d`. + pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> { + self.interpolate1d(target_size) + } + + /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the /// nearest element. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let storage = self @@ -832,6 +893,11 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } + /// Alias for `interpolate2d`. + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + self.interpolate2d(target_h, target_w) + } + /// 2D average pooling over an input tensor with multiple channels. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned @@ -1684,12 +1750,15 @@ impl Tensor { Ok(from_storage(storage, shape, BackpropOp::none(), true)) } - // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. /// + /// The shape can be specified using a tuple of `usize` and at most one `()` in which case + /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so + /// as to match the number of elements in the tensor. + /// /// ```rust /// # use candle_core::{Tensor, DType, Device, D}; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; @@ -1699,10 +1768,14 @@ impl Tensor { /// /// let c = a.reshape((3, 2))?; /// assert_eq!(c.shape().dims(), &[3, 2]); + /// + /// let c = a.reshape((2, (), 1))?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { - let shape = shape.into(); + pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> { + let shape = s.into_shape(self.elem_count())?; if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), @@ -1836,6 +1909,34 @@ impl Tensor { for arg in args { arg.as_ref().check_dim(dim, "cat")?; } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } if dim == 0 { Self::cat0(args) } else { |