summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs84
-rw-r--r--candle-core/tests/pool_tests.rs28
-rw-r--r--candle-kernels/src/conv.cu160
3 files changed, 253 insertions, 19 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 6129e100..90d3ee6d 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -960,6 +960,64 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+enum PoolOp {
+ Max,
+ Avg,
+}
+
+struct Pool2D {
+ w_k: usize,
+ h_k: usize,
+ w_stride: usize,
+ h_stride: usize,
+ op: PoolOp,
+}
+
+impl Map1 for Pool2D {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ dev: &CudaDevice,
+ inp_l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ // Kernel shape: (c_out, c_in_k, w_k, h_k)
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let ds = if dims.len() == 4 {
+ [dims, inp_l.stride()].concat()
+ } else {
+ panic!("unexpected input shape for conv1d {dims:?}")
+ };
+ let el = shape.elem_count();
+ let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
+ let out_h = (dims[3] - self.h_k) / self.h_stride + 1;
+ let dst_el = out_w * out_h * dims[0] * dims[1];
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let kname = match self.op {
+ PoolOp::Max => "max_pool2d",
+ PoolOp::Avg => "avg_pool2d",
+ };
+ let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let ds = dev.htod_copy(ds).w()?;
+ let params = (
+ el,
+ self.w_k,
+ self.h_k,
+ self.w_stride,
+ self.h_stride,
+ &ds,
+ inp,
+ &out,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1429,12 +1487,30 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
- todo!()
+ fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let device = self.device().clone();
+ let slice = Pool2D {
+ w_k: k.0,
+ h_k: k.1,
+ w_stride: stride.0,
+ h_stride: stride.1,
+ op: PoolOp::Avg,
+ }
+ .map(&self.slice, &device, l)?;
+ Ok(Self { slice, device })
}
- fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
- todo!()
+ fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let device = self.device().clone();
+ let slice = Pool2D {
+ w_k: k.0,
+ h_k: k.1,
+ w_stride: stride.0,
+ h_stride: stride.1,
+ op: PoolOp::Max,
+ }
+ .map(&self.slice, &device, l)?;
+ Ok(Self { slice, device })
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index 73bf7434..009564fa 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -1,25 +1,22 @@
mod test_utils;
-use candle_core::{Device, IndexOp, Tensor};
+use candle_core::{Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
-#[test]
-fn avg_pool2d() -> anyhow::Result<()> {
+fn avg_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
- let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
-
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
Ok(())
}
-#[test]
-fn max_pool2d() -> anyhow::Result<()> {
+fn max_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
- let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
@@ -35,8 +32,7 @@ print(t.flatten())
res = torch.nn.functional.avg_pool2d(t, 2)
print(res)
*/
-#[test]
-fn avg_pool2d_pytorch() -> anyhow::Result<()> {
+fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
@@ -44,7 +40,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127,
],
- &Device::Cpu,
+ dev,
)?
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
@@ -61,7 +57,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
}
#[test]
-fn upsample_nearest2d() -> anyhow::Result<()> {
+fn upsample_nearest2d() -> Result<()> {
let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
@@ -79,3 +75,11 @@ fn upsample_nearest2d() -> anyhow::Result<()> {
);
Ok(())
}
+
+test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
+test_device!(
+ avg_pool2d_pytorch,
+ avg_pool2d_pytorch_cpu,
+ avg_pool2d_pytorch_gpu
+);
+test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 722ca11e..2da4d401 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -24,6 +24,9 @@ __device__ void conv1d(
const size_t c_out = k_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
+ if (dst_i >= src_dims[0] * c_out * l_out) {
+ return;
+ }
// TODO
const size_t b_idx = dst_i / (l_out * c_out);
@@ -61,9 +64,6 @@ __device__ void conv2d(
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
- if (dst_i >= src_numel) {
- return;
- }
// src: (b_size, c_in, w_in, h_in)
// k: (c_out, c_in, w_k, h_k)
const size_t *src_dims = info;
@@ -76,6 +76,9 @@ __device__ void conv2d(
const size_t c_in = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
+ if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
+ return;
+ }
// TODO
const size_t b_idx = dst_i / (w_out * h_out * c_out);
@@ -107,6 +110,116 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
+template <typename T, typename A>
+__device__ void avg_pool2d(
+ const size_t src_numel,
+ const size_t w_k,
+ const size_t h_k,
+ const size_t w_stride,
+ const size_t h_stride,
+ const size_t *info,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, w_in, h_in)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+
+ const size_t c = src_dims[1];
+ const size_t w_in = src_dims[2];
+ const size_t h_in = src_dims[3];
+
+ const size_t w_out = (w_in - w_k) / w_stride + 1;
+ const size_t h_out = (h_in - h_k) / h_stride + 1;
+ if (dst_i >= src_dims[0] * c * w_out * h_out) {
+ return;
+ }
+
+ // TODO: Improve this.
+ const size_t b_idx = dst_i / (w_out * h_out * c);
+ const size_t c_idx = (dst_i / (w_out * h_out)) % c;
+ const size_t dst_w = (dst_i / h_out) % w_out;
+ const size_t dst_h = dst_i % h_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ const float scale = 1.0 / (w_k * h_k);
+ A d = 0;
+ for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
+ size_t src_w = w_stride * dst_w + w_offset;
+ if (src_w >= w_in) {
+ continue;
+ }
+ for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
+ size_t src_h = h_stride * dst_h + h_offset;
+ if (src_h >= h_in) {
+ continue;
+ }
+ const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
+ d += static_cast<A>(src[src_idx]);
+ }
+ }
+ dst[dst_i] = static_cast<T>(d * scale);
+}
+
+template <typename T>
+__device__ void max_pool2d(
+ const size_t src_numel,
+ const size_t w_k,
+ const size_t h_k,
+ const size_t w_stride,
+ const size_t h_stride,
+ const size_t *info,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, w_in, h_in)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+
+ const size_t c = src_dims[1];
+ const size_t w_in = src_dims[2];
+ const size_t h_in = src_dims[3];
+
+ const size_t w_out = (w_in - w_k) / w_stride + 1;
+ const size_t h_out = (h_in - h_k) / h_stride + 1;
+ if (dst_i >= src_dims[0] * c * w_out * h_out) {
+ return;
+ }
+
+ // TODO: Improve this.
+ const size_t b_idx = dst_i / (w_out * h_out * c);
+ const size_t c_idx = (dst_i / (w_out * h_out)) % c;
+ const size_t dst_w = (dst_i / h_out) % w_out;
+ const size_t dst_h = dst_i % h_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ T d = 0;
+ bool set = false;
+ for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
+ size_t src_w = w_stride * dst_w + w_offset;
+ if (src_w >= w_in) {
+ continue;
+ }
+ for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
+ size_t src_h = h_stride * dst_h + h_offset;
+ if (src_h >= h_in) {
+ continue;
+ }
+ const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
+ if (set) {
+ d = maxg(d, src[src_idx]);
+ }
+ else {
+ d = src[src_idx];
+ set = true;
+ }
+ }
+ }
+ dst[dst_i] = d;
+}
+
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
@@ -137,14 +250,46 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
} \
+#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t w_k, \
+ const size_t h_k, \
+ const size_t w_stride, \
+ const size_t h_stride, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ avg_pool2d<TYPENAME, TYPEACC>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
+} \
+
+#define MAX_POOL2D_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t w_k, \
+ const size_t h_k, \
+ const size_t w_stride, \
+ const size_t h_stride, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
+} \
+
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
+AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
+MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
+AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
+MAX_POOL2D_OP(__half, max_pool2d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@@ -157,3 +302,12 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
+AVG_POOL2D_OP(float, float, avg_pool2d_f32)
+AVG_POOL2D_OP(double, double, avg_pool2d_f64)
+AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
+AVG_POOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
+
+MAX_POOL2D_OP(float, max_pool2d_f32)
+MAX_POOL2D_OP(double, max_pool2d_f64)
+MAX_POOL2D_OP(uint8_t, max_pool2d_u8)
+MAX_POOL2D_OP(uint32_t, max_pool2d_u32)