diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backend.rs | 1 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 2 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 7 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 116 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 6 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/error.rs | 14 | ||||
-rw-r--r-- | candle-core/src/npy.rs | 18 | ||||
-rw-r--r-- | candle-core/src/op.rs | 7 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 26 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 27 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 18 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 29 |
13 files changed, 196 insertions, 79 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a8e5ac52..4c31ca6f 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -46,6 +46,7 @@ pub trait BackendStorage: Sized { ) -> Result<Self>; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 0eab508e..2a60fe30 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -88,6 +88,7 @@ impl Tensor { Op::Reshape(node) | Op::UpsampleNearest2D(node) | Op::AvgPool2D { arg: node, .. } + | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -172,6 +173,7 @@ impl Tensor { Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?, Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 30799459..e3fea861 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,6 +1,6 @@ #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { - pub(crate) b_size: Option<usize>, + pub(crate) b_size: usize, // Maybe we should have a version without l_in as this bit depends on the input and not only on // the weights. pub(crate) l_in: usize, @@ -19,10 +19,7 @@ impl ParamsConv1D { pub(crate) fn out_dims(&self) -> Vec<usize> { let l_out = self.l_out(); - match self.b_size { - None => vec![self.c_out, l_out], - Some(n) => vec![n, self.c_out, l_out], - } + vec![self.b_size, self.c_out, l_out] } } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0ec19559..d4f5fcdc 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -660,6 +660,8 @@ impl Map1 for AvgPool2D { let mut sum = T::zero(); for m in 0..k_h { for n in 0..k_w { + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; sum += src[src_index + m * stride_h + n * stride_w] } } @@ -672,6 +674,48 @@ impl Map1 for AvgPool2D { } } +struct MaxPool2D((usize, usize), (usize, usize)); + +impl Map1 for MaxPool2D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut largest = + src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w]; + for m in 0..k_h { + for n in 0..k_w { + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; + if largest < src[src_index + m * stride_h + n * stride_w] { + largest = src[src_index + m * stride_h + n * stride_w] + } + } + } + dst[h_idx * w_out + w_idx] = largest; + } + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -990,19 +1034,14 @@ impl<'a> Map2 for Conv1D<'a> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; let k = &k[k_l.start_offset()..]; - let inp_stride = inp_l.stride(); - let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { - (inp_stride[0], &inp_stride[1..]) - } else { - (0, inp_stride) // This value never gets used anyway - }; - let k_stride = k_l.stride(); + let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; + let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let l_out = p.l_out(); - let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst_elems = p.c_out * l_out * p.b_size; let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] - for b_idx in 0..p.b_size.unwrap_or(1) { - let inp_idx = b_idx * inp_stride0; + for b_idx in 0..p.b_size { + let inp_idx = b_idx * inp_s0; let dst_idx = b_idx * p.c_out * l_out; for dst_c_idx in 0..p.c_out { let dst_idx = dst_idx + dst_c_idx * l_out; @@ -1014,11 +1053,8 @@ impl<'a> Map2 for Conv1D<'a> { .saturating_sub(p.padding) .min(p.l_in - 1); for src_c_idx in 0..p.c_in { - let inp_idx = - inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset * k_stride[2]; + let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2; + let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2; d += inp[inp_idx] * k[k_idx] } } @@ -1043,14 +1079,14 @@ impl<'a> Map2 for Conv2D<'a> { ) -> Result<Vec<T>> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; - let inp_stride = inp_l.stride(); + let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; let k = &k[k_l.start_offset()..]; - let k_stride = k_l.stride(); + let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; let (out_h, out_w) = (p.out_h(), p.out_w()); let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; for b_idx in 0..p.b_size { - let inp_idx = b_idx * inp_stride[0]; + let inp_idx = b_idx * inp_s0; let dst_idx = b_idx * p.c_out * out_h * out_w; for dst_c_idx in 0..p.c_out { let dst_idx = dst_idx + dst_c_idx * out_h * out_w; @@ -1069,13 +1105,13 @@ impl<'a> Map2 for Conv2D<'a> { .min(p.i_w - 1); for src_c_idx in 0..p.c_in { let inp_idx = inp_idx - + src_c_idx * inp_stride[1] - + src_h * inp_stride[2] - + src_w * inp_stride[3]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset_h * k_stride[2] - + offset_w * k_stride[3]; + + src_c_idx * inp_s1 + + src_h * inp_s2 + + src_w * inp_s3; + let k_idx = dst_c_idx * k_s0 + + src_c_idx * k_s1 + + offset_h * k_s2 + + offset_w * k_s3; d += inp[inp_idx] * k[k_idx] } } @@ -1670,6 +1706,15 @@ impl BackendStorage for CpuStorage { AvgPool2D(kernel_size, stride).map(self, layout) } + fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result<Self> { + MaxPool2D(kernel_size, stride).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { UpsampleNearest2D(h, w).map(self, layout) } @@ -2025,35 +2070,36 @@ impl BackendDevice for CpuDevice { DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); - let std = bf16::from_f64(std); - let mean = bf16::from_f64(mean); + let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) + .map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::BF16(data)) } DType::F16 => { let mut data = Vec::with_capacity(elem_count); - let std = f16::from_f64(std); - let mean = f16::from_f64(mean); + let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std)) + .map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F16(data)) } DType::F32 => { let mut data = Vec::with_capacity(elem_count); - let std = std as f32; - let mean = mean as f32; + let normal = + rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F32(data)) } DType::F64 => { let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?; for _i in 0..elem_count { - data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F64(data)) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 727ea073..a7f63353 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> { let dims = shape.dims(); let el = shape.elem_count(); let l_out = p.l_out(); - let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst_el = p.c_out * l_out * p.b_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. @@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage { todo!() } + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { + todo!() + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { todo!() } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index ae4dd09f..870a87cd 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 35a33032..c18b43c6 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -185,6 +185,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box<dyn std::error::Error + Send + Sync>), + /// Adding path information to an error. + #[error("path: {path:?} {inner}")] + WithPath { + inner: Box<Self>, + path: std::path::PathBuf, + }, + #[error("{inner}\n{backtrace}")] WithBacktrace { inner: Box<Self>, @@ -214,6 +221,13 @@ impl Error { }, } } + + pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self { + Self::WithPath { + inner: Box::new(self), + path: p.as_ref().to_path_buf(), + } + } } #[macro_export] diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 6302cf71..e17ba02a 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -307,39 +307,39 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let elem_count = self.elem_count(); + let vs = self.flatten_all()?; match self.dtype() { DType::BF16 => { - let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?; + let vs = vs.to_vec1::<bf16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F16 => { - let vs = self.reshape(elem_count)?.to_vec1::<f16>()?; + let vs = vs.to_vec1::<f16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. - for v in self.reshape(elem_count)?.to_vec1::<f32>()? { + for v in vs.to_vec1::<f32>()? { f.write_f32::<LittleEndian>(v)? } } DType::F64 => { - for v in self.reshape(elem_count)?.to_vec1::<f64>()? { + for v in vs.to_vec1::<f64>()? { f.write_f64::<LittleEndian>(v)? } } DType::U32 => { - for v in self.reshape(elem_count)?.to_vec1::<u32>()? { + for v in vs.to_vec1::<u32>()? { f.write_u32::<LittleEndian>(v)? } } DType::U8 => { - let data = self.reshape(elem_count)?.to_vec1::<u8>()?; - f.write_all(&data)?; + let vs = vs.to_vec1::<u8>()?; + f.write_all(&vs)?; } } Ok(()) @@ -373,7 +373,7 @@ pub struct NpzTensors { index_per_name: HashMap<String, usize>, path: std::path::PathBuf, // We do not store a zip reader as it needs mutable access to extract data. Instead we - // re-create a zip reader each time. + // re-create a zip reader for each tensor. } impl NpzTensors { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index aea8b733..f99d8adc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -93,6 +93,13 @@ pub enum Op { kernel_size: (usize, usize), stride: (usize, usize), }, + + MaxPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), Cat(Vec<Tensor>, usize), diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 1880a041..914e5101 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> { pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> { let data = std::fs::read(filename.as_ref())?; - let st = safetensors::SafeTensors::deserialize(&data)?; + load_buffer(&data[..], device) +} + +pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> { + let st = safetensors::SafeTensors::deserialize(data)?; st.tensors() .into_iter() .map(|(name, view)| Ok((name, view.load(device)?))) @@ -253,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } -pub struct MmapedFile(memmap2::Mmap); +pub struct MmapedFile { + path: std::path::PathBuf, + inner: memmap2::Mmap, +} impl MmapedFile { /// Creates a wrapper around a memory mapped file from which you can retrieve @@ -263,13 +270,20 @@ impl MmapedFile { /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { - let file = std::fs::File::open(p)?; - let mmap = memmap2::MmapOptions::new().map(&file)?; - Ok(Self(mmap)) + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let inner = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + Ok(Self { + inner, + path: p.to_path_buf(), + }) } pub fn deserialize(&self) -> Result<SafeTensors<'_>> { - let st = safetensors::SafeTensors::deserialize(&self.0)?; + let st = safetensors::SafeTensors::deserialize(&self.inner) + .map_err(|e| Error::from(e).with_path(&self.path))?; Ok(st) } } diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index a5e21aad..83d11c09 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape { macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + pub fn $fn_name(dims: &[usize]) -> Result<$out_type> { + if dims.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: dims.len(), + shape: Shape::from(dims), + } + .bt()) + } else { + Ok($dims(dims)) + } + } + impl Shape { pub fn $fn_name(&self) -> Result<$out_type> { - if self.0.len() != $cnt { - Err(Error::UnexpectedNumberOfDims { - expected: $cnt, - got: self.0.len(), - shape: self.clone(), - } - .bt()) - } else { - Ok($dims(&self.0)) - } + $fn_name(self.0.as_slice()) } } + impl crate::Tensor { pub fn $fn_name(&self) -> Result<$out_type> { self.shape().$fn_name() @@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } -extract_dims!(dims0, 0, |_: &Vec<usize>| (), ()); +extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 3ed38e6a..791b65dd 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -311,6 +311,24 @@ impl Storage { } } + pub(crate) fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result<Self> { + match self { + Storage::Cpu(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index adba7376..c14a4e39 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -773,18 +773,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = match *self.dims() { - [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), - [c_in, l_in] => (None, c_in, l_in), - _ => Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "input rank is not 2 or 3", - } - .bt())?, - }; + let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { Err(Error::Conv1dInvalidArgs { inp_shape: self.shape().clone(), @@ -872,6 +861,22 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } + pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .max_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments |