summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend/mod.rs36
-rw-r--r--candle-core/tests/tensor_tests.rs30
-rw-r--r--candle-metal-kernels/src/fill.metal39
-rw-r--r--candle-metal-kernels/src/lib.rs28
-rw-r--r--candle-metal-kernels/src/tests.rs65
5 files changed, 194 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 69edd2d1..6f560c02 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ let name = match dtype {
+ DType::U8 => "fill_u8",
+ DType::U32 => "fill_u32",
+ DType::I64 => "fill_i64",
+ DType::F16 => "fill_f16",
+ DType::BF16 => "fill_bf16",
+ DType::F32 => "fill_f32",
+ DType::F64 => {
+ let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
+ return self.storage_from_cpu_storage(&cpu_storage);
+ }
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_const_fill(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ shape.elem_count(),
+ &buffer,
+ 1.,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(MetalStorage::new(
+ buffer,
+ self.clone(),
+ shape.elem_count(),
+ dtype,
+ ))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 4a76035c..e0cea15c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
+ [
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ],
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ]
+ ],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
+ [
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ],
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ]
+ ],
+ );
Ok(())
}
diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal
new file mode 100644
index 00000000..35c3fe7a
--- /dev/null
+++ b/candle-metal-kernels/src/fill.metal
@@ -0,0 +1,39 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T> METAL_FUNC void fill_with(
+ device T *out,
+ constant float &value,
+ constant size_t &numel,
+ uint tid [[thread_position_in_grid]]
+) {
+ if (tid >= numel) {
+ return;
+ }
+ out[tid] = static_cast<T>(value);
+}
+
+#define FILL_OP(NAME, T) \
+kernel void fill_##NAME( \
+ device T *out, \
+ constant float &value, \
+ constant size_t &numel, \
+ uint tid [[thread_position_in_grid]] \
+) { \
+ fill_with<T>(out, value, numel, tid); \
+} \
+
+
+#define FILL_OPS(NAME, T) \
+FILL_OP(NAME, T) \
+
+FILL_OPS(u8, uchar)
+FILL_OPS(u32, uint)
+FILL_OPS(i64, long)
+FILL_OPS(f16, half)
+FILL_OPS(f32, float)
+
+#if __METAL_VERSION__ >= 310
+FILL_OPS(bf16, bfloat)
+#endif
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a595b2bd..a270bb28 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
+const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@@ -31,6 +32,7 @@ pub enum Source {
Binary,
Cast,
Conv,
+ Fill,
Gemm,
Indexing,
Mfa,
@@ -196,6 +198,7 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
+ Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
@@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm(
Ok(())
}
+pub fn call_const_fill(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ name: &'static str,
+ length: usize,
+ output: &Buffer,
+ v: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (output, v, length));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+
+ Ok(())
+}
+
#[cfg(test)]
mod tests;
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 8b1adbde..f37ab5bb 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,7 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
+use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
+
+fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
+ let dev = device();
+ let kernels = Kernels::new();
+ let command_queue = dev.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let buffer = dev.new_buffer(
+ (len * std::mem::size_of::<T>()) as u64,
+ MTLResourceOptions::StorageModePrivate,
+ );
+
+ call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
+
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec::<T>(&buffer, len)
+}
+
+#[test]
+fn const_fill() {
+ let fills = [
+ "fill_u8",
+ "fill_u32",
+ "fill_i64",
+ "fill_f16",
+ "fill_bf16",
+ "fill_f32",
+ ];
+
+ for name in fills {
+ let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
+ let value = rand::thread_rng().gen_range(1. ..19.);
+
+ match name {
+ "fill_u8" => {
+ let v = constant_fill::<u8>(name, len, value);
+ assert_eq!(v, vec![value as u8; len])
+ }
+ "fill_u32" => {
+ let v = constant_fill::<u32>(name, len, value);
+ assert_eq!(v, vec![value as u32; len])
+ }
+ "fill_i64" => {
+ let v = constant_fill::<i64>(name, len, value);
+ assert_eq!(v, vec![value as i64; len])
+ }
+ "fill_f16" => {
+ let v = constant_fill::<f16>(name, len, value);
+ assert_eq!(v, vec![f16::from_f32(value); len])
+ }
+ "fill_bf16" => {
+ let v = constant_fill::<bf16>(name, len, value);
+ assert_eq!(v, vec![bf16::from_f32(value); len])
+ }
+ "fill_f32" => {
+ let v = constant_fill::<f32>(name, len, value);
+ assert_eq!(v, vec![value; len])
+ }
+ _ => unimplemented!(),
+ };
+ }
+}