diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-21 21:28:59 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-21 21:28:59 +0100 |
commit | 7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (patch) | |
tree | 710596156a4c026d4dd2ba804fab79b6cdafae3b /src/cuda_backend.rs | |
parent | 983415125495e9d57e684b701fbf746ebb6f7a29 (diff) | |
download | candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.gz candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.bz2 candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.zip |
Check that the tensor is contiguous before applying the kernel.
Diffstat (limited to 'src/cuda_backend.rs')
-rw-r--r-- | src/cuda_backend.rs | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8c7f23b3..d12db972 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -85,13 +85,15 @@ impl CudaStorage { pub(crate) fn affine_impl( &self, shape: &Shape, - _stride: &[usize], + stride: &[usize], mul: f64, add: f64, ) -> Result<Self> { match self { Self::F32(arg) => { - // TODO: Handle the stride. + if !shape.is_contiguous(stride) { + todo!("affine is only implemented for the contiguous case") + } let dev = arg.device(); let module_name = "affine_f32"; if !dev.has_func(module_name, module_name) { |