summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/accelerate.rs32
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs12
-rw-r--r--candle-core/src/cpu/erf.rs763
-rw-r--r--candle-core/src/cpu/kernels.rs95
-rw-r--r--candle-core/src/cpu/mod.rs1
-rw-r--r--candle-core/src/cpu_backend.rs265
-rw-r--r--candle-core/src/cuda_backend.rs223
-rw-r--r--candle-core/src/cudnn.rs6
-rw-r--r--candle-core/src/dtype.rs11
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/error.rs6
-rw-r--r--candle-core/src/indexer.rs39
-rw-r--r--candle-core/src/lib.rs15
-rw-r--r--candle-core/src/op.rs91
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/src/quantized/mod.rs2
-rw-r--r--candle-core/src/safetensors.rs8
-rw-r--r--candle-core/src/scalar.rs23
-rw-r--r--candle-core/src/shape.rs205
-rw-r--r--candle-core/src/storage.rs13
-rw-r--r--candle-core/src/tensor.rs133
22 files changed, 1871 insertions, 85 deletions
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs
index 87e0ee8d..1cb34e19 100644
--- a/candle-core/src/accelerate.rs
+++ b/candle-core/src/accelerate.rs
@@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
+#[inline]
+pub fn vs_tanh_inplace(y: &mut [f32]) {
+ unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vd_tanh_inplace(y: &mut [f64]) {
+ unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vs_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vd_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 67a08714..03a07434 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index d2099df7..a2548198 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -91,13 +91,14 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
- | Op::Reduce(node, _, _)
+ | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -111,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@@ -262,6 +264,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
+ Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest1d",
+ })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
@@ -437,6 +442,10 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
+ Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
+ Op::Unary(_, UnaryOp::GeluErf) => {
+ Err(Error::BackwardNotSupported { op: "gelu-erf" })?
+ }
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
@@ -517,6 +526,7 @@ impl Tensor {
}
}
+#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs
new file mode 100644
index 00000000..ca6be53f
--- /dev/null
+++ b/candle-core/src/cpu/erf.rs
@@ -0,0 +1,763 @@
+#![allow(clippy::excessive_precision)]
+// Code taken from https://github.com/statrs-dev/statrs
+//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
+//! related functions
+
+mod evaluate {
+ //! Provides functions that don't have a numerical solution and must
+ //! be solved computationally (e.g. evaluation of a polynomial)
+
+ /// evaluates a polynomial at `z` where `coeff` are the coeffecients
+ /// to a polynomial of order `k` where `k` is the length of `coeff` and the
+ /// coeffecient
+ /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
+ /// `2z^2 - z + 3`
+ ///
+ /// # Remarks
+ ///
+ /// Returns 0 for a 0 length coefficient slice
+ pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
+ let n = coeff.len();
+ if n == 0 {
+ return 0.0;
+ }
+
+ let mut sum = *coeff.last().unwrap();
+ for c in coeff[0..n - 1].iter().rev() {
+ sum = *c + z * sum;
+ }
+ sum
+ }
+}
+use std::f64;
+
+/// `erf` calculates the error function at `x`.
+pub fn erf(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x >= 0.0 && x.is_infinite() {
+ 1.0
+ } else if x <= 0.0 && x.is_infinite() {
+ -1.0
+ } else if x == 0. {
+ 0.0
+ } else {
+ erf_impl(x, false)
+ }
+}
+
+/// `erf_inv` calculates the inverse error function
+/// at `x`.
+pub fn erf_inv(x: f64) -> f64 {
+ if x == 0.0 {
+ 0.0
+ } else if x >= 1.0 {
+ f64::INFINITY
+ } else if x <= -1.0 {
+ f64::NEG_INFINITY
+ } else if x < 0.0 {
+ erf_inv_impl(-x, 1.0 + x, -1.0)
+ } else {
+ erf_inv_impl(x, 1.0 - x, 1.0)
+ }
+}
+
+/// `erfc` calculates the complementary error function
+/// at `x`.
+pub fn erfc(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x == f64::INFINITY {
+ 0.0
+ } else if x == f64::NEG_INFINITY {
+ 2.0
+ } else {
+ erf_impl(x, true)
+ }
+}
+
+/// `erfc_inv` calculates the complementary inverse
+/// error function at `x`.
+pub fn erfc_inv(x: f64) -> f64 {
+ if x <= 0.0 {
+ f64::INFINITY
+ } else if x >= 2.0 {
+ f64::NEG_INFINITY
+ } else if x > 1.0 {
+ erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
+ } else {
+ erf_inv_impl(1.0 - x, x, 1.0)
+ }
+}
+
+// **********************************************************
+// ********** Coefficients for erf_impl polynomial **********
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_impl`
+/// in the interval [1e-10, 0.5].
+const ERF_IMPL_AN: &[f64] = &[
+ 0.00337916709551257388990745,
+ -0.00073695653048167948530905,
+ -0.374732337392919607868241,
+ 0.0817442448733587196071743,
+ -0.0421089319936548595203468,
+ 0.0070165709512095756344528,
+ -0.00495091255982435110337458,
+ 0.000871646599037922480317225,
+];
+
+/// Polynomial coefficients for a denominator of `erf_impl`
+/// in the interval [1e-10, 0.5]
+const ERF_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.218088218087924645390535,
+ 0.412542972725442099083918,
+ -0.0841891147873106755410271,
+ 0.0655338856400241519690695,
+ -0.0120019604454941768171266,
+ 0.00408165558926174048329689,
+ -0.000615900721557769691924509,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BN: &[f64] = &[
+ -0.0361790390718262471360258,
+ 0.292251883444882683221149,
+ 0.281447041797604512774415,
+ 0.125610208862766947294894,
+ 0.0274135028268930549240776,
+ 0.00250839672168065762786937,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BD: &[f64] = &[
+ 1.0,
+ 1.8545005897903486499845,
+ 1.43575803037831418074962,
+ 0.582827658753036572454135,
+ 0.124810476932949746447682,
+ 0.0113724176546353285778481,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CN: &[f64] = &[
+ -0.0397876892611136856954425,
+ 0.153165212467878293257683,
+ 0.191260295600936245503129,
+ 0.10276327061989304213645,
+ 0.029637090615738836726027,
+ 0.0046093486780275489468812,
+ 0.000307607820348680180548455,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CD: &[f64] = &[
+ 1.0,
+ 1.95520072987627704987886,
+ 1.64762317199384860109595,
+ 0.768238607022126250082483,
+ 0.209793185936509782784315,
+ 0.0319569316899913392596356,
+ 0.00213363160895785378615014,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DN: &[f64] = &[
+ -0.0300838560557949717328341,
+ 0.0538578829844454508530552,
+ 0.0726211541651914182692959,
+ 0.0367628469888049348429018,
+ 0.00964629015572527529605267,
+ 0.00133453480075291076745275,
+ 0.778087599782504251917881e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.75967098147167528287343,
+ 1.32883571437961120556307,
+ 0.552528596508757581287907,
+ 0.133793056941332861912279,
+ 0.0179509645176280768640766,
+ 0.00104712440019937356634038,
+ -0.106640381820357337177643e-7,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_EN: &[f64] = &[
+ -0.0117907570137227847827732,
+ 0.014262132090538809896674,
+ 0.0202234435902960820020765,
+ 0.00930668299990432009042239,
+ 0.00213357802422065994322516,
+ 0.00025022987386460102395382,
+ 0.120534912219588189822126e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_ED: &[f64] = &[
+ 1.0,
+ 1.50376225203620482047419,
+ 0.965397786204462896346934,
+ 0.339265230476796681555511,
+ 0.0689740649541569716897427,
+ 0.00771060262491768307365526,
+ 0.000371421101531069302990367,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FN: &[f64] = &[
+ -0.00546954795538729307482955,
+ 0.00404190278731707110245394,
+ 0.0054963369553161170521356,
+ 0.00212616472603945399437862,
+ 0.000394984014495083900689956,
+ 0.365565477064442377259271e-4,
+ 0.135485897109932323253786e-5,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FD: &[f64] = &[
+ 1.0,
+ 1.21019697773630784832251,
+ 0.620914668221143886601045,
+ 0.173038430661142762569515,
+ 0.0276550813773432047594539,
+ 0.00240625974424309709745382,
+ 0.891811817251336577241006e-4,
+ -0.465528836283382684461025e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GN: &[f64] = &[
+ -0.00270722535905778347999196,
+ 0.0013187563425029400461378,
+ 0.00119925933261002333923989,
+ 0.00027849619811344664248235,
+ 0.267822988218331849989363e-4,
+ 0.923043672315028197865066e-6,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.814632808543141591118279,
+ 0.268901665856299542168425,
+ 0.0449877216103041118694989,
+ 0.00381759663320248459168994,
+ 0.000131571897888596914350697,
+ 0.404815359675764138445257e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HN: &[f64] = &[
+ -0.00109946720691742196814323,
+ 0.000406425442750422675169153,
+ 0.000274499489416900707787024,
+ 0.465293770646659383436343e-4,
+ 0.320955425395767463401993e-5,
+ 0.778286018145020892261936e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HD: &[f64] = &[
+ 1.0,
+ 0.588173710611846046373373,
+ 0.139363331289409746077541,
+ 0.0166329340417083678763028,
+ 0.00100023921310234908642639,
+ 0.24254837521587225125068e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_IN: &[f64] = &[
+ -0.00056907993601094962855594,
+ 0.000169498540373762264416984,
+ 0.518472354581100890120501e-4,
+ 0.382819312231928859704678e-5,
+ 0.824989931281894431781794e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_ID: &[f64] = &[
+ 1.0,
+ 0.339637250051139347430323,
+ 0.043472647870310663055044,
+ 0.00248549335224637114641629,
+ 0.535633305337152900549536e-4,
+ -0.117490944405459578783846e-12,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JN: &[f64] = &[
+ -0.000241313599483991337479091,
+ 0.574224975202501512365975e-4,
+ 0.115998962927383778460557e-4,
+ 0.581762134402593739370875e-6,
+ 0.853971555085673614607418e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JD: &[f64] = &[
+ 1.0,
+ 0.233044138299687841018015,
+ 0.0204186940546440312625597,
+ 0.000797185647564398289151125,
+ 0.117019281670172327758019e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KN: &[f64] = &[
+ -0.000146674699277760365803642,
+ 0.162666552112280519955647e-4,
+ 0.269116248509165239294897e-5,
+ 0.979584479468091935086972e-7,
+ 0.101994647625723465722285e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KD: &[f64] = &[
+ 1.0,
+ 0.165907812944847226546036,
+ 0.0103361716191505884359634,
+ 0.000286593026373868366935721,
+ 0.298401570840900340874568e-5,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LN: &[f64] = &[
+ -0.583905797629771786720406e-4,
+ 0.412510325105496173512992e-5,
+ 0.431790922420250949096906e-6,
+ 0.993365155590013193345569e-8,
+ 0.653480510020104699270084e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LD: &[f64] = &[
+ 1.0,
+ 0.105077086072039915406159,
+ 0.00414278428675475620830226,
+ 0.726338754644523769144108e-4,
+ 0.477818471047398785369849e-6,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MN: &[f64] = &[
+ -0.196457797609229579459841e-4,
+ 0.157243887666800692441195e-5,
+ 0.543902511192700878690335e-7,
+ 0.317472492369117710852685e-9,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MD: &[f64] = &[
+ 1.0,
+ 0.052803989240957632204885,
+ 0.000926876069151753290378112,
+ 0.541011723226630257077328e-5,
+ 0.535093845803642394908747e-15,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_NN: &[f64] = &[
+ -0.789224703978722689089794e-5,
+ 0.622088451660986955124162e-6,
+ 0.145728445676882396797184e-7,
+ 0.603715505542715364529243e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_ND: &[f64] = &[
+ 1.0,
+ 0.0375328846356293715248719,
+ 0.000467919535974625308126054,
+ 0.193847039275845656900547e-5,
+];
+
+// **********************************************************
+// ********** Coefficients for erf_inv_impl polynomial ******
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AN: &[f64] = &[
+ -0.000508781949658280665617,
+ -0.00836874819741736770379,
+ 0.0334806625409744615033,
+ -0.0126926147662974029034,
+ -0.0365637971411762664006,
+ 0.0219878681111168899165,
+ 0.00822687874676915743155,
+ -0.00538772965071242932965,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.970005043303290640362,
+ -1.56574558234175846809,
+ 1.56221558398423026363,
+ 0.662328840472002992063,
+ -0.71228902341542847553,
+ -0.0527396382340099713954,
+ 0.0795283687341571680018,
+ -0.00233393759374190016776,
+ 0.000886216390456424707504,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BN: &[f64] = &[
+ -0.202433508355938759655,
+ 0.105264680699391713268,
+ 8.37050328343119927838,
+ 17.6447298408374015486,
+ -18.8510648058714251895,
+ -44.6382324441786960818,
+ 17.445385985570866523,
+ 21.1294655448340526258,
+ -3.67192254707729348546,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BD: &[f64] = &[
+ 1.0,
+ 6.24264124854247537712,
+ 3.9713437953343869095,
+ -28.6608180499800029974,
+ -20.1432634680485188801,
+ 48.5609213108739935468,
+ 10.8268667355460159008,
+ -22.6436933413139721736,
+ 1.72114765761200282724,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CN: &[f64] = &[
+ -0.131102781679951906451,
+ -0.163794047193317060787,
+ 0.117030156341995252019,
+ 0.387079738972604337464,
+ 0.337785538912035898924,
+ 0.142869534408157156766,
+ 0.0290157910005329060432,
+ 0.00214558995388805277169,
+ -0.679465575181126350155e-6,
+ 0.285225331782217055858e-7,
+ -0.681149956853776992068e-9,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CD: &[f64] = &[
+ 1.0,
+ 3.46625407242567245975,
+ 5.38168345707006855425,
+ 4.77846592945843778382,
+ 2.59301921623620271374,
+ 0.848854343457902036425,
+ 0.152264338295331783612,
+ 0.01105924229346489121,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DN: &[f64] = &[
+ -0.0350353787183177984712,
+ -0.00222426529213447927281,
+ 0.0185573306514231072324,
+ 0.00950804701325919603619,
+ 0.00187123492819559223345,
+ 0.000157544617424960554631,
+ 0.460469890584317994083e-5,
+ -0.230404776911882601748e-9,
+ 0.266339227425782031962e-11,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.3653349817554063097,
+ 0.762059164553623404043,
+ 0.220091105764131249824,
+ 0.0341589143670947727934,
+ 0.00263861676657015992959,
+ 0.764675292302794483503e-4,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_EN: &[f64] = &[
+ -0.0167431005076633737133,
+ -0.00112951438745580278863,
+ 0.00105628862152492910091,
+ 0.000209386317487588078668,
+ 0.149624783758342370182e-4,
+ 0.449696789927706453732e-6,
+ 0.462596163522878599135e-8,
+ -0.281128735628831791805e-13,
+ 0.99055709973310326855e-16,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_ED: &[f64] = &[
+ 1.0,
+ 0.591429344886417493481,
+ 0.138151865749083321638,
+ 0.0160746087093676504695,
+ 0.000964011807005165528527,
+ 0.275335474764726041141e-4,
+ 0.282243172016108031869e-6,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FN: &[f64] = &[
+ -0.0024978212791898131227,
+ -0.779190719229053954292e-5,
+ 0.254723037413027451751e-4,
+ 0.162397777342510920873e-5,
+ 0.396341011304801168516e-7,
+ 0.411632831190944208473e-9,
+ 0.145596286718675035587e-11,
+ -0.116765012397184275695e-17,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FD: &[f64] = &[
+ 1.0,
+ 0.207123112214422517181,
+ 0.0169410838120975906478,
+ 0.000690538265622684595676,
+ 0.145007359818232637924e-4,
+ 0.144437756628144157666e-6,
+ 0.509761276599778486139e-9,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GN: &[f64] = &[
+ -0.000539042911019078575891,
+ -0.28398759004727721098e-6,
+ 0.899465114892291446442e-6,
+ 0.229345859265920864296e-7,
+ 0.225561444863500149219e-9,
+ 0.947846627503022684216e-12,
+ 0.135880130108924861008e-14,
+ -0.348890393399948882918e-21,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.0845746234001899436914,
+ 0.00282092984726264681981,
+ 0.468292921940894236786e-4,
+ 0.399968812193862100054e-6,
+ 0.161809290887904476097e-8,
+ 0.231558608310259605225e-11,
+];
+
+/// `erf_impl` computes the error function at `z`.
+/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
+fn erf_impl(z: f64, inv: bool) -> f64 {
+ if z < 0.0 {
+ if !inv {
+ return -erf_impl(-z, false);
+ }
+ if z < -0.5 {
+ return 2.0 - erf_impl(-z, true);
+ }
+ return 1.0 + erf_impl(-z, false);
+ }
+
+ let result = if z < 0.5 {
+ if z < 1e-10 {
+ z * 1.125 + z * 0.003379167095512573896158903121545171688
+ } else {
+ z * 1.125
+ + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
+ }
+ } else if z < 110.0 {
+ let (r, b) = if z < 0.75 {
+ (
+ evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
+ / evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
+ 0.3440242112,
+ )
+ } else if z < 1.25 {
+ (
+ evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
+ / evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
+ 0.419990927,
+ )
+ } else if z < 2.25 {
+ (
+ evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
+ / evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
+ 0.4898625016,
+ )
+ } else if z < 3.5 {
+ (
+ evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
+ / evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
+ 0.5317370892,
+ )
+ } else if z < 5.25 {
+ (
+ evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
+ / evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
+ 0.5489973426,
+ )
+ } else if z < 8.0 {
+ (
+ evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
+ / evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
+ 0.5571740866,
+ )
+ } else if z < 11.5 {
+ (
+ evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
+ / evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
+ 0.5609807968,
+ )
+ } else if z < 17.0 {
+ (
+ evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
+ / evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
+ 0.5626493692,
+ )
+ } else if z < 24.0 {
+ (
+ evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
+ / evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
+ 0.5634598136,
+ )
+ } else if z < 38.0 {
+ (
+ evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
+ / evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
+ 0.5638477802,
+ )
+ } else if z < 60.0 {
+ (
+ evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
+ / evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
+ 0.5640528202,
+ )
+ } else if z < 85.0 {
+ (
+ evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
+ / evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
+ 0.5641309023,
+ )
+ } else {
+ (
+ evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
+ / evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
+ 0.5641584396,
+ )
+ };
+ let g = (-z * z).exp() / z;
+ g * b + g * r
+ } else {
+ 0.0
+ };
+
+ if inv && z >= 0.5 {
+ result
+ } else if z >= 0.5 || inv {
+ 1.0 - result
+ } else {
+ result
+ }
+}
+
+// `erf_inv_impl` computes the inverse error function where
+// `p`,`q`, and `s` are the first, second, and third intermediate
+// parameters respectively
+fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
+ let result = if p <= 0.5 {
+ let y = 0.0891314744949340820313;
+ let g = p * (p + 10.0);
+ let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
+ g * y + g * r
+ } else if q >= 0.25 {
+ let y = 2.249481201171875;
+ let g = (-2.0 * q.ln()).sqrt();
+ let xs = q - 0.25;
+ let r =
+ evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
+ g / (y + r)
+ } else {
+ let x = (-q.ln()).sqrt();
+ if x < 3.0 {
+ let y = 0.807220458984375;
+ let xs = x - 1.125;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_CD);
+ y * x + r * x
+ } else if x < 6.0 {
+ let y = 0.93995571136474609375;
+ let xs = x - 3.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_DD);
+ y * x + r * x
+ } else if x < 18.0 {
+ let y = 0.98362827301025390625;
+ let xs = x - 6.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_ED);
+ y * x + r * x
+ } else if x < 44.0 {
+ let y = 0.99714565277099609375;
+ let xs = x - 18.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_FD);
+ y * x + r * x
+ } else {
+ let y = 0.99941349029541015625;
+ let xs = x - 44.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_GD);
+ y * x + r * x
+ }
+ };
+ s * result
+}
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 97e195ef..527646d6 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -1,4 +1,7 @@
-pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
+pub trait VecOps: num_traits::NumAssign + Copy {
+ fn min(self, rhs: Self) -> Self;
+ fn max(self, rhs: Self) -> Self;
+
/// Dot-product of two vectors.
///
/// # Safety
@@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x > *res {
- *res = x
- }
+ *res = (*res).max(*xs.add(i))
}
}
@@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x < *res {
- *res = x
- }
+ *res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
}
@@ -76,6 +83,16 @@ impl VecOps for f32 {
impl VecOps for half::f16 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
@@ -83,11 +100,61 @@ impl VecOps for half::f16 {
}
}
-impl VecOps for f64 {}
-impl VecOps for half::bf16 {}
-impl VecOps for u8 {}
-impl VecOps for u32 {}
-impl VecOps for i64 {}
+impl VecOps for f64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for half::bf16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for u8 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for u32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for i64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs
index 9a8e6317..50afb30f 100644
--- a/candle-core/src/cpu/mod.rs
+++ b/candle-core/src/cpu/mod.rs
@@ -1,3 +1,4 @@
+pub mod erf;
pub mod kernels;
trait Cpu<const ARR: usize> {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index ed3dd3fc..4e808b34 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
+use rayon::prelude::*;
+
+const USE_IM2COL_CONV1D: bool = true;
+const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
}
// This function maps over two strided index sequences.
-fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
+pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
}
// Similar to binary_map but with vectorized variants.
-fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
+pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -723,6 +727,36 @@ impl Map1 for MaxPool2D {
}
}
+struct UpsampleNearest1D(usize);
+
+impl Map1 for UpsampleNearest1D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // TODO: Specialized implementation for the case 2*sz?
+ let dst_sz = self.0;
+ let (b_sz, c, src_sz) = layout.shape().dims3()?;
+ let stride = layout.stride();
+ let stride_sz = stride[2];
+ let src_index = layout.start_offset();
+ let scale_sz = src_sz as f64 / dst_sz as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_sz];
+ let src_idxs = (0..dst_sz)
+ .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
+ .collect::<Vec<_>>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_sz..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_sz..];
+ let src_index = src_index + c_idx * stride[1];
+ for (idx, src_idx) in src_idxs.iter().enumerate() {
+ dst[idx] = src[src_index + src_idx * stride_sz]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset in 0..p.k_size {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
+ let &Self {
+ l_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, l) = layout.shape().dims3()?;
+ let l_out = self.l_out(l);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * l_out * c * l_k];
+ let (src_s0, src_s1, src_s2) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - l_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * l_out * c * l_k;
+ for l_idx in 0..l_out {
+ let dst_idx = dst_idx + l_idx * c * l_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * l_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for l_k_idx in 0..l_k {
+ let src_l = l_idx * stride + l_k_idx * dilation;
+ if padding != 0 && (src_l < padding || src_l >= l + padding) {
+ continue;
+ }
+ let src_l = src_l - padding;
+ let src_idx = src_idx + src_l * src_s2;
+ let dst_idx = dst_idx + l_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
+ let &Self {
+ h_k,
+ w_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, h, w) = layout.shape().dims4()?;
+ let (h_out, w_out) = self.hw_out(h, w);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
+ let (src_s0, src_s1, src_s2, src_s3) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2], s[3])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - h_k = w_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
+ for h_idx in 0..h_out {
+ let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
+ for w_idx in 0..w_out {
+ let dst_idx = dst_idx + w_idx * c * h_k * w_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * h_k * w_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for h_k_idx in 0..h_k {
+ let src_h = h_idx * stride + h_k_idx * dilation;
+ if padding != 0 && (src_h < padding || src_h >= h + padding) {
+ continue;
+ }
+ let src_h = src_h - padding;
+ let src_idx = src_idx + src_h * src_s2;
+ let dst_idx = dst_idx + h_k_idx * w_k;
+ for w_k_idx in 0..w_k {
+ let src_w = w_idx * stride + w_k_idx * dilation;
+ if padding != 0 && (src_w < padding || src_w >= w + padding) {
+ continue;
+ }
+ let src_w = src_w - padding;
+ let src_idx = src_idx + src_w * src_s3;
+ let dst_idx = dst_idx + w_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
}
}
}
- let num_threads = crate::utils::get_num_threads();
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
@@ -1298,8 +1461,9 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
- if T::DTYPE == DType::BF16 {
- return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
}
let (b, m, n, k) = self.0;
@@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout)
}
+ fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
+ UpsampleNearest1D(sz).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
@@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
- Conv1D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV1D {
+ return Conv1D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col1D {
+ l_k: params.k_size,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let l_out = params.l_out();
+ let k = op.l_k * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv2d(
@@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
- Conv2D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV2D {
+ return Conv2D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let (h_out, w_out) = (params.out_h(), params.out_w());
+ let k = op.h_k * op.w_k * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv_transpose2d(
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 663f2319..00fd1d04 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
-use candle_kernels as kernels;
+pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
+ // curand can only generate an odd number of values.
+ // https://github.com/huggingface/candle/issues/734
+ let elem_count_round = if elem_count % 2 == 1 {
+ elem_count + 1
+ } else {
+ elem_count
+ };
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice {
}
#[derive(Debug)]
-enum CudaStorageSlice {
+pub enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I64(CudaSlice<i64>),
@@ -394,7 +401,7 @@ enum CudaStorageSlice {
}
type S = CudaStorageSlice;
-trait Map1 {
+pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@@ -416,7 +423,7 @@ trait Map1 {
}
}
-trait Map2 {
+pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@@ -441,7 +448,7 @@ trait Map2 {
}
}
-trait Map2InPlace {
+pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@@ -472,7 +479,7 @@ trait Map2InPlace {
}
}
-trait Map1Any {
+pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@@ -495,7 +502,7 @@ trait Map1Any {
}
}
-trait Map2Any {
+pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@@ -532,7 +539,7 @@ impl Map1 for Clone {
}
}
-fn kernel_name<T: WithDType>(root: &str) -> String {
+pub fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
@@ -593,6 +600,105 @@ impl Map1 for Elu {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let l_out = self.l_out(dims[2]);
+ let dst_el = dims[0] * l_out * dims[1] * self.l_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ l_out,
+ self.l_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
+ let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ h_out,
+ w_out,
+ self.h_k,
+ self.w_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
@@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)]
pub struct CudaStorage {
- slice: CudaStorageSlice,
- device: CudaDevice,
+ pub slice: CudaStorageSlice,
+ pub device: CudaDevice,
}
pub trait CudaDType: Sized {
@@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
+ const USE_IM2COL_CONV1D: bool = true;
+
let device = self.device().clone();
- let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV1D {
+ let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col1D {
+ l_k: params.k_size,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let l_out = params.l_out();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_size * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(not(feature = "cudnn"))]
@@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
+ const USE_IM2COL_CONV2D: bool = true;
+
let device = self.device().clone();
- let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV2D {
+ let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let h_out = params.out_h();
+ let w_out = params.out_w();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_h * params.k_w * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, n))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(feature = "cudnn")]
@@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
+ crate::bail!("upsample-nearest1d is not supported on cuda")
+ }
+
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 235ad6e3..dd466ba2 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [
params.b_size as i32,
params.c_in as i32,
- params.i_w as i32,
params.i_h as i32,
+ params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
- params.k_w as i32,
params.k_h as i32,
+ params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
- [params.b_size as i32, params.c_out as i32, w_out, h_out],
+ [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index adfc4a3c..c7a1567f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -1,15 +1,24 @@
+//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
+/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
+ // Unsigned 8 bits integer.
U8,
+ // Unsigned 32 bits integer.
U32,
+ // Signed 64 bits integer.
I64,
+ // Brain floating-point using half precision (16 bits).
BF16,
+ // Floating-point using half precision (16 bits).
F16,
+ // Floating-point using single precision (32 bits).
F32,
+ // Floating-point using double precision (64 bits).
F64,
}
@@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
}
impl DType {
+ /// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
@@ -45,6 +55,7 @@ impl DType {
}
}
+ /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 6c896653..5cc9c6d8 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 1cf20a84..be8f7b07 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
- #[error("{op}: dimension index {dim} out of range for {shape:?}")]
+ #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Wrapped(Box::new(err))
+ Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Msg(err.to_string())
+ Self::Msg(err.to_string()).bt()
}
pub fn bt(self) -> Self {
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index 2b6d694b..7b84d316 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
+ TensorIndexer::IndexSelect(indexes) => {
+ if indexes.rank() != 1 {
+ crate::bail!("multi-dimensional tensor indexing is not supported")
+ }
+ let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
+ current_dim += 1;
+ out
+ }
+ TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
+ /// Indexing via a 1d tensor
+ IndexSelect(Tensor),
+ Err(Error),
}
impl From<usize> for TensorIndexer {
@@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
}
}
+impl From<&[u32]> for TensorIndexer {
+ fn from(index: &[u32]) -> Self {
+ match Tensor::new(index, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<Vec<u32>> for TensorIndexer {
+ fn from(index: Vec<u32>) -> Self {
+ let len = index.len();
+ match Tensor::from_vec(index, len, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<&Tensor> for TensorIndexer {
+ fn from(tensor: &Tensor) -> Self {
+ TensorIndexer::IndexSelect(tensor.clone())
+ }
+}
+
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index a0347416..52effdcf 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -59,6 +59,7 @@ mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
+pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
@@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
-pub trait Module: std::fmt::Debug {
+pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
-
- /// Change the module to use training mode vs eval mode.
- ///
- /// The default implementation does nothing as this is only used for a couple modules such as
- /// dropout or batch-normalization.
- fn set_training(&mut self, _training: bool) {}
}
impl Module for quantized::QMatMul {
@@ -124,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
+
+impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self(xs)
+ }
+}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index fbfc9c1a..4882a205 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -58,6 +58,8 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
+ GeluErf,
+ Erf,
Relu,
Tanh,
}
@@ -116,6 +118,7 @@ pub enum Op {
stride: (usize, usize),
},
+ UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
@@ -324,6 +327,8 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
+pub(crate) struct GeluErf;
+pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@@ -600,6 +605,92 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
+
+ #[cfg(feature = "accelerate")]
+ const F32_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f32_vec(xs: &[f32], ys: &mut [f32]) {
+ crate::accelerate::vs_gelu(xs, ys)
+ }
+
+ #[cfg(feature = "accelerate")]
+ const F64_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f64_vec(xs: &[f64], ys: &mut [f64]) {
+ crate::accelerate::vd_gelu(xs, ys)
+ }
+}
+
+impl UnaryOpT for Erf {
+ const NAME: &'static str = "erf";
+ const KERNEL: &'static str = "uerf";
+ const V: Self = Erf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ crate::cpu::erf::erf(v)
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
+}
+
+impl UnaryOpT for GeluErf {
+ const NAME: &'static str = "gelu_erf";
+ const KERNEL: &'static str = "ugelu_erf";
+ const V: Self = GeluErf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
}
impl UnaryOpT for Relu {
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 65fd6a6e..a0fe455c 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
pub struct BlockQ8_1 {
pub(crate) d: f16,
pub(crate) s: f16,
- pub(crate) qs: [u8; QK8_1],
+ pub(crate) qs: [i8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
@@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 {
for j in 0..Self::BLCK_SIZE / 2 {
let v0 = xs[j] * id;
let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
- ys.qs[j] = f32::round(v0) as u8;
- ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8;
+ ys.qs[j] = f32::round(v0) as i8;
+ ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
}
ys.s = f16::from_f32(sum as f32) * ys.d;
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 5c2bb2b2..f627f0f6 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -229,7 +229,7 @@ impl QTensor {
}
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index f37bb8ef..d588ea67 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -78,11 +78,7 @@ impl st::View for &Tensor {
}
impl Tensor {
- pub fn save_safetensors<P: AsRef<std::path::Path>>(
- &self,
- name: &str,
- filename: P,
- ) -> Result<()> {
+ pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
- pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs
new file mode 100644
index 00000000..43e1f4c8
--- /dev/null
+++ b/candle-core/src/scalar.rs
@@ -0,0 +1,23 @@
+use crate::{Result, Tensor, WithDType};
+
+pub enum TensorScalar {
+ Tensor(Tensor),
+ Scalar(Tensor),
+}
+
+pub trait TensorOrScalar {
+ fn to_tensor_scalar(self) -> Result<TensorScalar>;
+}
+
+impl TensorOrScalar for &Tensor {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ Ok(TensorScalar::Tensor(self.clone()))
+ }
+}
+
+impl<T: WithDType> TensorOrScalar for T {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ let scalar = Tensor::new(self, &crate::Device::Cpu)?;
+ Ok(TensorScalar::Scalar(scalar))
+ }
+}
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index aea8b887..4d500e7f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -1,3 +1,4 @@
+//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
@@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
+impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
+ fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
+ Self(vec![
+ d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
+ ])
+ }
+}
+
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
@@ -119,6 +128,7 @@ impl Shape {
Self(dims.to_vec())
}
+ /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@@ -127,10 +137,12 @@ impl Shape {
self.0
}
+ /// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
+ /// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@@ -182,6 +194,8 @@ impl Shape {
true
}
+ /// Modifies the shape by adding a list of additional dimensions at the end of the existing
+ /// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
@@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
}
}
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4])
+ }
+}
+
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ let d5 = self.5.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4, d5])
+ }
+}
+
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@@ -457,3 +494,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
+
+pub trait ShapeWithOneHole {
+ fn into_shape(self, el_count: usize) -> Result<Shape>;
+}
+
+impl<S: Into<Shape>> ShapeWithOneHole for S {
+ fn into_shape(self, _el_count: usize) -> Result<Shape> {
+ Ok(self.into())
+ }
+}
+
+impl ShapeWithOneHole for ((),) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ Ok(el_count.into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((el_count / d1, d1).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, ()) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((d1, el_count / d1).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, ()) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, ()) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, (), d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, d4, ()) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, d4, el_count / d).into())
+ }
+}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 8bd14ea9..9bd1fed6 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -369,6 +369,19 @@ impl Storage {
}
}
+ pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e181f240..9dccf2b5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,8 +1,10 @@
+//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
+use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -103,6 +105,28 @@ macro_rules! binary_op {
};
}
+macro_rules! binary_op_scalar {
+ ($fn_name:ident, $op_name:ident) => {
+ pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
+ let storage = self.storage().binary_impl::<crate::op::$op_name>(
+ &*rhs.storage(),
+ self.layout(),
+ rhs.layout(),
+ )?;
+ let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
+ Ok(from_storage(storage, shape.clone(), op, false))
+ }
+ };
+}
+
macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
@@ -445,8 +469,8 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
- binary_op!(maximum, Maximum);
- binary_op!(minimum, Minimum);
+ binary_op_scalar!(maximum, Maximum);
+ binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
@@ -465,6 +489,8 @@ impl Tensor {
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
+ unary_op!(gelu_erf, GeluErf);
+ unary_op!(erf, Erf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
@@ -642,7 +668,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
- let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
+ let op = match op {
+ ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
+ BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
+ }
+ ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
+ };
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
@@ -775,8 +806,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
- pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
- let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@@ -785,45 +823,68 @@ impl Tensor {
}
/// Element-wise equality.
- pub fn eq(&self, rhs: &Self) -> Result<Self> {
+ pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
- pub fn ne(&self, rhs: &Self) -> Result<Self> {
+ pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
- pub fn lt(&self, rhs: &Self) -> Result<Self> {
+ pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
- pub fn gt(&self, rhs: &Self) -> Result<Self> {
+ pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
- pub fn ge(&self, rhs: &Self) -> Result<Self> {
+ pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
- pub fn le(&self, rhs: &Self) -> Result<Self> {
+ pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
- /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
+ /// Clamp the tensor values to be between `min` and `max`.
+ pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
+ self.maximum(min)?.minimum(max)
+ }
+
+ /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
+ ///
+ /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
+ /// tensor also has three dimensions, `(batch, channels, target_size)`.
+ pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
+ let (n, c, _l) = self.dims3()?;
+ let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
+ let storage = self
+ .storage()
+ .upsample_nearest1d(self.layout(), target_size)?;
+ Ok(from_storage(storage, (n, c, target_size), op, false))
+ }
+
+ /// Alias for `interpolate1d`.
+ pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
+ self.interpolate1d(target_size)
+ }
+
+ /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
- pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
@@ -832,6 +893,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
+ /// Alias for `interpolate2d`.
+ pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ self.interpolate2d(target_h, target_w)
+ }
+
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@@ -1684,12 +1750,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
- // TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
+ /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
+ /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
+ /// as to match the number of elements in the tensor.
+ ///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@@ -1699,10 +1768,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
+ ///
+ /// let c = a.reshape((2, (), 1))?;
+ /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
+ ///
/// # Ok::<(), candle_core::Error>(())
/// ```
- pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
- let shape = shape.into();
+ pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
+ let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
@@ -1836,6 +1909,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg0.rank() != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: arg0.rank(),
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx != dim && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ }
if dim == 0 {
Self::cat0(args)
} else {