summaryrefslogtreecommitdiff
path: root/candle-core/src/variable.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-29 16:28:22 +0100
committerGitHub <noreply@github.com>2023-07-29 16:28:22 +0100
commit16c33383eb2beda515962b219728209b9edb2946 (patch)
tree64b849f0b6765fb937683016ae08cd478d229d2d /candle-core/src/variable.rs
parentbedcef64dca29ed7f15c7ad80245f7c79976e340 (diff)
downloadcandle-16c33383eb2beda515962b219728209b9edb2946.tar.gz
candle-16c33383eb2beda515962b219728209b9edb2946.tar.bz2
candle-16c33383eb2beda515962b219728209b9edb2946.zip
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r--candle-core/src/variable.rs27
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index 0cefee11..61800bf3 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}
+ pub fn from_tensor(t: &Tensor) -> Result<Self> {
+ let inner = t.make_var()?;
+ Ok(Self(inner))
+ }
+
+ pub fn rand_f64<S: Into<Shape>>(
+ lo: f64,
+ up: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
+ pub fn randn_f64<S: Into<Shape>>(
+ mean: f64,
+ std: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,