summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend
diff options
context:
space:
mode:
authorAnubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>2024-10-01 22:41:59 +0530
committerGitHub <noreply@github.com>2024-10-01 19:11:59 +0200
commita2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (patch)
tree79d10359ccc57c6ad31f05f4b0cd8a3513af04ef /candle-core/src/metal_backend
parentdef4c6cdeef78e437846efcb46a23006f539dee4 (diff)
downloadcandle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.gz
candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.bz2
candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.zip
Efficient implementation of `Tensor::ones()` for `metal` (#2512)
* WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r--candle-core/src/metal_backend/mod.rs36
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 69edd2d1..6f560c02 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ let name = match dtype {
+ DType::U8 => "fill_u8",
+ DType::U32 => "fill_u32",
+ DType::I64 => "fill_i64",
+ DType::F16 => "fill_f16",
+ DType::BF16 => "fill_bf16",
+ DType::F32 => "fill_f32",
+ DType::F64 => {
+ let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
+ return self.storage_from_cpu_storage(&cpu_storage);
+ }
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_const_fill(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ shape.elem_count(),
+ &buffer,
+ 1.,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(MetalStorage::new(
+ buffer,
+ self.clone(),
+ shape.elem_count(),
+ dtype,
+ ))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {