diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/device.rs | 44 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 38 | ||||
-rw-r--r-- | candle-core/src/variable.rs | 27 |
3 files changed, 95 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 89df8f84..563d892b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -116,46 +116,62 @@ impl Device { } } - pub(crate) fn rand_uniform<T: crate::FloatDType>( + pub(crate) fn rand_uniform_f64( &self, - lo: T, - up: T, + lo: f64, + up: f64, shape: &Shape, + dtype: DType, ) -> Result<Storage> { - let lo = lo.to_f64(); - let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = device.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal<T: crate::FloatDType>( + pub(crate) fn rand_uniform<T: crate::FloatDType>( &self, - mean: T, - std: T, + lo: T, + up: T, + shape: &Shape, + ) -> Result<Storage> { + self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE) + } + + pub(crate) fn rand_normal_f64( + &self, + mean: f64, + std: f64, shape: &Shape, + dtype: DType, ) -> Result<Storage> { - let mean = mean.to_f64(); - let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; + let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; + let storage = device.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cuda(storage)) } } } + pub(crate) fn rand_normal<T: crate::FloatDType>( + &self, + mean: T, + std: T, + shape: &Shape, + ) -> Result<Storage> { + self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE) + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ae92c2e..060e8792 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -245,6 +245,20 @@ impl Tensor { Ok(from_storage(storage, s, none, is_variable)) } + pub(crate) fn rand_f64_impl<S: Into<Shape>>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result<Self> { + let s = s.into(); + let storage = device.rand_uniform_f64(lo, up, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. pub fn rand<S: Into<Shape>, T: crate::FloatDType>( lo: T, @@ -268,6 +282,20 @@ impl Tensor { Ok(from_storage(storage, s, none, is_variable)) } + pub(crate) fn randn_f64_impl<S: Into<Shape>>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result<Self> { + let s = s.into(); + let storage = device.rand_normal_f64(mean, std, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. pub fn randn<S: Into<Shape>, T: crate::FloatDType>( @@ -1448,6 +1476,16 @@ impl Tensor { } } + /// Create a variable based on the values currently stored in a tensor. The storage is always + /// copied. + pub(crate) fn make_var(&self) -> Result<Tensor> { + let shape = self.shape().clone(); + let mut storage = self.device().zeros(&shape, self.dtype())?; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + Ok(from_storage(storage, shape, BackpropOp::none(), true)) + } + // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 0cefee11..61800bf3 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,6 +34,33 @@ impl Var { Ok(Self(inner)) } + pub fn from_tensor(t: &Tensor) -> Result<Self> { + let inner = t.make_var()?; + Ok(Self(inner)) + } + + pub fn rand_f64<S: Into<Shape>>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result<Self> { + let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn randn_f64<S: Into<Shape>>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result<Self> { + let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?; + Ok(Self(inner)) + } + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( lo: T, up: T, |