diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7106d4d7..9fc4ceca 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -153,7 +153,13 @@ impl CudaDevice { }) } - pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> { + pub(crate) fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + lo: f64, + up: f64, + ) -> Result<CudaStorage> { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { @@ -174,6 +180,10 @@ impl CudaDevice { CudaStorageSlice::F64(data) } }; + if lo != 0.0 || up != 1.0 { + let layout = Layout::contiguous(shape); + Affine(up - lo, lo).map(&slice, self, &layout)?; + } Ok(CudaStorage { slice, device: self.clone(), |