summaryrefslogtreecommitdiff
path: root/src/cuda_backend.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-21 21:28:59 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-21 21:28:59 +0100
commit7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (patch)
tree710596156a4c026d4dd2ba804fab79b6cdafae3b /src/cuda_backend.rs
parent983415125495e9d57e684b701fbf746ebb6f7a29 (diff)
downloadcandle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.gz
candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.bz2
candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.zip
Check that the tensor is contiguous before applying the kernel.
Diffstat (limited to 'src/cuda_backend.rs')
-rw-r--r--src/cuda_backend.rs6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 8c7f23b3..d12db972 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -85,13 +85,15 @@ impl CudaStorage {
pub(crate) fn affine_impl(
&self,
shape: &Shape,
- _stride: &[usize],
+ stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::F32(arg) => {
- // TODO: Handle the stride.
+ if !shape.is_contiguous(stride) {
+ todo!("affine is only implemented for the contiguous case")
+ }
let dev = arg.device();
let module_name = "affine_f32";
if !dev.has_func(module_name, module_name) {