diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-11 09:16:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-11 08:16:04 +0100 |
commit | 37dbbff261f1641db6dc868fc4dded5f8cb25a1f (patch) | |
tree | 9ebaa391b90b5936f915c919f0f99015d058b593 | |
parent | 9fea56d28e5f99529da8ed8df1eb508b0f163cc3 (diff) | |
download | candle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.tar.gz candle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.tar.bz2 candle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.zip |
Use full tensors for zeros and ones (#1071)
* Only optimize float tensors.
* Use full tensors for zeros and ones.
-rw-r--r-- | candle-core/src/tensor.rs | 22 |
1 files changed, 6 insertions, 16 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 7295c350..e2c97af2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -177,14 +177,9 @@ impl Tensor { is_variable: bool, ) -> Result<Self> { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.ones(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.ones(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with ones. @@ -222,14 +217,9 @@ impl Tensor { is_variable: bool, ) -> Result<Self> { let none = BackpropOp::none(); - if is_variable { - let shape = shape.into(); - let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, none, is_variable)) - } else { - let storage = device.zeros(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) - } + let shape = shape.into(); + let storage = device.zeros(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor filled with zeros. |