diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b74137f3..9e47c133 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{CmpOp, ReduceOp}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; @@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> { } } -struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp); +struct FastReduce<'a>(&'a [usize], ReduceOp); impl<'a> Map1 for FastReduce<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, @@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> { .w()?; let src = &src.slice(layout.start_offset()..); let name = match self.1 { - crate::op::ReduceOp::Sum => "fast_sum", - crate::op::ReduceOp::Min => "fast_min", - crate::op::ReduceOp::Max => "fast_max", + ReduceOp::Sum => "fast_sum", + ReduceOp::Min => "fast_min", + ReduceOp::Max => "fast_max", }; let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?; // SAFETY: filled in by the follow up kernel. @@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn reduce_op( - &self, - op: crate::op::ReduceOp, - layout: &Layout, - sum_dims: &[usize], - ) -> Result<Self> { + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { let device = self.device().clone(); let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { + Err(CudaError::InternalError("TODO: implement cmp").into()) + } + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) } |