diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-11 08:52:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-11 08:52:29 +0100 |
commit | ae79c00e48089d889f900b4c05f90a1201e610c6 (patch) | |
tree | a42bc3334791a79b203397c581d4ed03998e4e3d /candle-core/src/tensor.rs | |
parent | b31a3bbdcbf1a75bbb18cdc2aa0fbff2ab931351 (diff) | |
download | candle-ae79c00e48089d889f900b4c05f90a1201e610c6.tar.gz candle-ae79c00e48089d889f900b4c05f90a1201e610c6.tar.bz2 candle-ae79c00e48089d889f900b4c05f90a1201e610c6.zip |
Allow for uniform initialization in a single step. (#136)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index aba7b91a..ecc018f9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -226,19 +226,33 @@ impl Tensor { s: S, dtype: DType, device: &Device, + lo: f64, + up: f64, is_variable: bool, ) -> Result<Self> { let s = s.into(); - let storage = device.rand_uniform(&s, dtype)?; + let storage = device.rand_uniform(&s, dtype, lo, up)?; Ok(from_storage(storage, s, None, is_variable)) } - pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> { - Self::rand_uniform_impl(s, dtype, device, false) + pub fn rand_uniform<S: Into<Shape>>( + s: S, + dtype: DType, + device: &Device, + lo: f64, + up: f64, + ) -> Result<Self> { + Self::rand_uniform_impl(s, dtype, device, lo, up, false) } - pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> { - Self::rand_uniform_impl(s, dtype, device, true) + pub fn rand_uniform_var<S: Into<Shape>>( + s: S, + dtype: DType, + device: &Device, + lo: f64, + up: f64, + ) -> Result<Self> { + Self::rand_uniform_impl(s, dtype, device, lo, up, true) } fn rand_normal_impl<S: Into<Shape>>( |