// Variables are wrappers around tensors that can be modified, they are typically used for holding // weights and being modified by gradient descent. // We do not expose a public way to create variables as this would break the invariant that the // tensor within a variable is actually with `is_variable` set to `true`. use crate::{DType, Device, Error, Result, Shape, Tensor}; /// A variable is a wrapper around a tensor, however variables can have their content modified /// whereas tensors are immutable. #[derive(Clone, Debug)] pub struct Var(Tensor); impl std::fmt::Display for Var { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self.0, f) } } impl std::ops::Deref for Var { type Target = Tensor; fn deref(&self) -> &Self::Target { self.0.as_ref() } } impl Var { pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { let inner = Tensor::zeros_impl(shape, dtype, device, true)?; Ok(Self(inner)) } pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { let inner = Tensor::ones_impl(shape, dtype, device, true)?; Ok(Self(inner)) } // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is. pub fn from_tensor(t: &Tensor) -> Result { if t.is_variable() { Ok(Self(t.clone())) } else { let inner = t.make_var()?; Ok(Self(inner)) } } pub fn rand_f64>( lo: f64, up: f64, s: S, dtype: DType, device: &Device, ) -> Result { let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?; Ok(Self(inner)) } pub fn randn_f64>( mean: f64, std: f64, s: S, dtype: DType, device: &Device, ) -> Result { let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?; Ok(Self(inner)) } pub fn rand, T: crate::FloatDType>( lo: T, up: T, s: S, device: &Device, ) -> Result { let inner = Tensor::rand_impl(lo, up, s, device, true)?; Ok(Self(inner)) } pub fn randn, T: crate::FloatDType>( mean: T, std: T, s: S, device: &Device, ) -> Result { let inner = Tensor::randn_impl(mean, std, s, device, true)?; Ok(Self(inner)) } /// Creates a new tensor on the specified device using the content and shape of the input. /// This is similar to `new` but the resulting tensor is a variable. pub fn new(array: A, device: &Device) -> Result { let shape = array.shape()?; let inner = Tensor::new_impl(array, shape, device, true)?; Ok(Self(inner)) } pub fn from_vec, D: crate::WithDType>( data: Vec, shape: S, device: &Device, ) -> Result { let inner = Tensor::from_vec_impl(data, shape, device, true)?; Ok(Self(inner)) } pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S, device: &Device, ) -> Result { let inner = Tensor::new_impl(array, shape.into(), device, true)?; Ok(Self(inner)) } pub fn as_detached_tensor(&self) -> Tensor { self.0.detach() } pub fn as_tensor(&self) -> &Tensor { &self.0 } /// Consumes this `Var` and return the underlying tensor. pub fn into_inner(self) -> Tensor { self.0 } /// Sets the content of the inner tensor, this does not require a mutable reference as inner /// mutability is used. pub fn set(&self, src: &Tensor) -> Result<()> { if self.same_storage(src) { let msg = "cannot set a variable to a tensor that is derived from its value"; Err(Error::CannotSetVar { msg }.bt())? } let (mut dst, layout) = self.storage_mut_and_layout(); if !layout.is_contiguous() { let msg = "cannot set a non-contiguous variable"; Err(Error::CannotSetVar { msg }.bt())? } let (src, src_l) = src.storage_and_layout(); if layout.shape() != src_l.shape() { Err(Error::ShapeMismatchBinaryOp { lhs: layout.shape().clone(), rhs: src_l.shape().clone(), op: "set", } .bt())? } src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?; Ok(()) } }