diff options
Diffstat (limited to 'candle-core/src/metal_backend.rs')
-rw-r--r-- | candle-core/src/metal_backend.rs | 120 |
1 files changed, 51 insertions, 69 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 597c2f01..27475efe 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -54,10 +54,6 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - // pub fn metal_device(&self) -> &metal::DeviceRef { - // self.device.as_ref() - // } - pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -76,7 +72,6 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -105,28 +100,22 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result<CpuStorage> { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } match self.dtype { - DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize / 1), - )), - DType::U32 => Ok(CpuStorage::U32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::I64 => Ok(CpuStorage::I64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), - DType::F16 => Ok(CpuStorage::F16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::BF16 => Ok(CpuStorage::BF16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::F32 => Ok(CpuStorage::F32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::F64 => Ok(CpuStorage::F64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), + DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), } } @@ -137,9 +126,9 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); - assert_eq!(dtype, DType::F32); + if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{ + crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + } let mut buffer = device.new_buffer(el, self.dtype); let command_buffer = self.device.command_queue.new_command_buffer(); @@ -153,7 +142,7 @@ impl BackendStorage for MetalStorage { mul as f32, add as f32, ) - .unwrap(); + .map_err(MetalError::from)?; command_buffer.commit(); command_buffer.wait_until_completed(); return Ok(Self { @@ -164,18 +153,18 @@ impl BackendStorage for MetalStorage { } fn powf(&self, _: &Layout, _: f64) -> Result<Self> { - todo!() + crate::bail!("powf metal") } fn elu(&self, _: &Layout, _: f64) -> Result<Self> { - todo!() + crate::bail!("elu metal") } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { - assert!(sum_dims.len() == 1); - assert!(sum_dims[0] == layout.shape().rank() - 1); - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); + + if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){ + crate::bail!("Non contiguous reduce op not supported yet"); + } let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -204,7 +193,7 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), - _ => todo!("Reduce op for non float"), + _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? @@ -234,7 +223,7 @@ impl BackendStorage for MetalStorage { } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { - todo!() + crate::bail!("cmp metal") } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { @@ -246,7 +235,7 @@ impl BackendStorage for MetalStorage { if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", - (left, right) => todo!("to dtype {left:?} - {right:?}"), + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( &device.device, @@ -259,7 +248,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!( + crate::bail!( "TODO Implement the kernel calling cast {:?}-{:?}", self.dtype, dtype @@ -293,7 +282,7 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( &device.device, @@ -306,7 +295,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!("TODO Implement the kernel calling {}", B::KERNEL); + crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); } command_buffer.commit(); command_buffer.wait_until_completed(); @@ -344,7 +333,7 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( &device.device, @@ -365,7 +354,7 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( &device.device, @@ -442,7 +431,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv1D, ) -> Result<Self> { - todo!() + crate::bail!("conv1d metal") } fn conv_transpose1d( @@ -452,7 +441,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose1D, ) -> Result<Self> { - todo!() + crate::bail!("conv_transpose1d metal") } fn conv2d( @@ -462,7 +451,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv2D, ) -> Result<Self> { - todo!() + crate::bail!("conv2d metal") } fn conv_transpose2d( @@ -472,27 +461,27 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose2D, ) -> Result<Self> { - todo!() + crate::bail!("conv_tranpose2d metal") } fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { - todo!() + crate::bail!("avg_pool2d metal") } fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { - todo!() + crate::bail!("max_pool2d metal") } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> { - todo!() + crate::bail!("upsample_nearest1d metal") } fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { - todo!() + crate::bail!("upsample_nearest2d metal") } fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> { - todo!() + crate::bail!("gather metal") } fn scatter_add( @@ -504,14 +493,13 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result<Self> { - todo!() + crate::bail!("scatter_add metal") } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { - assert!(src_l.is_contiguous()); - assert!(src_l.start_offset() == 0); - assert!(ids_l.is_contiguous()); - assert!(ids_l.start_offset() == 0); + if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){ + crate::bail!("Non contiguous index select not implemented"); + } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.shape().elem_count(); @@ -519,10 +507,10 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); - let out = self.to_cpu_storage().unwrap(); + let out = self.to_cpu_storage()?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", - (left, right) => todo!("index select metal {left:?} {right:?}"), + (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; let command_buffer = self.device.command_queue.new_command_buffer(); candle_metal_kernels::call_index_select( @@ -556,7 +544,7 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result<Self> { - todo!() + crate::bail!("index_add metal") } fn matmul( @@ -666,11 +654,6 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); command_buffer.wait_until_completed(); - // let left = self.buffer.read_to_vec::<f32>(10); - // let right = rhs.buffer.read_to_vec::<f32>(10); - // let out = out_buffer.read_to_vec::<f32>(40); - // todo!("Out {left:?} {right:?} {out:?}"); - Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -681,7 +664,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); - // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -690,7 +672,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => todo!("copy_strided not implemented for {dtype:?}"), + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), }; candle_metal_kernels::call_unary_strided( &self.device.device, @@ -741,7 +723,7 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, _seed: u64) -> Result<()> { - todo!("set_seed") + crate::bail!("set_seed") } fn location(&self) -> crate::DeviceLocation { |