1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// Variables are wrappers around tensors that can be modified, they are typically used for holding
// weights and being modified by gradient descent.
// We do not expose a public way to create variables as this would break the invariant that the
// tensor within a variable is actually with `is_variable` set to `true`.
use crate::{DType, Device, Error, Result, Shape, Tensor};
/// A variable is a wrapper around a tensor, however variables can have their content modified
/// whereas tensors are immutable.
#[derive(Clone, Debug)]
pub struct Var(Tensor);
impl std::fmt::Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::ops::Deref for Var {
type Target = Tensor;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl Var {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
Ok(Self(inner))
}
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
let inner = Tensor::ones_impl(shape, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
Ok(Self(inner))
}
pub fn randn<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
Ok(Self(inner))
}
/// Creates a new tensor on the specified device using the content and shape of the input.
/// This is similar to `new` but the resulting tensor is a variable.
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?;
let inner = Tensor::new_impl(array, shape, device, true)?;
Ok(Self(inner))
}
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::from_vec_impl(data, shape, device, true)?;
Ok(Self(inner))
}
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::new_impl(array, shape.into(), device, true)?;
Ok(Self(inner))
}
pub fn as_tensor(&self) -> &Tensor {
&self.0
}
/// Consumes this `Var` and return the underlying tensor.
pub fn into_inner(self) -> Tensor {
self.0
}
/// Sets the content of the inner tensor, this does not require a mutable reference as inner
/// mutability is used.
pub fn set(&self, src: &Tensor) -> Result<()> {
if self.same_storage(src) {
let msg = "cannot set a variable to a tensor that is derived from its value";
Err(Error::CannotSetVar { msg }.bt())?
}
let (mut dst, layout) = self.storage_mut_and_layout();
if !layout.is_contiguous() {
let msg = "cannot set a non-contiguous variable";
Err(Error::CannotSetVar { msg }.bt())?
}
let (src, src_l) = src.storage_and_layout();
if layout.shape() != src_l.shape() {
Err(Error::ShapeMismatchBinaryOp {
lhs: layout.shape().clone(),
rhs: src_l.shape().clone(),
op: "set",
}
.bt())?
}
src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
Ok(())
}
}
|