diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 66 | ||||
-rw-r--r-- | candle-core/src/cpu_kernels.rs | 28 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 8 | ||||
-rw-r--r-- | candle-core/src/error.rs | 14 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/npy.rs | 18 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 20 |
7 files changed, 109 insertions, 46 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 238a9a69..250e2721 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { const OP: &'static str = "conv1d"; - fn f<T: 'static + num_traits::NumAssign + Copy>( - &self, - inp: &[T], - inp_l: &Layout, - k: &[T], - k_l: &Layout, - ) -> Result<Vec<T>> { - // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). + fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; let k = &k[k_l.start_offset()..]; @@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> { let dst_elems = p.c_out * l_out * p.b_size; let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; for b_idx in 0..p.b_size { - let inp_idx = b_idx * inp_s0; - let dst_idx = b_idx * p.c_out * l_out; + for src_l in 0..p.l_in { + for src_c_idx in 0..p.c_in { + let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2; + inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx] + } + } + } + for offset in 0..p.k_size { for dst_c_idx in 0..p.c_out { - let dst_idx = dst_idx + dst_c_idx * l_out; - for dst_l in 0..l_out { - let dst_idx = dst_idx + dst_l; - let mut d = T::zero(); - for offset in 0..p.k_size { + let dst_idx = dst_c_idx * l_out; + let k_cont = (0..p.c_in) + .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) + .collect::<Vec<_>>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; let src_l = (p.stride * dst_l + offset) .saturating_sub(p.padding) .min(p.l_in - 1); - for src_c_idx in 0..p.c_in { - let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2; - let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2; - d += inp[inp_idx] * k[k_idx] - } + let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) } + dst[dst_idx] += d } - dst[dst_idx] = d } } } @@ -2070,35 +2073,36 @@ impl BackendDevice for CpuDevice { DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); - let std = bf16::from_f64(std); - let mean = bf16::from_f64(mean); + let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) + .map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::BF16(data)) } DType::F16 => { let mut data = Vec::with_capacity(elem_count); - let std = f16::from_f64(std); - let mean = f16::from_f64(mean); + let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std)) + .map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F16(data)) } DType::F32 => { let mut data = Vec::with_capacity(elem_count); - let std = std as f32; - let mean = mean as f32; + let normal = + rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F32(data)) } DType::F64 => { let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F64(data)) } diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs new file mode 100644 index 00000000..187dc16b --- /dev/null +++ b/candle-core/src/cpu_kernels.rs @@ -0,0 +1,28 @@ +pub trait VecDot: num_traits::NumAssign + Copy { + /// Dot-product of two vectors. + /// + /// # Safety + /// + /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + *res = Self::zero(); + for i in 0..len { + *res += *lhs.add(i) * *rhs.add(i) + } + } +} + +impl VecDot for f32 { + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + ggblas::ggml::vec_dot_f32(lhs, rhs, res, len) + } +} + +impl VecDot for f64 {} +impl VecDot for half::bf16 {} +impl VecDot for half::f16 {} +impl VecDot for u8 {} +impl VecDot for u32 {} diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 92929748..5d24b08f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -54,7 +54,13 @@ impl DType { } pub trait WithDType: - Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static + Sized + + Copy + + num_traits::NumAssign + + std::cmp::PartialOrd + + std::fmt::Display + + 'static + + crate::cpu_kernels::VecDot { const DTYPE: DType; diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 35a33032..c18b43c6 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -185,6 +185,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box<dyn std::error::Error + Send + Sync>), + /// Adding path information to an error. + #[error("path: {path:?} {inner}")] + WithPath { + inner: Box<Self>, + path: std::path::PathBuf, + }, + #[error("{inner}\n{backtrace}")] WithBacktrace { inner: Box<Self>, @@ -214,6 +221,13 @@ impl Error { }, } } + + pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self { + Self::WithPath { + inner: Box::new(self), + path: p.as_ref().to_path_buf(), + } + } } #[macro_export] diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 016d3806..aba88135 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -40,6 +40,7 @@ pub mod backprop; mod conv; mod convert; pub mod cpu_backend; +pub mod cpu_kernels; #[cfg(feature = "cuda")] pub mod cuda_backend; mod device; diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 6302cf71..e17ba02a 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -307,39 +307,39 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let elem_count = self.elem_count(); + let vs = self.flatten_all()?; match self.dtype() { DType::BF16 => { - let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?; + let vs = vs.to_vec1::<bf16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F16 => { - let vs = self.reshape(elem_count)?.to_vec1::<f16>()?; + let vs = vs.to_vec1::<f16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. - for v in self.reshape(elem_count)?.to_vec1::<f32>()? { + for v in vs.to_vec1::<f32>()? { f.write_f32::<LittleEndian>(v)? } } DType::F64 => { - for v in self.reshape(elem_count)?.to_vec1::<f64>()? { + for v in vs.to_vec1::<f64>()? { f.write_f64::<LittleEndian>(v)? } } DType::U32 => { - for v in self.reshape(elem_count)?.to_vec1::<u32>()? { + for v in vs.to_vec1::<u32>()? { f.write_u32::<LittleEndian>(v)? } } DType::U8 => { - let data = self.reshape(elem_count)?.to_vec1::<u8>()?; - f.write_all(&data)?; + let vs = vs.to_vec1::<u8>()?; + f.write_all(&vs)?; } } Ok(()) @@ -373,7 +373,7 @@ pub struct NpzTensors { index_per_name: HashMap<String, usize>, path: std::path::PathBuf, // We do not store a zip reader as it needs mutable access to extract data. Instead we - // re-create a zip reader each time. + // re-create a zip reader for each tensor. } impl NpzTensors { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 132fb914..914e5101 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } -pub struct MmapedFile(memmap2::Mmap); +pub struct MmapedFile { + path: std::path::PathBuf, + inner: memmap2::Mmap, +} impl MmapedFile { /// Creates a wrapper around a memory mapped file from which you can retrieve @@ -267,13 +270,20 @@ impl MmapedFile { /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { - let file = std::fs::File::open(p)?; - let mmap = memmap2::MmapOptions::new().map(&file)?; - Ok(Self(mmap)) + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let inner = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + Ok(Self { + inner, + path: p.to_path_buf(), + }) } pub fn deserialize(&self) -> Result<SafeTensors<'_>> { - let st = safetensors::SafeTensors::deserialize(&self.0)?; + let st = safetensors::SafeTensors::deserialize(&self.inner) + .map_err(|e| Error::from(e).with_path(&self.path))?; Ok(st) } } |