summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-21 13:08:45 -0400
committerGitHub <noreply@github.com>2024-03-21 18:08:45 +0100
commit9563a5fee42f8fef754c238e28ca79725813cea1 (patch)
tree15f8e7bdc192b04da1e4ac7d32a85cf7c912cabb /candle-core
parentec97c98e81707c8f66db6be22d2df7c8791c55b8 (diff)
downloadcandle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.gz
candle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.bz2
candle-9563a5fee42f8fef754c238e28ca79725813cea1.zip
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types * update bench calculation * enable testing all conv operations on metal
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/benches/bench_main.rs3
-rw-r--r--candle-core/benches/benchmarks/conv_transpose2d.rs59
-rw-r--r--candle-core/benches/benchmarks/mod.rs1
-rw-r--r--candle-core/src/metal_backend.rs66
-rw-r--r--candle-core/tests/conv_tests.rs124
5 files changed, 177 insertions, 76 deletions
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs
index 162e3f2b..9f94b252 100644
--- a/candle-core/benches/bench_main.rs
+++ b/candle-core/benches/bench_main.rs
@@ -5,5 +5,6 @@ criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
- benchmarks::where_cond::benches
+ benchmarks::where_cond::benches,
+ benchmarks::conv_transpose2d::benches,
);
diff --git a/candle-core/benches/benchmarks/conv_transpose2d.rs b/candle-core/benches/benchmarks/conv_transpose2d.rs
new file mode 100644
index 00000000..7b252ec6
--- /dev/null
+++ b/candle-core/benches/benchmarks/conv_transpose2d.rs
@@ -0,0 +1,59 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use std::time::Instant;
+
+fn run(
+ x: &Tensor,
+ k: &Tensor,
+ padding: usize,
+ output_padding: usize,
+ stride: usize,
+ dilation: usize,
+) {
+ x.conv_transpose2d(k, padding, output_padding, stride, dilation)
+ .unwrap();
+}
+
+fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let t = Tensor::arange(0.0f32, 10000.0, device)
+ .unwrap()
+ .reshape((1, 4, 50, 50))
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+
+ let kernel = Tensor::arange(0.0f32, 100.0, device)
+ .unwrap()
+ .reshape((4, 1, 5, 5))
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+
+ let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
+ run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
+ run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
index c45effee..a0ffa3eb 100644
--- a/candle-core/benches/benchmarks/mod.rs
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -1,4 +1,5 @@
pub(crate) mod affine;
+pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index c4245652..4f4162e2 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -2,8 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
-use candle_metal_kernels;
use candle_metal_kernels::Kernels;
+use candle_metal_kernels::{self, CallConvTranspose2dCfg};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
@@ -1074,12 +1074,66 @@ impl BackendStorage for MetalStorage {
fn conv_transpose2d(
&self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &ParamsConvTranspose2D,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &ParamsConvTranspose2D,
) -> Result<Self> {
- crate::bail!("Metal conv_tranpose2d not implemented")
+ // Kernel shape: (c_in_k, c_out, h_k, w_k)
+ // Input shape: (b_size, c_in, h_in, w_in)
+ let (out_w, out_h) = (params.out_w(), params.out_h());
+ let dst_el = params.c_out * out_w * out_h * params.b_size;
+
+ let dims = l.dims();
+ if dims.len() != 4 {
+ crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
+ }
+
+ let k_dims = kernel_l.dims();
+ if k_dims.len() != 4 {
+ crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
+ }
+
+ let buffer = self
+ .device
+ .new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
+
+ let command_buffer = self.device.command_buffer()?;
+
+ let name = match self.dtype {
+ DType::F32 => "conv_transpose2d_f32",
+ DType::F16 => "conv_transpose2d_f16",
+ DType::BF16 => "conv_transpose2d_bf16",
+ dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
+ };
+
+ candle_metal_kernels::call_conv_transpose2d(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ CallConvTranspose2dCfg {
+ dilation: params.dilation,
+ stride: params.stride,
+ padding: params.padding,
+ output_padding: params.output_padding,
+ c_out: params.c_out,
+ out_h: out_h,
+ out_w: out_w,
+ b_size: params.b_size,
+ input_dims: l.dims(),
+ input_stride: l.stride(),
+ kernel_dims: kernel_l.dims(),
+ kernel_stride: kernel_l.stride(),
+ input_offset: l.start_offset() * self.dtype.size_in_bytes(),
+ kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
+ },
+ &self.buffer,
+ &kernel.buffer,
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn avg_pool2d(
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 71bf65be..6cc48ec7 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -163,33 +163,34 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
- if !dev.is_metal() {
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
- assert_eq!(res.dims(), [1, 2, 7, 7]);
- assert_eq!(
- test_utils::to_vec3_round(&res.i(0)?, 4)?,
+
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
+
+ assert_eq!(res.dims(), [1, 2, 7, 7]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ [
+ [
+ [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
+ [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
+ [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
+ [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
+ [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
+ [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
+ [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
+ ],
[
- [
- [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
- [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
- [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
- [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
- [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
- [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
- [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
- ],
- [
- [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
- [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
- [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
- [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
- [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
- [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
- [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
- ]
+ [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
+ [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
+ [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
+ [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
+ [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
+ [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
+ [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
- );
- }
+ ]
+ );
+
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
@@ -198,44 +199,37 @@ fn conv2d(dev: &Device) -> Result<()> {
[2.45, -2.3504],
);
- if !dev.is_metal() {
- // Transpose and dilations.
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
- assert_eq!(res.dims(), [1, 2, 9, 9]);
- assert_eq!(
- test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ // Transpose and dilations.
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
+ assert_eq!(res.dims(), [1, 2, 9, 9]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ [
+ [
+ [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
+ [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
+ [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
+ [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
+ [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
+ [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
+ [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
+ [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
+ [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
+ ],
[
- [
- [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
- [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
- [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
- [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
- [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
- [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
- [
- -2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
- -3.5024
- ],
- [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
- [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
- ],
- [
- [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
- [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
- [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
- [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
- [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
- [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
- [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
- [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
- [
- -5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
- 1.0171
- ]
- ]
+ [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
+ [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
+ [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
+ [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
+ [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
+ [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
+ [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
+ [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
+ [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
]
- );
- }
+ ]
+ );
+
Ok(())
}
@@ -290,11 +284,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
]
);
- // conv-transposes are not implemented for metal
- if dev.is_metal() {
- return Ok(());
- }
-
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
@@ -397,9 +386,6 @@ print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
// conv-transposes are not implemented for metal
- if dev.is_metal() {
- return Ok(());
- }
use candle_core::Var;
let t = Var::from_slice(
&[