summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 20:13:29 +0100
committerGitHub <noreply@github.com>2023-09-08 20:13:29 +0100
commitacf8f10ae17d7f472dc1a634fbd7358a79d7b4d4 (patch)
treed36b6aa116a95559d9a2d2e40b79e89b6f5537cf
parent0906acab9186fbb14a2268e12dd66c13b0877f3e (diff)
downloadcandle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.gz
candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.bz2
candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.zip
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/scalar.rs23
-rw-r--r--candle-core/src/tensor.rs24
-rw-r--r--candle-examples/examples/segment-anything/main.rs7
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
5 files changed, 49 insertions, 12 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index a0347416..3504b0a6 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -59,6 +59,7 @@ mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
+pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs
new file mode 100644
index 00000000..43e1f4c8
--- /dev/null
+++ b/candle-core/src/scalar.rs
@@ -0,0 +1,23 @@
+use crate::{Result, Tensor, WithDType};
+
+pub enum TensorScalar {
+ Tensor(Tensor),
+ Scalar(Tensor),
+}
+
+pub trait TensorOrScalar {
+ fn to_tensor_scalar(self) -> Result<TensorScalar>;
+}
+
+impl TensorOrScalar for &Tensor {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ Ok(TensorScalar::Tensor(self.clone()))
+ }
+}
+
+impl<T: WithDType> TensorOrScalar for T {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ let scalar = Tensor::new(self, &crate::Device::Cpu)?;
+ Ok(TensorScalar::Scalar(scalar))
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 6bb3d740..8ad9322b 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
+use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -776,8 +777,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
- pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
- let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@@ -786,36 +794,36 @@ impl Tensor {
}
/// Element-wise equality.
- pub fn eq(&self, rhs: &Self) -> Result<Self> {
+ pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
- pub fn ne(&self, rhs: &Self) -> Result<Self> {
+ pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
- pub fn lt(&self, rhs: &Self) -> Result<Self> {
+ pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
- pub fn gt(&self, rhs: &Self) -> Result<Self> {
+ pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
- pub fn ge(&self, rhs: &Self) -> Result<Self> {
+ pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
- pub fn le(&self, rhs: &Self) -> Result<Self> {
+ pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 4627248c..ce8e3bb4 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> {
}
} else {
let point = Some((args.point_x, args.point_y));
+ let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
+ println!(
+ "mask generated in {:.2}s",
+ start_time.elapsed().as_secs_f32()
+ );
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
- let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
+ let mask = (mask.ge(0f32)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index 40cc6e36..7bbe8419 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -161,21 +161,21 @@ impl PromptEncoder {
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(0f32)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
- let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
+ let labels0 = labels.eq(0f32)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
- let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
+ let labels1 = labels.eq(1f32)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,