diff options
author | Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> | 2024-10-01 22:41:59 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 19:11:59 +0200 |
commit | a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (patch) | |
tree | 79d10359ccc57c6ad31f05f4b0cd8a3513af04ef /candle-core/src/metal_backend | |
parent | def4c6cdeef78e437846efcb46a23006f539dee4 (diff) | |
download | candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.gz candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.bz2 candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.zip |
Efficient implementation of `Tensor::ones()` for `metal` (#2512)
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 69edd2d1..6f560c02 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice { )) } - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { + let name = match dtype { + DType::U8 => "fill_u8", + DType::U32 => "fill_u32", + DType::I64 => "fill_i64", + DType::F16 => "fill_f16", + DType::BF16 => "fill_bf16", + DType::F32 => "fill_f32", + DType::F64 => { + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + return self.storage_from_cpu_storage(&cpu_storage); + } + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_const_fill( + &self.device, + &command_buffer, + &self.kernels, + name, + shape.elem_count(), + &buffer, + 1., + ) + .map_err(MetalError::from)?; + + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> { |