diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/accelerate.rs | 111 | ||||
-rw-r--r-- | candle-core/src/backend.rs | 11 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 17 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 29 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 288 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 18 | ||||
-rw-r--r-- | candle-core/src/device.rs | 7 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 6 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 18 | ||||
-rw-r--r-- | candle-core/src/ggml.rs | 582 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 3 | ||||
-rw-r--r-- | candle-core/src/op.rs | 18 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 58 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 123 | ||||
-rw-r--r-- | candle-core/src/utils.rs | 14 |
15 files changed, 1279 insertions, 24 deletions
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs new file mode 100644 index 00000000..8b0df5c1 --- /dev/null +++ b/candle-core/src/accelerate.rs @@ -0,0 +1,111 @@ +#![allow(dead_code)] +use libc::{c_char, c_double, c_float, c_int}; + +mod ffi { + use super::*; + extern "C" { + // It would be nice to be able to switch to the NEWLAPACK version of the function but this + // seems to trigger some link error. Available function names can be seen here: + // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd + #[link_name = "sgemm_"] + pub fn sgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_float, + a: *const c_float, + lda: *const c_int, + b: *const c_float, + ldb: *const c_int, + beta: *const c_float, + c: *mut c_float, + ldc: *const c_int, + ); + #[link_name = "dgemm_"] + pub fn dgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_double, + a: *const c_double, + lda: *const c_int, + b: *const c_double, + ldb: *const c_int, + beta: *const c_double, + c: *mut c_double, + ldc: *const c_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn sgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f32, + a: &[f32], + lda: i32, + b: &[f32], + ldb: i32, + beta: f32, + c: &mut [f32], + ldc: i32, +) { + ffi::sgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn dgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f64, + a: &[f64], + lda: i32, + b: &[f64], + ldb: i32, + beta: f64, + c: &mut [f64], + ldc: i32, +) { + ffi::dgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 345db0e5..a8e5ac52 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -37,6 +37,17 @@ pub trait BackendStorage: Sized { _params: &crate::conv::ParamsConv1D, ) -> Result<Self>; + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv2D, + ) -> Result<Self>; + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; fn scatter_add( &self, diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f5cc8191..0eab508e 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -55,6 +55,11 @@ impl Tensor { kernel: rhs, .. } + | Op::Conv2D { + arg: lhs, + kernel: rhs, + .. + } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) @@ -81,6 +86,8 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest2D(node) + | Op::AvgPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -163,6 +170,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest2d", + })?, Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; @@ -291,6 +303,11 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } + Op::Unary(arg, UnaryOp::Recip) => { + let sum_grad = grads.or_insert(arg)?; + let grad = (grad / arg.sqr()?)?; + *sum_grad = sum_grad.sub(&grad)? + } &Op::Narrow(ref arg, dim, start_idx, len) => { let arg_dims = arg.dims(); let left_pad = if start_idx == 0 { diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 4cf9d0ad..30799459 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,3 +25,32 @@ impl ParamsConv1D { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConv2D { + pub(crate) b_size: usize, + pub(crate) i_h: usize, + pub(crate) i_w: usize, + pub(crate) k_h: usize, + pub(crate) k_w: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv2D { + pub(crate) fn out_h(&self) -> usize { + let dilation = 1; + (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_w(&self) -> usize { + let dilation = 1; + (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec<usize> { + vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + } +} diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8563721c..10c6cc4a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -633,6 +633,84 @@ impl Map1 for Affine { } } +struct AvgPool2D((usize, usize), (usize, usize)); + +impl Map1 for AvgPool2D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + let scale = 1f64 / (k_h * k_w) as f64; + let scale = T::from_f64(scale); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut sum = T::zero(); + for m in 0..k_h { + for n in 0..k_w { + let m = k_h * h_idx + m; + let n = k_w * w_idx + n; + sum += src[src_index + m * stride_h + n * stride_w] + } + } + dst[h_idx * w_out + w_idx] = sum * scale; + } + } + } + } + Ok(dst) + } +} + +struct UpsampleNearest2D(usize, usize); + +impl Map1 for UpsampleNearest2D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // TODO: Specialized implementation for the case 2*h, 2*w? + let (dst_h, dst_w) = (self.0, self.1); + let (b_sz, c, src_h, src_w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let src_index = layout.start_offset(); + let scale_h = src_h as f64 / dst_h as f64; + let scale_w = src_w as f64 / dst_w as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w]; + let src_h_idxs = (0..src_h) + .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize)) + .collect::<Vec<_>>(); + let src_w_idxs = (0..src_w) + .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize)) + .collect::<Vec<_>>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_h * dst_w..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_h * dst_w..]; + let src_index = src_index + c_idx * stride[1]; + for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() { + for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() { + let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w; + dst[h_idx * dst_w + w_idx] = src[src_index] + } + } + } + } + Ok(dst) + } +} + struct Gather<'a, I: IntDType> { ids: &'a [I], ids_l: &'a Layout, @@ -921,7 +999,6 @@ impl<'a> Map2 for Conv1D<'a> { (0, inp_stride) // This value never gets used anyway }; let k_stride = k_l.stride(); - let k_over_2 = p.k_size / 2; let l_out = p.l_out(); let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); let mut dst = vec![T::zero(); dst_elems]; @@ -935,18 +1012,16 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + dst_l; let mut d = T::zero(); for offset in 0..p.k_size { - let src_l_plus = p.stride * dst_l + offset; - // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] - if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { - let src_l = src_l_plus - k_over_2; - for src_c_idx in 0..p.c_in { - let inp_idx = - inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset * k_stride[2]; - d += inp[inp_idx] * k[k_idx] - } + let src_l = (p.stride * dst_l + offset) + .saturating_sub(p.padding) + .min(p.l_in - 1); + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] } } dst[dst_idx] = d @@ -957,6 +1032,65 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); + +impl<'a> Map2 for Conv2D<'a> { + const OP: &'static str = "conv2d"; + fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>( + &self, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, + ) -> Result<Vec<T>> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let k = &k[k_l.start_offset()..]; + let k_stride = k_l.stride(); + let (out_h, out_w) = (p.out_h(), p.out_w()); + + let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + for b_idx in 0..p.b_size { + let inp_idx = b_idx * inp_stride[0]; + let dst_idx = b_idx * p.c_out * out_h * out_w; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * out_h * out_w; + for dst_h in 0..out_h { + let dst_idx = dst_idx + dst_h * out_w; + for dst_w in 0..out_w { + let dst_idx = dst_idx + dst_w; + let mut d = T::zero(); + for offset_h in 0..p.k_h { + let src_h = (p.stride * dst_h + offset_h) + .saturating_sub(p.padding) + .min(p.i_h - 1); + for offset_w in 0..p.k_w { + let src_w = (p.stride * dst_w + offset_w) + .saturating_sub(p.padding) + .min(p.i_w - 1); + for src_c_idx in 0..p.c_in { + let inp_idx = inp_idx + + src_c_idx * inp_stride[1] + + src_h * inp_stride[2] + + src_w * inp_stride[3]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset_h * k_stride[2] + + offset_w * k_stride[3]; + d += inp[inp_idx] * k[k_idx] + } + } + } + dst[dst_idx] = d + } + } + } + } + Ok(dst) + } +} + struct MatMul((usize, usize, usize, usize)); impl MatMul { @@ -974,7 +1108,7 @@ impl MatMul { impl Map2 for MatMul { const OP: &'static str = "mat_mul"; - #[cfg(not(feature = "mkl"))] + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] fn f<T: 'static + WithDType + num_traits::Num + Copy>( &self, lhs: &[T], @@ -1053,6 +1187,109 @@ impl Map2 for MatMul { Ok(dst) } + #[cfg(feature = "accelerate")] + fn f<T: 'static + WithDType + num_traits::Num + Copy>( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result<Vec<T>> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + + let a_skip: usize = match lhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [stride] => stride, + [] => m * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, + }; + let b_skip: usize = match rhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [stride] => stride, + [] => n * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, + }; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + #[cfg(feature = "mkl")] fn f<T: 'static + WithDType + num_traits::Num + Copy>( &self, @@ -1426,6 +1663,19 @@ impl BackendStorage for CpuStorage { Affine(mul, add).map(self, layout) } + fn avg_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result<Self> { + AvgPool2D(kernel_size, stride).map(self, layout) + } + + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { + UpsampleNearest2D(h, w).map(self, layout) + } + fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { // TODO: Have some generic map for functions that apply on num_traits::Float elements. match self { @@ -1612,6 +1862,16 @@ impl BackendStorage for CpuStorage { Conv1D(params).map(self, l, kernel, kernel_l) } + fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result<Self> { + Conv2D(params).map(self, l, kernel, kernel_l) + } + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7b4b358d..727ea073 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1381,6 +1381,24 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv2D, + ) -> Result<Self> { + todo!() + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { + todo!() + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { + todo!() + } + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { let device = self.device().clone(); let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 563d892b..65232839 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -101,6 +101,13 @@ impl Device { } } + pub fn is_cpu(&self) -> bool { + match self { + Self::Cpu => true, + Self::Cuda(_) => false, + } + } + pub fn is_cuda(&self) -> bool { match self { Self::Cpu => false, diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 0e906119..92929748 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -43,7 +43,7 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { - Self::U8 => 4, + Self::U8 => 1, Self::U32 => 4, Self::BF16 => 2, Self::F16 => 2, @@ -53,7 +53,9 @@ impl DType { } } -pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static { +pub trait WithDType: + Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static +{ const DTYPE: DType; fn from_f64(v: f64) -> Self; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 17d4a22e..ae4dd09f 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } @@ -119,6 +129,14 @@ impl crate::backend::BackendStorage for CudaStorage { fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } } impl crate::backend::BackendDevice for CudaDevice { diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs new file mode 100644 index 00000000..4a5d4fa0 --- /dev/null +++ b/candle-core/src/ggml.rs @@ -0,0 +1,582 @@ +//! Support for the GGML file format. + +use crate::{DType, Device, Result, Tensor}; +use byteorder::{LittleEndian, ReadBytesExt}; +use half::f16; + +// Default to QK_K 256 rather than 64. +pub const QK_K: usize = 256; +pub const K_SCALE_SIZE: usize = 12; + +pub const QK4_0: usize = 32; +pub const QK4_1: usize = 32; +pub const QK5_0: usize = 32; +pub const QK5_1: usize = 32; +pub const QK8_0: usize = 32; +pub const QK8_1: usize = 32; + +#[repr(C)] +struct BlockQ4_0 { + d: f16, + qs: [u8; QK4_0 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18); + +#[repr(C)] +struct BlockQ4_1 { + d: f16, + m: f16, + qs: [u8; QK4_1 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20); + +#[repr(C)] +struct BlockQ5_0 { + d: f16, + qh: [u8; 4], + qs: [u8; QK5_0 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22); + +#[repr(C)] +struct BlockQ5_1 { + d: f16, + m: f16, + qh: [u8; 4], + qs: [u8; QK5_1 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24); + +#[repr(C)] +struct BlockQ8_0 { + d: f16, + qs: [u8; QK8_0], +} +const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34); + +#[repr(C)] +struct BlockQ8_1 { + d: f16, + s: f16, + qs: [u8; QK8_1], +} +const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36); + +#[repr(C)] +struct BlockQ2K { + scales: [u8; QK_K / 16], + qs: [u8; QK_K / 4], + d: f16, + dmin: f16, +} +const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>()); + +#[repr(C)] +struct BlockQ3K { + hmask: [u8; QK_K / 8], + qs: [u8; QK_K / 4], + scales: [u8; 12], + d: f16, +} +const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>()); + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82 +#[repr(C)] +struct BlockQ4K { + d: f16, + dmin: f16, + scales: [u8; K_SCALE_SIZE], + qs: [u8; QK_K / 2], +} +const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>()); + +#[repr(C)] +struct BlockQ5K { + d: f16, + dmin: f16, + scales: [u8; K_SCALE_SIZE], + qh: [u8; QK_K / 8], + qs: [u8; QK_K / 2], +} +const _: () = + assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>()); + +#[repr(C)] +struct BlockQ6K { + ql: [u8; QK_K / 2], + qh: [u8; QK_K / 4], + scales: [i8; QK_K / 16], + d: f16, +} +const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>()); + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 +fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let q = &x.qs; + + let mut is = 0; + for n in (0..QK_K).step_by(128) { + // Step by 32 over q. + let q = &q[n / 4..]; + let mut shift = 0; + for _j in 0..4 { + let sc = x.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &q[..16] { + let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; + ys[ys_index] = y; + ys_index += 1; + } + + let sc = x.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &q[16..32] { + let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; + ys[ys_index] = y; + ys_index += 1; + } + + shift += 2; + } + } + } + Ok(()) +} + +fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { + if j < 4 { + let d = q[j] & 63; + let m = q[j + 4] & 63; + (d, m) + } else { + let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + (d, m) + } +} +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 +fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs.iter() { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let q = &x.qs; + let mut is = 0; + for j in (0..QK_K).step_by(64) { + let q = &q[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &x.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for q in q { + let y = d1 * (q & 0xF) as f32 - m1; + ys[ys_index] = y; + ys_index += 1; + } + for q in q { + let y = d2 * (q >> 4) as f32 - m2; + ys[ys_index] = y; + ys_index += 1; + } + is += 2; + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 +fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> { + todo!() +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 +fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs.iter() { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let ql = &x.qs; + let qh = &x.qh; + let mut is = 0; + let mut u1 = 1; + let mut u2 = 2; + for j in (0..QK_K).step_by(64) { + let ql = &ql[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &x.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u1 != 0 { 16 } else { 1 }; + let y = d1 * ((ql & 0xF) + to_add) as f32 - m1; + ys[ys_index] = y; + ys_index += 1; + } + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u2 != 0 { 16 } else { 1 }; + let y = d2 * ((ql >> 4) + to_add) as f32 - m2; + ys[ys_index] = y; + ys_index += 1; + } + is += 2; + u1 <<= 2; + u2 <<= 2; + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 +fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}") + } + for x in xs.iter() { + let d = x.d.to_f32(); + let ql = &x.ql; + let qh = &x.qh; + let sc = &x.scales; + for n in (0..QK_K).step_by(128) { + let idx = n / 128; + let ys = &mut ys[n..]; + let sc = &sc[8 * idx..]; + let ql = &ql[64 * idx..]; + let qh = &qh[32 * idx..]; + for l in 0..32 { + let is = l / 16; + let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32; + let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32; + let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32; + let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32; + ys[l] = d * sc[is] as f32 * q1 as f32; + ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32; + ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32; + ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32; + } + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Ggjt, + Ggla, + Ggmf, + Ggml, + Ggsn, +} + +impl TryFrom<u32> for Magic { + type Error = crate::Error; + fn try_from(value: u32) -> Result<Self> { + let magic = match value { + 0x67676a74 => Self::Ggjt, + 0x67676c61 => Self::Ggla, + 0x67676d66 => Self::Ggmf, + 0x67676d6c => Self::Ggml, + 0x6767736e => Self::Ggsn, + _ => crate::bail!("unknown magic {value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgmlUnversioned, + GgmfV1, + GgjtV1, + GgjtV2, + GgjtV3, +} + +impl VersionedMagic { + fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> { + let magic = reader.read_u32::<LittleEndian>()?; + let magic = Magic::try_from(magic)?; + if magic == Magic::Ggml { + return Ok(Self::GgmlUnversioned); + } + let version = reader.read_u32::<LittleEndian>()?; + let versioned_magic = match (magic, version) { + (Magic::Ggmf, 1) => Self::GgmfV1, + (Magic::Ggjt, 1) => Self::GgjtV1, + (Magic::Ggjt, 2) => Self::GgjtV2, + (Magic::Ggjt, 3) => Self::GgjtV3, + _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } + + fn align32(&self) -> bool { + match self { + Self::GgmlUnversioned | Self::GgmfV1 => false, + Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HParams { + pub n_vocab: u32, + pub n_embd: u32, + pub n_mult: u32, + pub n_head: u32, + pub n_layer: u32, + pub n_rot: u32, + pub ftype: u32, +} + +impl HParams { + fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> { + let n_vocab = reader.read_u32::<LittleEndian>()?; + let n_embd = reader.read_u32::<LittleEndian>()?; + let n_mult = reader.read_u32::<LittleEndian>()?; + let n_head = reader.read_u32::<LittleEndian>()?; + let n_layer = reader.read_u32::<LittleEndian>()?; + let n_rot = reader.read_u32::<LittleEndian>()?; + let ftype = reader.read_u32::<LittleEndian>()?; + Ok(Self { + n_vocab, + n_embd, + n_mult, + n_head, + n_layer, + n_rot, + ftype, + }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Vocab { + pub token_score_pairs: Vec<(Vec<u8>, f32)>, +} + +impl Vocab { + fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556 + let mut token_score_pairs = Vec::with_capacity(n_vocab); + for _index in 0..n_vocab { + let len = reader.read_u32::<LittleEndian>()? as usize; + let mut word = vec![0u8; len]; + reader.read_exact(&mut word)?; + let score = reader.read_f32::<LittleEndian>()?; + token_score_pairs.push((word, score)) + } + Ok(Self { token_score_pairs }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GgmlDType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, +} + +impl GgmlDType { + fn from_u32(u: u32) -> Result<Self> { + let dtype = match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + _ => crate::bail!("unknown dtype for tensor {u}"), + }; + Ok(dtype) + } + + fn type_size(&self) -> usize { + match self { + Self::F32 => 4, + Self::F16 => 2, + Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(), + Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(), + Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(), + Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(), + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932 + Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(), + Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(), + Self::Q2K => std::mem::size_of::<BlockQ2K>(), + Self::Q3K => std::mem::size_of::<BlockQ3K>(), + Self::Q4K => std::mem::size_of::<BlockQ4K>(), + Self::Q5K => std::mem::size_of::<BlockQ5K>(), + Self::Q6K => std::mem::size_of::<BlockQ6K>(), + } + } + + fn blck_size(&self) -> usize { + match self { + Self::F32 => 1, + Self::F16 => 1, + Self::Q4_0 => QK4_0, + Self::Q4_1 => QK4_1, + Self::Q5_0 => QK5_0, + Self::Q5_1 => QK5_1, + Self::Q8_0 => QK8_0, + Self::Q8_1 => QK8_1, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K, + } + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub hparams: HParams, + pub vocab: Vocab, + pub tensors: Vec<(String, Tensor)>, +} + +fn read_one_tensor<R: std::io::Seek + std::io::Read>( + reader: &mut R, + magic: VersionedMagic, + device: &Device, +) -> Result<(String, Tensor)> { + let n_dims = reader.read_u32::<LittleEndian>()?; + let name_len = reader.read_u32::<LittleEndian>()?; + let dtype = reader.read_u32::<LittleEndian>()?; + let dtype = GgmlDType::from_u32(dtype)?; + let mut dims = vec![0u32; n_dims as usize]; + reader.read_u32_into::<LittleEndian>(&mut dims)?; + let mut name = vec![0u8; name_len as usize]; + reader.read_exact(&mut name)?; + let name = String::from_utf8_lossy(&name).into_owned(); + + if magic.align32() { + let pos = reader.stream_position()?; + reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?; + } + let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); + let tensor_elems = dims.iter().product::<usize>(); + let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size(); + println!("{name} {dtype:?} {dims:?}"); + // TODO: Mmap version to avoid copying the data around? + let mut raw_data = vec![0u8; size_in_bytes]; + reader.read_exact(&mut raw_data)?; + let tensor = match dtype { + GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?, + GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?, + GgmlDType::Q2K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) }; + dequantize_row_q2k(raw_data, &mut f32_data)?; + // Maybe we should use bf16 instead? + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q3K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ3K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) }; + dequantize_row_q3k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q4K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) }; + dequantize_row_q4k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q5K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ5K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) }; + dequantize_row_q5k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q6K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) }; + dequantize_row_q6k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + _ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"), + }; + Ok((name, tensor)) +} + +impl Content { + pub fn read<R: std::io::Seek + std::io::Read>( + reader: &mut R, + device: &Device, + ) -> Result<Content> { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 + let last_position = reader.seek(std::io::SeekFrom::End(0))?; + reader.seek(std::io::SeekFrom::Start(0))?; + let magic = VersionedMagic::read(reader)?; + let hparams = HParams::read(reader)?; + let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; + let mut tensors = vec![]; + + while reader.stream_position()? != last_position { + let (name, tensor) = read_one_tensor(reader, magic, device)?; + tensors.push((name, tensor)) + } + Ok(Self { + magic, + hparams, + vocab, + tensors, + }) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index c374d245..016d3806 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -33,6 +33,8 @@ //! //! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers) +#[cfg(feature = "accelerate")] +mod accelerate; pub mod backend; pub mod backprop; mod conv; @@ -45,6 +47,7 @@ pub mod display; mod dtype; mod dummy_cuda_backend; pub mod error; +pub mod ggml; mod indexer; pub mod layout; #[cfg(feature = "mkl")] diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ba8d2fb4..aea8b733 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -51,6 +51,7 @@ pub enum UnaryOp { Cos, Abs, Neg, + Recip, Sqr, Sqrt, Gelu, @@ -79,6 +80,21 @@ pub enum Op { stride: usize, }, + #[allow(dead_code)] + Conv2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + + AvgPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), + Cat(Vec<Tensor>, usize), #[allow(dead_code)] // add is currently unused. @@ -264,6 +280,7 @@ pub(crate) struct Sin; pub(crate) struct Cos; pub(crate) struct Abs; pub(crate) struct Neg; +pub(crate) struct Recip; pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; @@ -410,6 +427,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); +unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 1e1ef305..3ed38e6a 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -266,6 +266,64 @@ impl Storage { } } + pub(crate) fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result<Self> { + self.same_device(kernel, "conv2d")?; + self.same_dtype(kernel, "conv2d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv2d", + } + .bt()), + } + } + + pub(crate) fn avg_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result<Self> { + match self { + Storage::Cpu(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + } + } + + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b958e06d..adba7376 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -269,6 +269,10 @@ impl Tensor { Self::rand_impl(lo, up, s, device, false) } + pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> { + Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false) + } + pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>( mean: T, std: T, @@ -296,6 +300,17 @@ impl Tensor { Ok(from_storage(storage, s, none, is_variable)) } + pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> { + Tensor::randn_f64_impl( + mean, + stdev, + self.shape(), + self.dtype(), + self.device(), + false, + ) + } + /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. pub fn randn<S: Into<Shape>, T: crate::FloatDType>( @@ -474,6 +489,7 @@ impl Tensor { broadcast_binary_op!(broadcast_sub, sub); broadcast_binary_op!(broadcast_div, div); + unary_op!(recip, Recip); unary_op!(neg, Neg); unary_op!(exp, Exp); unary_op!(log, Log); @@ -548,6 +564,32 @@ impl Tensor { } } + /// Split a tensor into the specified number of chunks, this may return less chunks than + /// specificed. + pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> { + let dim = dim.to_index(self.shape(), "chunk")?; + let size = self.dim(dim)?; + if size < chunks { + (0..size).map(|i| self.narrow(dim, i, 1)).collect() + } else { + let chunk_size = size / chunks; + let cnt_additional = size % chunks; + let mut tensors = vec![]; + let mut sum_chunk_size = 0; + for i in 0..chunks { + let chunk_size = if i < cnt_additional { + chunk_size + 1 + } else { + chunk_size + }; + let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?; + tensors.push(tensor); + sum_chunk_size += chunk_size + } + Ok(tensors) + } + } + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> { @@ -775,6 +817,61 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } + pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k { + crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") + } + let params = crate::conv::ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding, + stride, + }); + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) + } + + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + let (n, c, _h, _w) = self.dims4()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let storage = self + .storage() + .upsample_nearest2d(self.layout(), target_h, target_w)?; + Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) + } + + pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::AvgPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .avg_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments @@ -1717,6 +1814,32 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[self, &right], dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self], dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self, &right], dim) + } + } + fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 895c97e1..d3f5b50e 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -11,16 +11,14 @@ pub fn get_num_threads() -> usize { } } +pub fn has_accelerate() -> bool { + cfg!(feature = "accelerate") +} + pub fn has_mkl() -> bool { - #[cfg(feature = "mkl")] - return true; - #[cfg(not(feature = "mkl"))] - return false; + cfg!(feature = "mkl") } pub fn cuda_is_available() -> bool { - #[cfg(feature = "cuda")] - return true; - #[cfg(not(feature = "cuda"))] - return false; + cfg!(feature = "cuda") } |