summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs11
-rw-r--r--candle-core/tests/tensor_tests.rs9
2 files changed, 18 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index cb00441f..7cc85489 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
+ // curand can only generate an odd number of values.
+ // https://github.com/huggingface/candle/issues/734
+ let elem_count_round = if elem_count % 2 == 1 {
+ elem_count + 1
+ } else {
+ elem_count
+ };
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 6af43196..cd68908f 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
+fn randn(device: &Device) -> Result<()> {
+ let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ Ok(())
+}
+
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+test_device!(randn, randn_cpu, randn_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381