summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-11 09:16:04 +0200
committerGitHub <noreply@github.com>2023-10-11 08:16:04 +0100
commit37dbbff261f1641db6dc868fc4dded5f8cb25a1f (patch)
tree9ebaa391b90b5936f915c919f0f99015d058b593 /candle-core/src
parent9fea56d28e5f99529da8ed8df1eb508b0f163cc3 (diff)
downloadcandle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.tar.gz
candle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.tar.bz2
candle-37dbbff261f1641db6dc868fc4dded5f8cb25a1f.zip
Use full tensors for zeros and ones (#1071)
* Only optimize float tensors. * Use full tensors for zeros and ones.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/tensor.rs22
1 files changed, 6 insertions, 16 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 7295c350..e2c97af2 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -177,14 +177,9 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
- if is_variable {
- let shape = shape.into();
- let storage = device.ones(&shape, dtype)?;
- Ok(from_storage(storage, shape, none, is_variable))
- } else {
- let storage = device.ones(&crate::shape::SCALAR, dtype)?;
- from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
- }
+ let shape = shape.into();
+ let storage = device.ones(&shape, dtype)?;
+ Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor filled with ones.
@@ -222,14 +217,9 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
- if is_variable {
- let shape = shape.into();
- let storage = device.zeros(&shape, dtype)?;
- Ok(from_storage(storage, shape, none, is_variable))
- } else {
- let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
- from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
- }
+ let shape = shape.into();
+ let storage = device.zeros(&shape, dtype)?;
+ Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor filled with zeros.