summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/device.rs44
-rw-r--r--candle-core/src/tensor.rs38
-rw-r--r--candle-core/src/variable.rs27
-rw-r--r--candle-examples/examples/simple-training/main.rs88
-rw-r--r--candle-nn/src/init.rs40
-rw-r--r--candle-nn/src/lib.rs1
6 files changed, 196 insertions, 42 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 89df8f84..563d892b 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -116,46 +116,62 @@ impl Device {
}
}
- pub(crate) fn rand_uniform<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform_f64(
&self,
- lo: T,
- up: T,
+ lo: f64,
+ up: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let lo = lo.to_f64();
- let up = up.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
- pub(crate) fn rand_normal<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
- mean: T,
- std: T,
+ lo: T,
+ up: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
+ }
+
+ pub(crate) fn rand_normal_f64(
+ &self,
+ mean: f64,
+ std: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let mean = mean.to_f64();
- let std = std.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
+ pub(crate) fn rand_normal<T: crate::FloatDType>(
+ &self,
+ mean: T,
+ std: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
+ }
+
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 8ae92c2e..060e8792 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -245,6 +245,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
+ pub(crate) fn rand_f64_impl<S: Into<Shape>>(
+ lo: f64,
+ up: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ is_variable: bool,
+ ) -> Result<Self> {
+ let s = s.into();
+ let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
+ let none = BackpropOp::none();
+ Ok(from_storage(storage, s, none, is_variable))
+ }
+
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
@@ -268,6 +282,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
+ pub(crate) fn randn_f64_impl<S: Into<Shape>>(
+ mean: f64,
+ std: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ is_variable: bool,
+ ) -> Result<Self> {
+ let s = s.into();
+ let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
+ let none = BackpropOp::none();
+ Ok(from_storage(storage, s, none, is_variable))
+ }
+
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@@ -1448,6 +1476,16 @@ impl Tensor {
}
}
+ /// Create a variable based on the values currently stored in a tensor. The storage is always
+ /// copied.
+ pub(crate) fn make_var(&self) -> Result<Tensor> {
+ let shape = self.shape().clone();
+ let mut storage = self.device().zeros(&shape, self.dtype())?;
+ self.storage()
+ .copy_strided_src(&mut storage, 0, self.layout())?;
+ Ok(from_storage(storage, shape, BackpropOp::none(), true))
+ }
+
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index 0cefee11..61800bf3 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}
+ pub fn from_tensor(t: &Tensor) -> Result<Self> {
+ let inner = t.make_var()?;
+ Ok(Self(inner))
+ }
+
+ pub fn rand_f64<S: Into<Shape>>(
+ lo: f64,
+ up: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
+ pub fn randn_f64<S: Into<Shape>>(
+ mean: f64,
+ std: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs
index 35b938e8..f15aa60c 100644
--- a/candle-examples/examples/simple-training/main.rs
+++ b/candle-examples/examples/simple-training/main.rs
@@ -2,8 +2,10 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+use clap::{Parser, ValueEnum};
+
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
-use candle_nn::{loss, ops, Linear};
+use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784;
@@ -44,7 +46,7 @@ impl VarStore {
}
}
- fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
+ fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
@@ -59,8 +61,7 @@ impl VarStore {
}
return Ok(tensor.as_tensor().clone());
}
- // TODO: Proper initialization using the `Init` enum.
- let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
+ let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
@@ -77,21 +78,36 @@ impl VarStore {
}
}
-fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
- let ws = vs.get((dim2, dim1), "weight")?;
- let bs = vs.get(dim2, "bias")?;
+fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
+ let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
+ let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
+ Ok(Linear::new(ws, Some(bs)))
+}
+
+fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
+ let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
+ let bound = 1. / (in_dim as f64).sqrt();
+ let init_bs = Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
-#[allow(unused)]
+trait Model: Sized {
+ fn new(vs: VarStore) -> Result<Self>;
+ fn forward(&self, xs: &Tensor) -> Result<Tensor>;
+}
+
struct LinearModel {
linear: Linear,
}
-#[allow(unused)]
-impl LinearModel {
+impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
- let linear = linear(IMAGE_DIM, LABELS, vs)?;
+ let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
@@ -100,14 +116,12 @@ impl LinearModel {
}
}
-#[allow(unused)]
struct Mlp {
ln1: Linear,
ln2: Linear,
}
-#[allow(unused)]
-impl Mlp {
+impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
@@ -121,26 +135,22 @@ impl Mlp {
}
}
-pub fn main() -> anyhow::Result<()> {
+fn training_loop<M: Model>(
+ m: candle_nn::vision::Dataset,
+ learning_rate: f64,
+) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
- // Load the dataset
- let m = candle_nn::vision::mnist::load_dir("data")?;
- println!("train-images: {:?}", m.train_images.shape());
- println!("train-labels: {:?}", m.train_labels.shape());
- println!("test-images: {:?}", m.test_images.shape());
- println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
let vs = VarStore::new(DType::F32, dev);
- let model = LinearModel::new(vs.clone())?;
- // let model = Mlp::new(vs)?;
+ let model = M::new(vs.clone())?;
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
- let sgd = candle_nn::SGD::new(&all_vars, 1.0);
+ let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
@@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
}
Ok(())
}
+
+#[derive(ValueEnum, Clone)]
+enum WhichModel {
+ Linear,
+ Mlp,
+}
+
+#[derive(Parser)]
+struct Args {
+ #[clap(value_enum, default_value_t = WhichModel::Linear)]
+ model: WhichModel,
+
+ #[arg(long)]
+ learning_rate: Option<f64>,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+ // Load the dataset
+ let m = candle_nn::vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.shape());
+ println!("train-labels: {:?}", m.train_labels.shape());
+ println!("test-images: {:?}", m.test_images.shape());
+ println!("test-labels: {:?}", m.test_labels.shape());
+
+ match args.model {
+ WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
+ WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
+ }
+}
diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs
index 762f0ef1..25702d52 100644
--- a/candle-nn/src/init.rs
+++ b/candle-nn/src/init.rs
@@ -1,7 +1,7 @@
//! Variable initialization.
// This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
-use candle::Shape;
+use candle::{DType, Device, Result, Shape, Tensor, Var};
/// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves
@@ -91,11 +91,11 @@ pub enum Init {
fan: FanInOut,
non_linearity: NonLinearity,
},
-
- /// Orthogonal initialization
- Orthogonal { gain: f64 },
}
+pub const ZERO: Init = Init::Const(0.);
+pub const ONE: Init = Init::Const(1.);
+
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
@@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
+
+impl Init {
+ /// Creates a new tensor with the specified shape, device, and initialization.
+ pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
+ match self {
+ Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
+ Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
+ Self::Const(cst) => {
+ Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
+ }
+ Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
+ Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
+ Self::Kaiming {
+ dist,
+ fan,
+ non_linearity,
+ } => {
+ let s = s.into();
+ let fan = fan.for_shape(&s);
+ let gain = non_linearity.gain();
+ let std = gain / (fan as f64).sqrt();
+ match dist {
+ NormalOrUniform::Uniform => {
+ let bound = 3f64.sqrt() * std;
+ Var::rand_f64(-bound, bound, s, dtype, device)
+ }
+ NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
+ }
+ }
+ }
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index db01b067..d0b62dbb 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -15,6 +15,7 @@ pub mod vision;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
+pub use init::Init;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use optim::SGD;