summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/device.rs44
-rw-r--r--candle-core/src/tensor.rs38
-rw-r--r--candle-core/src/variable.rs27
3 files changed, 95 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 89df8f84..563d892b 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -116,46 +116,62 @@ impl Device {
}
}
- pub(crate) fn rand_uniform<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform_f64(
&self,
- lo: T,
- up: T,
+ lo: f64,
+ up: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let lo = lo.to_f64();
- let up = up.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
- pub(crate) fn rand_normal<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
- mean: T,
- std: T,
+ lo: T,
+ up: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
+ }
+
+ pub(crate) fn rand_normal_f64(
+ &self,
+ mean: f64,
+ std: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let mean = mean.to_f64();
- let std = std.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
+ pub(crate) fn rand_normal<T: crate::FloatDType>(
+ &self,
+ mean: T,
+ std: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
+ }
+
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 8ae92c2e..060e8792 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -245,6 +245,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
+ pub(crate) fn rand_f64_impl<S: Into<Shape>>(
+ lo: f64,
+ up: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ is_variable: bool,
+ ) -> Result<Self> {
+ let s = s.into();
+ let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
+ let none = BackpropOp::none();
+ Ok(from_storage(storage, s, none, is_variable))
+ }
+
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
@@ -268,6 +282,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
+ pub(crate) fn randn_f64_impl<S: Into<Shape>>(
+ mean: f64,
+ std: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ is_variable: bool,
+ ) -> Result<Self> {
+ let s = s.into();
+ let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
+ let none = BackpropOp::none();
+ Ok(from_storage(storage, s, none, is_variable))
+ }
+
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@@ -1448,6 +1476,16 @@ impl Tensor {
}
}
+ /// Create a variable based on the values currently stored in a tensor. The storage is always
+ /// copied.
+ pub(crate) fn make_var(&self) -> Result<Tensor> {
+ let shape = self.shape().clone();
+ let mut storage = self.device().zeros(&shape, self.dtype())?;
+ self.storage()
+ .copy_strided_src(&mut storage, 0, self.layout())?;
+ Ok(from_storage(storage, shape, BackpropOp::none(), true))
+ }
+
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index 0cefee11..61800bf3 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}
+ pub fn from_tensor(t: &Tensor) -> Result<Self> {
+ let inner = t.make_var()?;
+ Ok(Self(inner))
+ }
+
+ pub fn rand_f64<S: Into<Shape>>(
+ lo: f64,
+ up: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
+ pub fn randn_f64<S: Into<Shape>>(
+ mean: f64,
+ std: f64,
+ s: S,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
+ Ok(Self(inner))
+ }
+
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,