summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-21 21:28:59 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-21 21:28:59 +0100
commit7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (patch)
tree710596156a4c026d4dd2ba804fab79b6cdafae3b /src
parent983415125495e9d57e684b701fbf746ebb6f7a29 (diff)
downloadcandle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.gz
candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.tar.bz2
candle-7c46de9584fd4315b84d3bc4c28cf1b2bad7785d.zip
Check that the tensor is contiguous before applying the kernel.
Diffstat (limited to 'src')
-rw-r--r--src/cuda_backend.rs6
-rw-r--r--src/shape.rs14
-rw-r--r--src/tensor.rs9
3 files changed, 19 insertions, 10 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 8c7f23b3..d12db972 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -85,13 +85,15 @@ impl CudaStorage {
pub(crate) fn affine_impl(
&self,
shape: &Shape,
- _stride: &[usize],
+ stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::F32(arg) => {
- // TODO: Handle the stride.
+ if !shape.is_contiguous(stride) {
+ todo!("affine is only implemented for the contiguous case")
+ }
let dev = arg.device();
let module_name = "affine_f32";
if !dev.has_func(module_name, module_name) {
diff --git a/src/shape.rs b/src/shape.rs
index d626aee6..ebc497cf 100644
--- a/src/shape.rs
+++ b/src/shape.rs
@@ -128,6 +128,20 @@ impl Shape {
stride.reverse();
stride
}
+
+ pub fn is_contiguous(&self, stride: &[usize]) -> bool {
+ if self.0.len() != stride.len() {
+ return false;
+ }
+ let mut acc = 1;
+ for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
+ if stride != acc {
+ return false;
+ }
+ acc *= dim;
+ }
+ true
+ }
}
#[cfg(test)]
diff --git a/src/tensor.rs b/src/tensor.rs
index a1262334..02105573 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -310,14 +310,7 @@ impl Tensor {
}
pub fn is_contiguous(&self) -> bool {
- let mut acc = 1;
- for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
- if stride != acc {
- return false;
- }
- acc *= dim;
- }
- true
+ self.shape.is_contiguous(&self.stride)
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first