diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-29 11:56:40 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-29 11:56:40 +0100 |
commit | 2741b39ad37ecb58c110459739ee174fae5f1fa4 (patch) | |
tree | 7dce00b52392a2176725a5a6f6987fd095aaabd8 /candle-core/src/tensor.rs | |
parent | 3872dc4751c45b625d71c6652c2854a3cc695fb3 (diff) | |
download | candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.gz candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.bz2 candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.zip |
Use broadcasted scalars for const tensors.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4b9b3306..6586834c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>( } impl Tensor { - // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn ones_impl<S: Into<Shape>>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result<Self> { - let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + let storage = device.ones(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { @@ -132,6 +130,8 @@ impl Tensor { } pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { + // Maybe we should allocate some actual storage for vars rather than just using a + // broadcasted scalar? Self::ones_impl(shape, dtype, device, true) } @@ -139,16 +139,14 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), &self.device()) } - // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result<Self> { - let shape = shape.into(); - let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + let storage = device.zeros(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { @@ -599,8 +597,7 @@ impl Tensor { &self.layout } - // TODO: Rename to `stride` once the PR that introduced the layout has been merged. - pub fn stride_tmp(&self) -> &[usize] { + pub fn stride(&self) -> &[usize] { self.layout.stride() } |