summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-09 18:44:21 +0100
committerGitHub <noreply@github.com>2023-09-09 18:44:21 +0100
commit258ac32c3868d4103e90df19af99a3e13c805c4e (patch)
tree2f54d1123d40980e5d6b1f3cbda231aa46af835e /candle-core/src/cuda_backend.rs
parent31936c08fe26ad0bc401fb1ea5b4eac869491637 (diff)
downloadcandle-258ac32c3868d4103e90df19af99a3e13c805c4e.tar.gz
candle-258ac32c3868d4103e90df19af99a3e13c805c4e.tar.bz2
candle-258ac32c3868d4103e90df19af99a3e13c805c4e.zip
Fix cuda randn when generating an odd number of values. (#793)
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs11
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index cb00441f..7cc85489 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
+ // curand can only generate an odd number of values.
+ // https://github.com/huggingface/candle/issues/734
+ let elem_count_round = if elem_count % 2 == 1 {
+ elem_count + 1
+ } else {
+ elem_count
+ };
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}