diff options
Diffstat (limited to 'candle-core/src/metal_backend/mod.rs')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 69edd2d1..6f560c02 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice { )) } - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { + let name = match dtype { + DType::U8 => "fill_u8", + DType::U32 => "fill_u32", + DType::I64 => "fill_i64", + DType::F16 => "fill_f16", + DType::BF16 => "fill_bf16", + DType::F32 => "fill_f32", + DType::F64 => { + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + return self.storage_from_cpu_storage(&cpu_storage); + } + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_const_fill( + &self.device, + &command_buffer, + &self.kernels, + name, + shape.elem_count(), + &buffer, + 1., + ) + .map_err(MetalError::from)?; + + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> { |