summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-20 10:40:31 +0200
committerGitHub <noreply@github.com>2023-07-20 09:40:31 +0100
commite9c052bf94521b418852a1c5231c12ddce99a78f (patch)
tree7adbd7bc5fa01415c4285a0d141d297f17525293 /candle-core/src/cuda_backend.rs
parentdc416243a32236785a21c1184a21ac21ed06fcc4 (diff)
downloadcandle-e9c052bf94521b418852a1c5231c12ddce99a78f.tar.gz
candle-e9c052bf94521b418852a1c5231c12ddce99a78f.tar.bz2
candle-e9c052bf94521b418852a1c5231c12ddce99a78f.zip
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs20
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index b74137f3..9e47c133 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,4 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
+use crate::op::{CmpOp, ReduceOp};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> {
}
}
-struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
+struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1 for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
@@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> {
.w()?;
let src = &src.slice(layout.start_offset()..);
let name = match self.1 {
- crate::op::ReduceOp::Sum => "fast_sum",
- crate::op::ReduceOp::Min => "fast_min",
- crate::op::ReduceOp::Max => "fast_max",
+ ReduceOp::Sum => "fast_sum",
+ ReduceOp::Min => "fast_min",
+ ReduceOp::Max => "fast_max",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
@@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn reduce_op(
- &self,
- op: crate::op::ReduceOp,
- layout: &Layout,
- sum_dims: &[usize],
- ) -> Result<Self> {
+ fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
+ Err(CudaError::InternalError("TODO: implement cmp").into())
+ }
+
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}