summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshua <gpg@isthisa.email>2024-06-12 08:15:32 +0200
committerGitHub <noreply@github.com>2024-06-12 07:15:32 +0100
commit2b10aaa05d3752186899bd5b5364d92164edc7ef (patch)
tree857bfb4c2e157e6e866160748d76c9703a337554
parent9f804af29db1273f6580cc8d68b3f7a808f91ee6 (diff)
downloadcandle-2b10aaa05d3752186899bd5b5364d92164edc7ef.tar.gz
candle-2b10aaa05d3752186899bd5b5364d92164edc7ef.tar.bz2
candle-2b10aaa05d3752186899bd5b5364d92164edc7ef.zip
implement Slice op (#2260)
-rw-r--r--candle-onnx/src/eval.rs80
-rw-r--r--candle-onnx/tests/ops.rs135
2 files changed, 215 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index f52e6c5c..10a3b937 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
+ DataType::Bool => Some(DType::U8),
_ => None,
}
}
@@ -1053,6 +1054,85 @@ fn simple_eval_(
),
}
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
+ "Slice" => {
+ let data = get(&node.input[0])?;
+ let starts = get(&node.input[1])?;
+ let ends = get(&node.input[2])?;
+ let default_axes;
+ let default_steps;
+ let axes: &Tensor;
+ let steps: &Tensor;
+ // If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
+ // they are set to [1, ..., 1] of length len(starts)
+ match node.input.len() {
+ 3 => {
+ let len = starts.dims()[0];
+ default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
+ axes = default_axes.as_ref().unwrap();
+ default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
+ steps = default_steps.as_ref().unwrap();
+ }
+ 4 => {
+ let len = starts.dims()[0];
+ axes = get(&node.input[3])?;
+ default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
+ steps = default_steps.as_ref().unwrap();
+ }
+ 5 => {
+ steps = get(&node.input[4])?;
+ axes = get(&node.input[3])?;
+ }
+ _ => bail!(
+ "Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
+ node.input.len(),
+ node
+ ),
+ }
+
+ let mut out = data.clone();
+ for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
+ // All negative elements of axes are made non-negative by
+ // adding r to them, where r = rank(input).
+ let axis = if axis < 0 {
+ axis + data.rank() as i64
+ } else {
+ axis
+ } as usize;
+
+ let data_dim = data.dims()[axis] as i64;
+ let mut s = starts.get(i)?.to_scalar::<i64>()?;
+ let mut e = ends.get(i)?.to_scalar::<i64>()?;
+ // All negative values in starts[i] and ends[i] have
+ // dims[axes[i]] added to them, where dims are the
+ // dimensions of input.
+ if s < 0 {
+ s += data_dim;
+ }
+ if e < 0 {
+ e += data_dim;
+ }
+
+ let p = steps.get(i)?.to_scalar::<i64>()?;
+ // starts[i] is clamped into the range [0, dims[axes[i]]]
+ // for positive stepping and [0, dims[axes[i]]-1] for
+ // negative stepping.
+ // for positive stepping ends[axes[i]] is clamped to
+ // [0, dims[axes[i]]], while for negative stepping it is
+ // clamped to [-1, dims[axes[i]]-1].
+ if p >= 0 {
+ s = s.clamp(0, data_dim);
+ e = e.clamp(0, data_dim);
+ } else {
+ s = s.clamp(0, data_dim - 1);
+ e = e.clamp(-1, data_dim - 1);
+ }
+
+ let indexes = Tensor::arange_step(s, e, p, data.device())?;
+ out = out.index_select(&indexes, axis)?
+ }
+ values.insert(node.output[0].clone(), out);
+ }
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index b4299af1..82d38aa4 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> {
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
Ok(())
}
+
+#[test]
+fn test_slice() -> Result<()> {
+ let model = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Slice".to_string(),
+ input: vec![
+ "data".to_string(),
+ "starts".to_string(),
+ "ends".to_string(),
+ "axes".to_string(),
+ "steps".to_string(),
+ ],
+ output: vec!["result".to_string()],
+ ..NodeProto::default()
+ }],
+ input: ["data", "starts", "ends", "axes", "steps"]
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ r#type: None,
+ doc_string: "".to_string(),
+ })
+ .collect(),
+ output: ["result"]
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ r#type: None,
+ doc_string: "".to_string(),
+ })
+ .collect(),
+ ..GraphProto::default()
+ }));
+
+ /*
+ data = [
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ ]
+ axes = [0, 1]
+ starts = [1, 0]
+ ends = [2, 3]
+ steps = [1, 2]
+ result = [
+ [5, 7],
+ ]
+ */
+
+ let outputs = candle_onnx::simple_eval(
+ &model,
+ HashMap::from_iter([
+ (
+ "data".to_string(),
+ Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
+ ),
+ (
+ "starts".to_string(),
+ Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
+ ),
+ (
+ "ends".to_string(),
+ Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
+ ),
+ (
+ "axes".to_string(),
+ Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
+ ),
+ (
+ "steps".to_string(),
+ Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
+ ),
+ ]),
+ )?;
+ let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
+ assert_eq!(actual, vec![vec![5i64, 7]]);
+
+ /*
+ data = [
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ ]
+ starts = [0, 1]
+ ends = [-1, 1000]
+ result = [
+ [2, 3, 4],
+ ]
+ */
+ let model = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Slice".to_string(),
+ input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
+ output: vec!["result".to_string()],
+ ..NodeProto::default()
+ }],
+ input: ["data", "starts", "ends"]
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ r#type: None,
+ doc_string: "".to_string(),
+ })
+ .collect(),
+ output: ["result"]
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ r#type: None,
+ doc_string: "".to_string(),
+ })
+ .collect(),
+ ..GraphProto::default()
+ }));
+ let outputs = candle_onnx::simple_eval(
+ &model,
+ HashMap::from_iter([
+ (
+ "data".to_string(),
+ Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
+ ),
+ (
+ "starts".to_string(),
+ Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
+ ),
+ (
+ "ends".to_string(),
+ Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
+ ),
+ ]),
+ )?;
+ let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
+ assert_eq!(actual, vec![vec![2i64, 3, 4]]);
+
+ Ok(())
+}