summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/benches/bench_main.rs2
-rw-r--r--candle-core/benches/benchmarks/affine.rs43
-rw-r--r--candle-core/benches/benchmarks/mod.rs1
-rw-r--r--candle-core/src/metal_backend.rs2
-rw-r--r--candle-metal-kernels/src/affine.metal14
5 files changed, 54 insertions, 8 deletions
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs
index 92c33a86..4e508a39 100644
--- a/candle-core/benches/bench_main.rs
+++ b/candle-core/benches/bench_main.rs
@@ -1,4 +1,4 @@
mod benchmarks;
use criterion::criterion_main;
-criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches);
+criterion_main!(benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::where_cond::benches); \ No newline at end of file
diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs
new file mode 100644
index 00000000..eded9f57
--- /dev/null
+++ b/candle-core/benches/benchmarks/affine.rs
@@ -0,0 +1,43 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use std::time::Instant;
+
+fn run(a: &Tensor) {
+ a.affine(12.34, 56.78).unwrap();
+}
+
+fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
+
+ let flops = b * m * k * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_affine_benchmark(c, &device, DType::F32, "affine_f32");
+ run_affine_benchmark(c, &device, DType::F16, "affine_f16");
+ run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
index 4e73ebb6..7dacff5e 100644
--- a/candle-core/benches/benchmarks/mod.rs
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -1,3 +1,4 @@
+pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod where_cond;
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 38f909c8..5269a899 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -353,6 +353,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
+ DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
@@ -371,6 +372,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
+ DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine_strided(
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index 3d8e7f0d..a4484998 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
using namespace metal;
-#define AFFINE(FN_NAME, TYPENAME) \
+#define AFFINE(FN_NAME, T) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
- device const TYPENAME *input, \
- device TYPENAME *output, \
+ device const T *input, \
+ device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
- output[id] = TYPENAME(float(input[id]) * mul + add); \
+ output[id] = T(fma(float(input[id]), mul, add)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
@@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
- device const TYPENAME *input, \
- device TYPENAME *output, \
+ device const T *input, \
+ device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
- output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
+ output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
}
#define POWF(FN_NAME, TYPENAME) \