summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/tmp
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-11-11 01:02:15 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-11-30 11:30:31 +0100
commit4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a (patch)
tree78a6b3533670a33f7bc2f75851fac24307a46fed /candle-metal-kernels/tmp
parent7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269 (diff)
downloadcandle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.gz
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.bz2
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.zip
Starting to fix some tests.
Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float.
Diffstat (limited to 'candle-metal-kernels/tmp')
-rw-r--r--candle-metal-kernels/tmp/affine.rs76
-rw-r--r--candle-metal-kernels/tmp/binary.rs182
-rw-r--r--candle-metal-kernels/tmp/cast.rs84
-rw-r--r--candle-metal-kernels/tmp/unary.rs197
4 files changed, 539 insertions, 0 deletions
diff --git a/candle-metal-kernels/tmp/affine.rs b/candle-metal-kernels/tmp/affine.rs
new file mode 100644
index 00000000..cd019056
--- /dev/null
+++ b/candle-metal-kernels/tmp/affine.rs
@@ -0,0 +1,76 @@
+use candle_metal_kernels::{call_affine, Kernels};
+use metal::objc::rc::autoreleasepool;
+use metal::{Device, MTLResourceOptions};
+use rand;
+use std::any::type_name;
+use std::time::Instant;
+
+fn main() {
+ let device = Device::system_default().unwrap();
+ let kernels = Kernels::new();
+
+ let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
+ let f32_10k = (0..10000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+ let f32_100k = (0..100000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
+ "dtype", "kernel", "size", "runs", "total time", "avg time"
+ );
+
+ // f32
+ run_affine_bench(&device, &kernels, &f32_1k);
+ run_affine_bench(&device, &kernels, &f32_10k);
+ run_affine_bench(&device, &kernels, &f32_100k);
+}
+
+fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
+ let command_queue = device.new_command_queue();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let iterations = 10000;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ core::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
+
+ let mul: f32 = 1.2345;
+ let add: f32 = 2.3456;
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_affine(
+ &device,
+ command_buffer,
+ &kernels,
+ "affine_float",
+ v.len(),
+ &input,
+ &mut output,
+ mul,
+ add,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ "affine",
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+}
diff --git a/candle-metal-kernels/tmp/binary.rs b/candle-metal-kernels/tmp/binary.rs
new file mode 100644
index 00000000..af5a8bdc
--- /dev/null
+++ b/candle-metal-kernels/tmp/binary.rs
@@ -0,0 +1,182 @@
+use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
+use half::{bf16, f16};
+use metal::objc::rc::autoreleasepool;
+use metal::{Device, MTLResourceOptions};
+use rand;
+use std::any::type_name;
+use std::time::Instant;
+
+fn main() {
+ let device = Device::system_default().unwrap();
+ let kernels = Kernels::new();
+
+ let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
+ let f32_10k = (0..10000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+ let f32_100k = (0..100000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+
+ let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let f16_1k = f16_map(&f32_1k);
+ let f16_10k = f16_map(&f32_10k);
+ let f16_100k = f16_map(&f32_100k);
+
+ let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let bf16_1k = bf16_map(&f32_1k);
+ let bf16_10k = bf16_map(&f32_10k);
+ let bf16_100k = bf16_map(&f32_100k);
+
+ let f32_ckernels = [
+ binary::contiguous::add::FLOAT,
+ binary::contiguous::sub::FLOAT,
+ binary::contiguous::mul::FLOAT,
+ binary::contiguous::div::FLOAT,
+ ];
+ let f32_skernels = [
+ binary::strided::add::FLOAT,
+ binary::strided::sub::FLOAT,
+ binary::strided::mul::FLOAT,
+ binary::strided::div::FLOAT,
+ ];
+ let f16_ckernels = [
+ binary::contiguous::add::HALF,
+ binary::contiguous::sub::HALF,
+ binary::contiguous::mul::HALF,
+ binary::contiguous::div::HALF,
+ ];
+ let f16_skernels = [
+ binary::strided::add::HALF,
+ binary::strided::sub::HALF,
+ binary::strided::mul::HALF,
+ binary::strided::div::HALF,
+ ];
+ let bf16_ckernels = [
+ binary::contiguous::add::BFLOAT,
+ binary::contiguous::sub::BFLOAT,
+ binary::contiguous::mul::BFLOAT,
+ binary::contiguous::div::BFLOAT,
+ ];
+ let bf16_skernels = [
+ binary::strided::add::BFLOAT,
+ binary::strided::sub::BFLOAT,
+ binary::strided::mul::BFLOAT,
+ binary::strided::div::BFLOAT,
+ ];
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
+ "dtype", "kernel", "size", "runs", "total time", "avg time"
+ );
+
+ // f32
+ run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
+ run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
+ run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
+
+ // f16
+ run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
+ run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
+ run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
+
+ // bf16
+ run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
+ run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
+ run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
+}
+
+fn run_binary_bench<T: Clone>(
+ device: &Device,
+ kernels: &Kernels,
+ v: &[T],
+ contiguous: [binary::contiguous::Kernel; 4],
+ strided: [binary::strided::Kernel; 4],
+) {
+ let command_queue = device.new_command_queue();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let iterations = 1000;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ core::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
+
+ // Contiguous
+ for kernel_name in contiguous {
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_binary_contiguous(
+ device,
+ &command_buffer,
+ kernels,
+ kernel_name,
+ v.len(),
+ &input,
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ kernel_name.to_string(),
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+ }
+
+ // Strided
+ let shape = vec![2, 5_000];
+ let strides = vec![2, 1];
+ let offset = 0;
+ for kernel_name in strided {
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_binary_strided(
+ device,
+ command_buffer,
+ &kernels,
+ kernel_name,
+ &shape,
+ &input,
+ &strides,
+ offset,
+ &input,
+ &strides,
+ offset,
+ &mut output,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ kernel_name.to_string(),
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+ }
+}
diff --git a/candle-metal-kernels/tmp/cast.rs b/candle-metal-kernels/tmp/cast.rs
new file mode 100644
index 00000000..090f510d
--- /dev/null
+++ b/candle-metal-kernels/tmp/cast.rs
@@ -0,0 +1,84 @@
+use candle_metal_kernels::{call_cast_contiguous, Kernels};
+use metal::objc::rc::autoreleasepool;
+use metal::{Device, MTLResourceOptions};
+use rand;
+use std::any::type_name;
+use std::time::Instant;
+
+fn main() {
+ let device = Device::system_default().unwrap();
+ let kernels = Kernels::new();
+
+ let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
+ let f32_10k = (0..10000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+ let f32_100k = (0..100000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+
+ let contiguous_kernels = ["cast_u32_f32"];
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
+ "dtype", "kernel", "size", "runs", "total time", "avg time"
+ );
+
+ // f32
+ run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
+ run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
+ run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
+}
+
+fn run_cast_bench<T: Clone>(
+ device: &Device,
+ kernels: &Kernels,
+ v: &[T],
+ contiguous: &[&'static str],
+) {
+ let command_queue = device.new_command_queue();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let iterations = 1000;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ core::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
+
+ // Contiguous
+ for kernel_name in contiguous {
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_cast_contiguous(
+ device,
+ &command_buffer,
+ kernels,
+ kernel_name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ kernel_name.to_string(),
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+ }
+
+ // Strided?
+}
diff --git a/candle-metal-kernels/tmp/unary.rs b/candle-metal-kernels/tmp/unary.rs
new file mode 100644
index 00000000..66cf25c0
--- /dev/null
+++ b/candle-metal-kernels/tmp/unary.rs
@@ -0,0 +1,197 @@
+use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
+use half::{bf16, f16};
+use metal::objc::rc::autoreleasepool;
+use metal::{Device, MTLResourceOptions};
+use rand;
+use std::any::type_name;
+use std::time::Instant;
+
+fn main() {
+ let device = Device::system_default().unwrap();
+ let kernels = Kernels::new();
+
+ let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
+ let f32_10k = (0..10000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+ let f32_100k = (0..100000)
+ .map(|_| rand::random::<f32>())
+ .collect::<Vec<_>>();
+
+ let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let f16_1k = f16_map(&f32_1k);
+ let f16_10k = f16_map(&f32_10k);
+ let f16_100k = f16_map(&f32_100k);
+
+ let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let bf16_1k = bf16_map(&f32_1k);
+ let bf16_10k = bf16_map(&f32_10k);
+ let bf16_100k = bf16_map(&f32_100k);
+
+ let f32_ckernels = [
+ unary::contiguous::sin::FLOAT,
+ unary::contiguous::cos::FLOAT,
+ unary::contiguous::exp::FLOAT,
+ unary::contiguous::sqr::FLOAT,
+ unary::contiguous::sqrt::FLOAT,
+ unary::contiguous::neg::FLOAT,
+ unary::contiguous::copy::FLOAT,
+ ];
+ let f32_skernels = [
+ unary::strided::sin::FLOAT,
+ unary::strided::cos::FLOAT,
+ unary::strided::exp::FLOAT,
+ unary::strided::sqr::FLOAT,
+ unary::strided::sqrt::FLOAT,
+ unary::strided::neg::FLOAT,
+ unary::strided::copy::FLOAT,
+ ];
+ let f16_ckernels = [
+ unary::contiguous::sin::HALF,
+ unary::contiguous::cos::HALF,
+ unary::contiguous::exp::HALF,
+ unary::contiguous::sqr::HALF,
+ unary::contiguous::sqrt::HALF,
+ unary::contiguous::neg::HALF,
+ unary::contiguous::copy::HALF,
+ ];
+ let f16_skernels = [
+ unary::strided::sin::HALF,
+ unary::strided::cos::HALF,
+ unary::strided::exp::HALF,
+ unary::strided::sqr::HALF,
+ unary::strided::sqrt::HALF,
+ unary::strided::neg::HALF,
+ unary::strided::copy::HALF,
+ ];
+ let bf16_ckernels = [
+ unary::contiguous::sin::BFLOAT,
+ unary::contiguous::cos::BFLOAT,
+ unary::contiguous::exp::BFLOAT,
+ unary::contiguous::sqr::BFLOAT,
+ unary::contiguous::sqrt::BFLOAT,
+ unary::contiguous::neg::BFLOAT,
+ unary::contiguous::copy::BFLOAT,
+ ];
+ let bf16_skernels = [
+ unary::strided::sin::BFLOAT,
+ unary::strided::cos::BFLOAT,
+ unary::strided::exp::BFLOAT,
+ unary::strided::sqr::BFLOAT,
+ unary::strided::sqrt::BFLOAT,
+ unary::strided::neg::BFLOAT,
+ unary::strided::copy::BFLOAT,
+ ];
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
+ "dtype", "kernel", "size", "runs", "total time", "avg time"
+ );
+
+ // f32
+ run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
+ run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
+ run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
+
+ // f16
+ run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
+ run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
+ run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
+
+ // bf16
+ run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
+ run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
+ run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
+}
+
+fn run_unary_bench<T: Clone>(
+ device: &Device,
+ kernels: &Kernels,
+ v: &[T],
+ contiguous: [unary::contiguous::Kernel; 7],
+ strided: [unary::strided::Kernel; 7],
+) {
+ let command_queue = device.new_command_queue();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let iterations = 10000;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ core::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
+
+ // Contiguous
+ for kernel_name in contiguous {
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_unary_contiguous(
+ device,
+ &command_buffer,
+ kernels,
+ kernel_name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ kernel_name.0,
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+ }
+
+ // Strided
+ let shape = vec![2, 5_000];
+ let strides = vec![2, 1];
+ let offset = 0;
+ for kernel_name in &strided {
+ let total_time = autoreleasepool(|| {
+ let command_buffer = command_queue.new_command_buffer();
+ let start = Instant::now();
+ for _ in 0..iterations {
+ call_unary_strided(
+ device,
+ command_buffer,
+ &kernels,
+ kernel_name,
+ &shape,
+ &input,
+ &strides,
+ offset,
+ &mut output,
+ 0,
+ )
+ .unwrap();
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ start.elapsed()
+ });
+
+ println!(
+ "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
+ type_name::<T>().split("::").last().unwrap(),
+ kernel_name.0,
+ v.len(),
+ iterations,
+ total_time,
+ total_time / iterations
+ );
+ }
+}