summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-02 10:22:31 +0200
committerGitHub <noreply@github.com>2024-10-02 10:22:31 +0200
commitfd08d3d0a40872f207284b008de23ef875d54f74 (patch)
treedfe536a82f70591ff6cf1a5b3d6abcc2caf6aeef /candle-metal-kernels
parenta2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (diff)
downloadcandle-fd08d3d0a40872f207284b008de23ef875d54f74.tar.gz
candle-fd08d3d0a40872f207284b008de23ef875d54f74.tar.bz2
candle-fd08d3d0a40872f207284b008de23ef875d54f74.zip
Tweak some metal tests. (#2528)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs5
-rw-r--r--candle-metal-kernels/src/tests.rs80
2 files changed, 23 insertions, 62 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a270bb28..be616009 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -2372,16 +2372,11 @@ pub fn call_const_fill(
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
-
encoder.set_compute_pipeline_state(&pipeline);
-
set_params!(encoder, (output, v, length));
-
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
-
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
-
Ok(())
}
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index f37ab5bb..637bf2e2 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() {
assert_eq!(results, expected);
}
-fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
- let dev = device();
- let kernels = Kernels::new();
- let command_queue = dev.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
-
- let buffer = dev.new_buffer(
- (len * std::mem::size_of::<T>()) as u64,
- MTLResourceOptions::StorageModePrivate,
- );
-
- call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- read_to_vec::<T>(&buffer, len)
-}
-
#[test]
fn const_fill() {
- let fills = [
- "fill_u8",
- "fill_u32",
- "fill_i64",
- "fill_f16",
- "fill_bf16",
- "fill_f32",
- ];
-
- for name in fills {
+ fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
+ let dev = device();
+ let kernels = Kernels::new();
+ let command_queue = dev.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let buffer = dev.new_buffer(
+ (len * std::mem::size_of::<T>()) as u64,
+ MTLResourceOptions::StorageModePrivate,
+ );
+ call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ read_to_vec::<T>(&buffer, len)
+ }
+ fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
-
- match name {
- "fill_u8" => {
- let v = constant_fill::<u8>(name, len, value);
- assert_eq!(v, vec![value as u8; len])
- }
- "fill_u32" => {
- let v = constant_fill::<u32>(name, len, value);
- assert_eq!(v, vec![value as u32; len])
- }
- "fill_i64" => {
- let v = constant_fill::<i64>(name, len, value);
- assert_eq!(v, vec![value as i64; len])
- }
- "fill_f16" => {
- let v = constant_fill::<f16>(name, len, value);
- assert_eq!(v, vec![f16::from_f32(value); len])
- }
- "fill_bf16" => {
- let v = constant_fill::<bf16>(name, len, value);
- assert_eq!(v, vec![bf16::from_f32(value); len])
- }
- "fill_f32" => {
- let v = constant_fill::<f32>(name, len, value);
- assert_eq!(v, vec![value; len])
- }
- _ => unimplemented!(),
- };
+ let v = constant_fill::<T>(name, len, value);
+ assert_eq!(v, vec![f(value); len])
}
+ test::<u8, _>("fill_u8", |v| v as u8);
+ test::<u32, _>("fill_u32", |v| v as u32);
+ test::<i64, _>("fill_i64", |v| v as i64);
+ test::<f16, _>("fill_f16", f16::from_f32);
+ test::<bf16, _>("fill_bf16", bf16::from_f32);
+ test::<f32, _>("fill_f32", |v| v);
}