summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/metal_backend/mod.rs')
-rw-r--r--candle-core/src/metal_backend/mod.rs36
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 69edd2d1..6f560c02 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ let name = match dtype {
+ DType::U8 => "fill_u8",
+ DType::U32 => "fill_u32",
+ DType::I64 => "fill_i64",
+ DType::F16 => "fill_f16",
+ DType::BF16 => "fill_bf16",
+ DType::F32 => "fill_f32",
+ DType::F64 => {
+ let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
+ return self.storage_from_cpu_storage(&cpu_storage);
+ }
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_const_fill(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ shape.elem_count(),
+ &buffer,
+ 1.,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(MetalStorage::new(
+ buffer,
+ self.clone(),
+ shape.elem_count(),
+ dtype,
+ ))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {