summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 7106d4d7..9fc4ceca 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -153,7 +153,13 @@ impl CudaDevice {
})
}
- pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ pub(crate) fn rand_uniform(
+ &self,
+ shape: &Shape,
+ dtype: DType,
+ lo: f64,
+ up: f64,
+ ) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
@@ -174,6 +180,10 @@ impl CudaDevice {
CudaStorageSlice::F64(data)
}
};
+ if lo != 0.0 || up != 1.0 {
+ let layout = Layout::contiguous(shape);
+ Affine(up - lo, lo).map(&slice, self, &layout)?;
+ }
Ok(CudaStorage {
slice,
device: self.clone(),