summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-29 11:56:40 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-29 11:56:40 +0100
commit2741b39ad37ecb58c110459739ee174fae5f1fa4 (patch)
tree7dce00b52392a2176725a5a6f6987fd095aaabd8 /candle-core/src
parent3872dc4751c45b625d71c6652c2854a3cc695fb3 (diff)
downloadcandle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.gz
candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.tar.bz2
candle-2741b39ad37ecb58c110459739ee174fae5f1fa4.zip
Use broadcasted scalars for const tensors.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/shape.rs2
-rw-r--r--candle-core/src/tensor.rs17
3 files changed, 10 insertions, 11 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 7801b878..45448505 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -99,7 +99,7 @@ impl Tensor {
pub fn backward(&self) -> Result<GradStore> {
let sorted_nodes = self.sorted_nodes();
let mut grads = GradStore::new();
- grads.insert(self, self.ones_like()?);
+ grads.insert(self, self.ones_like()?.contiguous()?);
for node in sorted_nodes.iter() {
if node.is_variable() {
continue;
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index efea723b..cc068004 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -3,6 +3,8 @@ use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
pub struct Shape(Vec<usize>);
+pub const SCALAR: Shape = Shape(vec![]);
+
impl std::fmt::Debug for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self.dims())
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 4b9b3306..6586834c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
}
impl Tensor {
- // TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = shape.into();
- let storage = device.ones(&shape, dtype)?;
- Ok(from_storage(storage, shape, None, is_variable))
+ let storage = device.ones(&crate::shape::SCALAR, dtype)?;
+ from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@@ -132,6 +130,8 @@ impl Tensor {
}
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
+ // Maybe we should allocate some actual storage for vars rather than just using a
+ // broadcasted scalar?
Self::ones_impl(shape, dtype, device, true)
}
@@ -139,16 +139,14 @@ impl Tensor {
Tensor::ones(self.shape(), self.dtype(), &self.device())
}
- // TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = shape.into();
- let storage = device.zeros(&shape, dtype)?;
- Ok(from_storage(storage, shape, None, is_variable))
+ let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
+ from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@@ -599,8 +597,7 @@ impl Tensor {
&self.layout
}
- // TODO: Rename to `stride` once the PR that introduced the layout has been merged.
- pub fn stride_tmp(&self) -> &[usize] {
+ pub fn stride(&self) -> &[usize] {
self.layout.stride()
}