summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/cpu_backend.rs66
-rw-r--r--candle-core/src/cpu_kernels.rs28
-rw-r--r--candle-core/src/dtype.rs8
-rw-r--r--candle-core/src/error.rs14
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/npy.rs18
-rw-r--r--candle-core/src/safetensors.rs20
7 files changed, 109 insertions, 46 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 238a9a69..250e2721 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
- fn f<T: 'static + num_traits::NumAssign + Copy>(
- &self,
- inp: &[T],
- inp_l: &Layout,
- k: &[T],
- k_l: &Layout,
- ) -> Result<Vec<T>> {
- // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
+ fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
@@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_s0;
- let dst_idx = b_idx * p.c_out * l_out;
+ for src_l in 0..p.l_in {
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
+ inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
+ }
+ }
+ }
+ for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
- let dst_idx = dst_idx + dst_c_idx * l_out;
- for dst_l in 0..l_out {
- let dst_idx = dst_idx + dst_l;
- let mut d = T::zero();
- for offset in 0..p.k_size {
+ let dst_idx = dst_c_idx * l_out;
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
+ .collect::<Vec<_>>();
+ for b_idx in 0..p.b_size {
+ let dst_idx = dst_idx + b_idx * p.c_out * l_out;
+ for dst_l in 0..l_out {
+ let dst_idx = dst_idx + dst_l;
let src_l = (p.stride * dst_l + offset)
.saturating_sub(p.padding)
.min(p.l_in - 1);
- for src_c_idx in 0..p.c_in {
- let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
- let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
- d += inp[inp_idx] * k[k_idx]
- }
+ let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
+ assert!(inp_cont.len() >= p.c_in);
+ assert!(k_cont.len() >= p.c_in);
+ let mut d = T::zero();
+ unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
+ dst[dst_idx] += d
}
- dst[dst_idx] = d
}
}
}
@@ -2070,35 +2073,36 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = bf16::from_f64(std);
- let mean = bf16::from_f64(mean);
+ let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = f16::from_f64(std);
- let mean = f16::from_f64(mean);
+ let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let std = std as f32;
- let mean = mean as f32;
+ let normal =
+ rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
+ let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}
diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs
new file mode 100644
index 00000000..187dc16b
--- /dev/null
+++ b/candle-core/src/cpu_kernels.rs
@@ -0,0 +1,28 @@
+pub trait VecDot: num_traits::NumAssign + Copy {
+ /// Dot-product of two vectors.
+ ///
+ /// # Safety
+ ///
+ /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ *res = Self::zero();
+ for i in 0..len {
+ *res += *lhs.add(i) * *rhs.add(i)
+ }
+ }
+}
+
+impl VecDot for f32 {
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ ggblas::ggml::vec_dot_f32(lhs, rhs, res, len)
+ }
+}
+
+impl VecDot for f64 {}
+impl VecDot for half::bf16 {}
+impl VecDot for half::f16 {}
+impl VecDot for u8 {}
+impl VecDot for u32 {}
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 92929748..5d24b08f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -54,7 +54,13 @@ impl DType {
}
pub trait WithDType:
- Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static
+ Sized
+ + Copy
+ + num_traits::NumAssign
+ + std::cmp::PartialOrd
+ + std::fmt::Display
+ + 'static
+ + crate::cpu_kernels::VecDot
{
const DTYPE: DType;
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 35a33032..c18b43c6 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ /// Adding path information to an error.
+ #[error("path: {path:?} {inner}")]
+ WithPath {
+ inner: Box<Self>,
+ path: std::path::PathBuf,
+ },
+
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@@ -214,6 +221,13 @@ impl Error {
},
}
}
+
+ pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
+ Self::WithPath {
+ inner: Box::new(self),
+ path: p.as_ref().to_path_buf(),
+ }
+ }
}
#[macro_export]
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 016d3806..aba88135 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -40,6 +40,7 @@ pub mod backprop;
mod conv;
mod convert;
pub mod cpu_backend;
+pub mod cpu_kernels;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod device;
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 6302cf71..e17ba02a 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -307,39 +307,39 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let elem_count = self.elem_count();
+ let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
+ let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
+ let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
- for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
+ for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
- for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
+ for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
- for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
+ for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
- let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
- f.write_all(&data)?;
+ let vs = vs.to_vec1::<u8>()?;
+ f.write_all(&vs)?;
}
}
Ok(())
@@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
- // re-create a zip reader each time.
+ // re-create a zip reader for each tensor.
}
impl NpzTensors {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 132fb914..914e5101 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
-pub struct MmapedFile(memmap2::Mmap);
+pub struct MmapedFile {
+ path: std::path::PathBuf,
+ inner: memmap2::Mmap,
+}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@@ -267,13 +270,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
- let file = std::fs::File::open(p)?;
- let mmap = memmap2::MmapOptions::new().map(&file)?;
- Ok(Self(mmap))
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let inner = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok(Self {
+ inner,
+ path: p.to_path_buf(),
+ })
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
- let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ let st = safetensors::SafeTensors::deserialize(&self.inner)
+ .map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}