summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/examples/conv1d_benchmark.rs24
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/conv.rs7
-rw-r--r--candle-core/src/cpu_backend.rs116
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/error.rs14
-rw-r--r--candle-core/src/npy.rs18
-rw-r--r--candle-core/src/op.rs7
-rw-r--r--candle-core/src/safetensors.rs26
-rw-r--r--candle-core/src/shape.rs27
-rw-r--r--candle-core/src/storage.rs18
-rw-r--r--candle-core/src/tensor.rs29
-rw-r--r--candle-core/tests/pool_tests.rs61
-rw-r--r--candle-core/tests/tensor_tests.rs11
17 files changed, 293 insertions, 79 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index af77a0e0..7411592e 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -22,6 +22,7 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
+rand_distr = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
diff --git a/candle-core/examples/conv1d_benchmark.rs b/candle-core/examples/conv1d_benchmark.rs
new file mode 100644
index 00000000..52fae5e8
--- /dev/null
+++ b/candle-core/examples/conv1d_benchmark.rs
@@ -0,0 +1,24 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Result;
+use candle_core::{Device, Tensor};
+
+pub const N_ITERS: usize = 5;
+
+fn main() -> Result<()> {
+ let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
+ let res = inp.conv1d(&w, 0, 1);
+ println!("{res:?}");
+ let start = std::time::Instant::now();
+ for i in 0..N_ITERS {
+ let res = inp.conv1d(&w, 0, 1);
+ println!("{i} {res:?}");
+ }
+ println!("{:?}", start.elapsed() / N_ITERS as u32);
+ Ok(())
+}
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index a8e5ac52..4c31ca6f 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -46,6 +46,7 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 0eab508e..2a60fe30 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -88,6 +88,7 @@ impl Tensor {
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
+ | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -172,6 +173,7 @@ impl Tensor {
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 30799459..e3fea861 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
- pub(crate) b_size: Option<usize>,
+ pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@@ -19,10 +19,7 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
- match self.b_size {
- None => vec![self.c_out, l_out],
- Some(n) => vec![n, self.c_out, l_out],
- }
+ vec![self.b_size, self.c_out, l_out]
}
}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0ec19559..d4f5fcdc 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
@@ -672,6 +674,48 @@ impl Map1 for AvgPool2D {
}
}
+struct MaxPool2D((usize, usize), (usize, usize));
+
+impl Map1 for MaxPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut largest =
+ src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
+ if largest < src[src_index + m * stride_h + n * stride_w] {
+ largest = src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ }
+ dst[h_idx * w_out + w_idx] = largest;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -990,19 +1034,14 @@ impl<'a> Map2 for Conv1D<'a> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
- let inp_stride = inp_l.stride();
- let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
- (inp_stride[0], &inp_stride[1..])
- } else {
- (0, inp_stride) // This value never gets used anyway
- };
- let k_stride = k_l.stride();
+ let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
+ let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
- let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
- for b_idx in 0..p.b_size.unwrap_or(1) {
- let inp_idx = b_idx * inp_stride0;
+ for b_idx in 0..p.b_size {
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
@@ -1014,11 +1053,8 @@ impl<'a> Map2 for Conv1D<'a> {
.saturating_sub(p.padding)
.min(p.l_in - 1);
for src_c_idx in 0..p.c_in {
- let inp_idx =
- inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset * k_stride[2];
+ let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
+ let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1043,14 +1079,14 @@ impl<'a> Map2 for Conv2D<'a> {
) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
- let inp_stride = inp_l.stride();
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
- let k_stride = k_l.stride();
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_stride[0];
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * out_h * out_w;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
@@ -1069,13 +1105,13 @@ impl<'a> Map2 for Conv2D<'a> {
.min(p.i_w - 1);
for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx
- + src_c_idx * inp_stride[1]
- + src_h * inp_stride[2]
- + src_w * inp_stride[3];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset_h * k_stride[2]
- + offset_w * k_stride[3];
+ + src_c_idx * inp_s1
+ + src_h * inp_s2
+ + src_w * inp_s3;
+ let k_idx = dst_c_idx * k_s0
+ + src_c_idx * k_s1
+ + offset_h * k_s2
+ + offset_w * k_s3;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1670,6 +1706,15 @@ impl BackendStorage for CpuStorage {
AvgPool2D(kernel_size, stride).map(self, layout)
}
+ fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ MaxPool2D(kernel_size, stride).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
@@ -2025,35 +2070,36 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = bf16::from_f64(std);
- let mean = bf16::from_f64(mean);
+ let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = f16::from_f64(std);
- let mean = f16::from_f64(mean);
+ let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let std = std as f32;
- let mean = mean as f32;
+ let normal =
+ rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
+ let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 727ea073..a7f63353 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
- let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
@@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage {
todo!()
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index ae4dd09f..870a87cd 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 35a33032..c18b43c6 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ /// Adding path information to an error.
+ #[error("path: {path:?} {inner}")]
+ WithPath {
+ inner: Box<Self>,
+ path: std::path::PathBuf,
+ },
+
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@@ -214,6 +221,13 @@ impl Error {
},
}
}
+
+ pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
+ Self::WithPath {
+ inner: Box::new(self),
+ path: p.as_ref().to_path_buf(),
+ }
+ }
}
#[macro_export]
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 6302cf71..e17ba02a 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -307,39 +307,39 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let elem_count = self.elem_count();
+ let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
+ let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
+ let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
- for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
+ for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
- for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
+ for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
- for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
+ for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
- let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
- f.write_all(&data)?;
+ let vs = vs.to_vec1::<u8>()?;
+ f.write_all(&vs)?;
}
}
Ok(())
@@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
- // re-create a zip reader each time.
+ // re-create a zip reader for each tensor.
}
impl NpzTensors {
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index aea8b733..f99d8adc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -93,6 +93,13 @@ pub enum Op {
kernel_size: (usize, usize),
stride: (usize, usize),
},
+
+ MaxPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 1880a041..914e5101 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
- let st = safetensors::SafeTensors::deserialize(&data)?;
+ load_buffer(&data[..], device)
+}
+
+pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
+ let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
@@ -253,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
-pub struct MmapedFile(memmap2::Mmap);
+pub struct MmapedFile {
+ path: std::path::PathBuf,
+ inner: memmap2::Mmap,
+}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@@ -263,13 +270,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
- let file = std::fs::File::open(p)?;
- let mmap = memmap2::MmapOptions::new().map(&file)?;
- Ok(Self(mmap))
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let inner = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok(Self {
+ inner,
+ path: p.to_path_buf(),
+ })
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
- let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ let st = safetensors::SafeTensors::deserialize(&self.inner)
+ .map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index a5e21aad..83d11c09 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
+ pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
+ if dims.len() != $cnt {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: $cnt,
+ got: dims.len(),
+ shape: Shape::from(dims),
+ }
+ .bt())
+ } else {
+ Ok($dims(dims))
+ }
+ }
+
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
- if self.0.len() != $cnt {
- Err(Error::UnexpectedNumberOfDims {
- expected: $cnt,
- got: self.0.len(),
- shape: self.clone(),
- }
- .bt())
- } else {
- Ok($dims(&self.0))
- }
+ $fn_name(self.0.as_slice())
}
}
+
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
-extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
+extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 3ed38e6a..791b65dd 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -311,6 +311,24 @@ impl Storage {
}
}
+ pub(crate) fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index adba7376..c14a4e39 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -773,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = match *self.dims() {
- [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
- [c_in, l_in] => (None, c_in, l_in),
- _ => Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "input rank is not 2 or 3",
- }
- .bt())?,
- };
+ let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
@@ -872,6 +861,22 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
+ pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let (n, c, h, w) = self.dims4()?;
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
+ let h_out = (h - kernel_size.0) / stride.0 + 1;
+ let w_out = (w - kernel_size.1) / stride.1 + 1;
+ let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ });
+ let storage = self
+ .storage()
+ .max_pool2d(self.layout(), kernel_size, stride)?;
+ Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
+ }
+
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
new file mode 100644
index 00000000..c8ddef97
--- /dev/null
+++ b/candle-core/tests/pool_tests.rs
@@ -0,0 +1,61 @@
+mod test_utils;
+use candle_core::{Device, Tensor};
+
+// https://github.com/huggingface/candle/issues/364
+#[test]
+fn avg_pool2d() -> anyhow::Result<()> {
+ let data: Vec<f32> = vec![
+ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
+
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
+ Ok(())
+}
+
+#[test]
+fn max_pool2d() -> anyhow::Result<()> {
+ let data: Vec<f32> = vec![
+ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
+
+ let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
+ Ok(())
+}
+
+/* This test corresponds to the following PyTorch script.
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 2, 4, 4))
+print(t.flatten())
+res = torch.nn.functional.avg_pool2d(t, 2)
+print(res)
+*/
+#[test]
+fn avg_pool2d_pytorch() -> anyhow::Result<()> {
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
+ 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
+ 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
+ 0.2477, 1.3127,
+ ],
+ &Device::Cpu,
+ )?
+ .reshape((1, 2, 4, 4))?;
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
+ assert_eq!(
+ test_utils::to_vec3_round(pool, 4)?,
+ [
+ [[-1.1926, -0.0395], [0.2688, 0.1871]],
+ [[0.1835, -0.1606], [0.6249, 0.3217]]
+ ]
+ );
+ let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
+ assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
+ Ok(())
+}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 599c2665..0b77f1a5 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -869,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+
+// There was originally a bug on the CPU implementation for randn
+// https://github.com/huggingface/candle/issues/381
+#[test]
+fn randn_hasneg() -> Result<()> {
+ let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
+ if t.iter().all(|&v| v >= 0.) {
+ candle_core::bail!("all values in tensors are non-negative")
+ }
+ Ok(())
+}