diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8b1adbde..f37ab5bb 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,7 @@ use super::*; use half::{bf16, f16}; use metal::MTLResourceOptions; +use rand::Rng; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() { let expected = vec![1, 4, 10, 20, 25, 24, 16]; assert_eq!(results, expected); } + +fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> { + let dev = device(); + let kernels = Kernels::new(); + let command_queue = dev.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let buffer = dev.new_buffer( + (len * std::mem::size_of::<T>()) as u64, + MTLResourceOptions::StorageModePrivate, + ); + + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec::<T>(&buffer, len) +} + +#[test] +fn const_fill() { + let fills = [ + "fill_u8", + "fill_u32", + "fill_i64", + "fill_f16", + "fill_bf16", + "fill_f32", + ]; + + for name in fills { + let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); + let value = rand::thread_rng().gen_range(1. ..19.); + + match name { + "fill_u8" => { + let v = constant_fill::<u8>(name, len, value); + assert_eq!(v, vec![value as u8; len]) + } + "fill_u32" => { + let v = constant_fill::<u32>(name, len, value); + assert_eq!(v, vec![value as u32; len]) + } + "fill_i64" => { + let v = constant_fill::<i64>(name, len, value); + assert_eq!(v, vec![value as i64; len]) + } + "fill_f16" => { + let v = constant_fill::<f16>(name, len, value); + assert_eq!(v, vec![f16::from_f32(value); len]) + } + "fill_bf16" => { + let v = constant_fill::<bf16>(name, len, value); + assert_eq!(v, vec![bf16::from_f32(value); len]) + } + "fill_f32" => { + let v = constant_fill::<f32>(name, len, value); + assert_eq!(v, vec![value; len]) + } + _ => unimplemented!(), + }; + } +} |