summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/metal_backend.rs')
-rw-r--r--candle-core/src/metal_backend.rs120
1 files changed, 51 insertions, 69 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 597c2f01..27475efe 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -54,10 +54,6 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
- // pub fn metal_device(&self) -> &metal::DeviceRef {
- // self.device.as_ref()
- // }
-
pub fn id(&self) -> NSUInteger {
self.registry_id()
}
@@ -76,7 +72,6 @@ impl MetalDevice {
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
- // debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
@@ -105,28 +100,22 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ let length = self.buffer.length() as usize;
+ let size = self.dtype.size_in_bytes();
+ if length % size != 0 {
+ crate::bail!(
+ "The Metal buffer length is not aligned with dtype {:?}",
+ self.dtype
+ );
+ }
match self.dtype {
- DType::U8 => Ok(CpuStorage::U8(
- self.buffer.read_to_vec(self.buffer.length() as usize / 1),
- )),
- DType::U32 => Ok(CpuStorage::U32(
- self.buffer.read_to_vec(self.buffer.length() as usize / 4),
- )),
- DType::I64 => Ok(CpuStorage::I64(
- self.buffer.read_to_vec(self.buffer.length() as usize / 8),
- )),
- DType::F16 => Ok(CpuStorage::F16(
- self.buffer.read_to_vec(self.buffer.length() as usize / 2),
- )),
- DType::BF16 => Ok(CpuStorage::BF16(
- self.buffer.read_to_vec(self.buffer.length() as usize / 2),
- )),
- DType::F32 => Ok(CpuStorage::F32(
- self.buffer.read_to_vec(self.buffer.length() as usize / 4),
- )),
- DType::F64 => Ok(CpuStorage::F64(
- self.buffer.read_to_vec(self.buffer.length() as usize / 8),
- )),
+ DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
+ DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
+ DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
+ DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
+ DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
+ DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
+ DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
}
}
@@ -137,9 +126,9 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = self.dtype;
- assert!(layout.is_contiguous());
- assert!(layout.start_offset() == 0);
- assert_eq!(dtype, DType::F32);
+ if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{
+ crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
+ }
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
@@ -153,7 +142,7 @@ impl BackendStorage for MetalStorage {
mul as f32,
add as f32,
)
- .unwrap();
+ .map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
return Ok(Self {
@@ -164,18 +153,18 @@ impl BackendStorage for MetalStorage {
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
- todo!()
+ crate::bail!("powf metal")
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
- todo!()
+ crate::bail!("elu metal")
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
- assert!(sum_dims.len() == 1);
- assert!(sum_dims[0] == layout.shape().rank() - 1);
- assert!(layout.is_contiguous());
- assert!(layout.start_offset() == 0);
+
+ if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){
+ crate::bail!("Non contiguous reduce op not supported yet");
+ }
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@@ -204,7 +193,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
- _ => todo!("Reduce op for non float"),
+ _ => crate::bail!("Reduce op for non float"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
@@ -234,7 +223,7 @@ impl BackendStorage for MetalStorage {
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
- todo!()
+ crate::bail!("cmp metal")
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
@@ -246,7 +235,7 @@ impl BackendStorage for MetalStorage {
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
- (left, right) => todo!("to dtype {left:?} - {right:?}"),
+ (left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
@@ -259,7 +248,7 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
- todo!(
+ crate::bail!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
@@ -293,7 +282,7 @@ impl BackendStorage for MetalStorage {
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
- (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
@@ -306,7 +295,7 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
- todo!("TODO Implement the kernel calling {}", B::KERNEL);
+ crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
}
command_buffer.commit();
command_buffer.wait_until_completed();
@@ -344,7 +333,7 @@ impl BackendStorage for MetalStorage {
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
- (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
@@ -365,7 +354,7 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
- (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
@@ -442,7 +431,7 @@ impl BackendStorage for MetalStorage {
_kernel_l: &Layout,
_params: &ParamsConv1D,
) -> Result<Self> {
- todo!()
+ crate::bail!("conv1d metal")
}
fn conv_transpose1d(
@@ -452,7 +441,7 @@ impl BackendStorage for MetalStorage {
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
) -> Result<Self> {
- todo!()
+ crate::bail!("conv_transpose1d metal")
}
fn conv2d(
@@ -462,7 +451,7 @@ impl BackendStorage for MetalStorage {
_kernel_l: &Layout,
_params: &ParamsConv2D,
) -> Result<Self> {
- todo!()
+ crate::bail!("conv2d metal")
}
fn conv_transpose2d(
@@ -472,27 +461,27 @@ impl BackendStorage for MetalStorage {
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
) -> Result<Self> {
- todo!()
+ crate::bail!("conv_tranpose2d metal")
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
- todo!()
+ crate::bail!("avg_pool2d metal")
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
- todo!()
+ crate::bail!("max_pool2d metal")
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
- todo!()
+ crate::bail!("upsample_nearest1d metal")
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
- todo!()
+ crate::bail!("upsample_nearest2d metal")
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
- todo!()
+ crate::bail!("gather metal")
}
fn scatter_add(
@@ -504,14 +493,13 @@ impl BackendStorage for MetalStorage {
_: &Layout,
_: usize,
) -> Result<Self> {
- todo!()
+ crate::bail!("scatter_add metal")
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
- assert!(src_l.is_contiguous());
- assert!(src_l.start_offset() == 0);
- assert!(ids_l.is_contiguous());
- assert!(ids_l.start_offset() == 0);
+ if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){
+ crate::bail!("Non contiguous index select not implemented");
+ }
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
@@ -519,10 +507,10 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
- let out = self.to_cpu_storage().unwrap();
+ let out = self.to_cpu_storage()?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
- (left, right) => todo!("index select metal {left:?} {right:?}"),
+ (left, right) => crate::bail!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_index_select(
@@ -556,7 +544,7 @@ impl BackendStorage for MetalStorage {
_: &Layout,
_: usize,
) -> Result<Self> {
- todo!()
+ crate::bail!("index_add metal")
}
fn matmul(
@@ -666,11 +654,6 @@ impl BackendStorage for MetalStorage {
command_buffer.commit();
command_buffer.wait_until_completed();
- // let left = self.buffer.read_to_vec::<f32>(10);
- // let right = rhs.buffer.read_to_vec::<f32>(10);
- // let out = out_buffer.read_to_vec::<f32>(40);
- // todo!("Out {left:?} {right:?} {out:?}");
-
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
@@ -681,7 +664,6 @@ impl BackendStorage for MetalStorage {
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
- // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
if el_count == 0 {
return Ok(());
}
@@ -690,7 +672,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
- dtype => todo!("copy_strided not implemented for {dtype:?}"),
+ dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
@@ -741,7 +723,7 @@ impl BackendDevice for MetalDevice {
}
fn set_seed(&self, _seed: u64) -> Result<()> {
- todo!("set_seed")
+ crate::bail!("set_seed")
}
fn location(&self) -> crate::DeviceLocation {