summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-29 11:56:40 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-29 11:56:40 +0100
commit2741b39ad37ecb58c110459739ee174fae5f1fa4 (patch)
tree7dce00b52392a2176725a5a6f6987fd095aaabd8 /candle-core/src/backprop.rs
parent3872dc4751c45b625d71c6652c2854a3cc695fb3 (diff)
downloadcandle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.gz
candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.bz2
candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.zip
Use broadcasted scalars for const tensors.
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 7801b878..45448505 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -99,7 +99,7 @@ impl Tensor {
pub fn backward(&self) -> Result<GradStore> {
let sorted_nodes = self.sorted_nodes();
let mut grads = GradStore::new();
- grads.insert(self, self.ones_like()?);
+ grads.insert(self, self.ones_like()?.contiguous()?);
for node in sorted_nodes.iter() {
if node.is_variable() {
continue;