diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-29 16:28:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-29 16:28:22 +0100 |
commit | 16c33383eb2beda515962b219728209b9edb2946 (patch) | |
tree | 64b849f0b6765fb937683016ae08cd478d229d2d /candle-core/src/variable.rs | |
parent | bedcef64dca29ed7f15c7ad80245f7c79976e340 (diff) | |
download | candle-16c33383eb2beda515962b219728209b9edb2946.tar.gz candle-16c33383eb2beda515962b219728209b9edb2946.tar.bz2 candle-16c33383eb2beda515962b219728209b9edb2946.zip |
Improve the mnist training example. (#276)
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r-- | candle-core/src/variable.rs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 0cefee11..61800bf3 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,6 +34,33 @@ impl Var { Ok(Self(inner)) } + pub fn from_tensor(t: &Tensor) -> Result<Self> { + let inner = t.make_var()?; + Ok(Self(inner)) + } + + pub fn rand_f64<S: Into<Shape>>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result<Self> { + let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn randn_f64<S: Into<Shape>>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result<Self> { + let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?; + Ok(Self(inner)) + } + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( lo: T, up: T, |