summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-27 11:31:04 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-27 11:31:04 +0100
commit380d61e99092c55532ce06d4aff711b95c18209d (patch)
tree0897cb236ead5559d7ac27976e618bdc7ae125a3
parentd7f729fb8f1d4b224f18ca3d7ae1163afe57a094 (diff)
downloadcandle-380d61e99092c55532ce06d4aff711b95c18209d.tar.gz
candle-380d61e99092c55532ce06d4aff711b95c18209d.tar.bz2
candle-380d61e99092c55532ce06d4aff711b95c18209d.zip
Fix two cuda bugs (matmul and where_cond).
-rw-r--r--Makefile (renamed from candle-core/Makefile)4
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-kernels/src/ternary.cu2
3 files changed, 5 insertions, 5 deletions
diff --git a/candle-core/Makefile b/Makefile
index 97923e96..cb472d80 100644
--- a/candle-core/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
clean-ptx:
find target -name "*.ptx" -type f -delete
- echo "" > kernels/src/lib.rs
- touch kernels/build.rs
+ echo "" > candle-kernels/src/lib.rs
+ touch candle-kernels/build.rs
clean:
cargo clean
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index caaa64b8..57ea9b3e 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -301,8 +301,8 @@ fn gemm_config<T>(
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
- stride_a: (m * k) as i64,
- stride_b: (n * k) as i64,
+ stride_a: (n * k) as i64,
+ stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
})
}
diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu
index 8f51526b..2a20fbec 100644
--- a/candle-kernels/src/ternary.cu
+++ b/candle-kernels/src/ternary.cu
@@ -14,7 +14,7 @@ extern "C" __global__ void FN_NAME( \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
const size_t *strides_t = info + 2*num_dims; \
- const size_t *strides_f = info + 2*num_dims; \
+ const size_t *strides_f = info + 3*num_dims; \
if (is_contiguous(num_dims, dims, strides) \
&& is_contiguous(num_dims, dims, strides_f) \
&& is_contiguous(num_dims, dims, strides_t)) { \