summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs65
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 8b1adbde..f37ab5bb 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,7 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
+use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
+
+fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
+ let dev = device();
+ let kernels = Kernels::new();
+ let command_queue = dev.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let buffer = dev.new_buffer(
+ (len * std::mem::size_of::<T>()) as u64,
+ MTLResourceOptions::StorageModePrivate,
+ );
+
+ call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
+
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec::<T>(&buffer, len)
+}
+
+#[test]
+fn const_fill() {
+ let fills = [
+ "fill_u8",
+ "fill_u32",
+ "fill_i64",
+ "fill_f16",
+ "fill_bf16",
+ "fill_f32",
+ ];
+
+ for name in fills {
+ let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
+ let value = rand::thread_rng().gen_range(1. ..19.);
+
+ match name {
+ "fill_u8" => {
+ let v = constant_fill::<u8>(name, len, value);
+ assert_eq!(v, vec![value as u8; len])
+ }
+ "fill_u32" => {
+ let v = constant_fill::<u32>(name, len, value);
+ assert_eq!(v, vec![value as u32; len])
+ }
+ "fill_i64" => {
+ let v = constant_fill::<i64>(name, len, value);
+ assert_eq!(v, vec![value as i64; len])
+ }
+ "fill_f16" => {
+ let v = constant_fill::<f16>(name, len, value);
+ assert_eq!(v, vec![f16::from_f32(value); len])
+ }
+ "fill_bf16" => {
+ let v = constant_fill::<bf16>(name, len, value);
+ assert_eq!(v, vec![bf16::from_f32(value); len])
+ }
+ "fill_f32" => {
+ let v = constant_fill::<f32>(name, len, value);
+ assert_eq!(v, vec![value; len])
+ }
+ _ => unimplemented!(),
+ };
+ }
+}