summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/metal_backend/mod.rs36
-rw-r--r--candle-core/tests/tensor_tests.rs30
2 files changed, 62 insertions, 4 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 69edd2d1..6f560c02 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ let name = match dtype {
+ DType::U8 => "fill_u8",
+ DType::U32 => "fill_u32",
+ DType::I64 => "fill_i64",
+ DType::F16 => "fill_f16",
+ DType::BF16 => "fill_bf16",
+ DType::F32 => "fill_f32",
+ DType::F64 => {
+ let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
+ return self.storage_from_cpu_storage(&cpu_storage);
+ }
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_const_fill(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ shape.elem_count(),
+ &buffer,
+ 1.,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(MetalStorage::new(
+ buffer,
+ self.clone(),
+ shape.elem_count(),
+ dtype,
+ ))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 4a76035c..e0cea15c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
+ [
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ],
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ]
+ ],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
+ [
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ],
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ]
+ ],
+ );
Ok(())
}