summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-onnx/src/eval.rs213
-rw-r--r--candle-onnx/tests/ops.rs329
2 files changed, 528 insertions, 14 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 5b66a743..9c22eeab 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -323,6 +323,13 @@ fn simple_eval_(
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
+ let get_opt = |i: usize| {
+ node.input
+ .get(i)
+ .filter(|s: &&String| !s.is_empty())
+ .map(|s| get(s))
+ };
+
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@@ -608,15 +615,13 @@ fn simple_eval_(
}
"Clip" => {
let xs = get(&node.input[0])?;
- let xs = if node.input.len() >= 2 {
- let mins = get(&node.input[1])?;
- xs.broadcast_maximum(mins)?
+ let xs = if let Some(mins) = get_opt(1) {
+ xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
- let xs = if node.input.len() >= 3 {
- let maxs = get(&node.input[2])?;
- xs.broadcast_minimum(maxs)?
+ let xs = if let Some(maxs) = get_opt(2) {
+ xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
@@ -759,7 +764,14 @@ fn simple_eval_(
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
- let output = cond.where_cond(a, b)?;
+
+ // where_cond requires that all inputs are the same shape.
+ // In contrast, the Where op in ONNX only requires that they are broadcastable.
+ let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
+ let cond = cond.broadcast_as(shape.clone())?;
+ let a = a.broadcast_as(shape.clone())?;
+ let b = b.broadcast_as(shape)?;
+ let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
@@ -962,6 +974,7 @@ fn simple_eval_(
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
+
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
@@ -1199,6 +1212,152 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
+ //https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
+ // Version 18 impl
+ "Split" => {
+ let input_tensor = get(&node.input[0])?;
+ let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
+ let axis = input_tensor.normalize_axis(axis)?;
+
+ // Determine split sizes
+ let splits = if node.input.len() > 1 {
+ // If the split tensor is provided, use it to determine sizes
+ let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
+ split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
+ } else {
+ let num_outputs = if let Some(&num_outputs_attrib) =
+ get_attr_opt::<i64>(node, "num_outputs")?
+ {
+ num_outputs_attrib as usize
+ } else {
+ node.output.len()
+ };
+
+ let input_dim = input_tensor.dim(axis)?;
+
+ let mut split_sizes =
+ vec![input_dim / num_outputs as usize; num_outputs as usize];
+ let remainder = input_dim % num_outputs as usize;
+ if remainder > 0 {
+ // If there's a remainder, add it to the last split size
+ split_sizes[num_outputs as usize - 1] += remainder;
+ }
+
+ split_sizes
+ };
+
+ // Perform the split operation
+ let mut outputs = vec![];
+ let mut start = 0;
+ for &size in &splits {
+ let end = start + size;
+ let slice = input_tensor.narrow(axis, start, size)?;
+ outputs.push(slice);
+ start = end;
+ }
+
+ // Insert the split outputs into the values map
+ for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
+ values.insert(output.clone(), slice);
+ }
+ }
+ //https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
+ // Version 13 impl
+ "Expand" => {
+ // unlike broadcast_to, expand allows for the output shape to
+ // be different from the specified shape.
+ let input_tensor = get(&node.input[0])?;
+ let input_shape = get(&node.input[1])?;
+
+ // Check that the shape tensor is 1D
+ if input_shape.rank() != 1 {
+ bail!(
+ "Expand expects 'shape' input to be 1D tensor: {:?}",
+ input_shape
+ );
+ }
+ let input_tensor_dims = input_tensor.dims();
+ let input_shape_dims = input_shape
+ .to_vec1::<i64>()?
+ .into_iter()
+ .map(|x| x as usize)
+ .collect::<Vec<_>>();
+
+ let target_shape =
+ broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
+
+ let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
+
+ values.insert(node.output[0].clone(), expanded_tensor);
+ }
+ //https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
+ // Version 13 impl
+ "ReduceSum" => {
+ let input = get(&node.input[0])?;
+ let axes = get_opt(1);
+ let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
+ let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
+ .copied()
+ .unwrap_or(0);
+
+ let axes = match axes {
+ Some(axes) => axes?
+ .to_vec1::<i64>()?
+ .into_iter()
+ .map(|x| x as usize)
+ .collect::<Vec<_>>(),
+ None => {
+ if noop_with_empty_axes == 1 {
+ vec![]
+ } else {
+ (0..input.rank()).collect()
+ }
+ }
+ };
+
+ let output = if keepdims == 1 {
+ input.sum_keepdim(axes)?
+ } else {
+ input.sum(axes)?
+ };
+
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
+ // Version 18 impl
+ "ReduceL2" => {
+ let input = get(&node.input[0])?;
+ let axes = get_opt(1);
+ let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
+ let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
+ .copied()
+ .unwrap_or(0);
+
+ let input_sq = input.sqr()?;
+
+ let axes = match axes {
+ Some(axes) => axes?
+ .to_vec1::<i64>()?
+ .into_iter()
+ .map(|x| x as usize)
+ .collect::<Vec<_>>(),
+ None => {
+ if noop_with_empty_axes == 1 {
+ vec![]
+ } else {
+ (0..input_sq.rank()).collect()
+ }
+ }
+ };
+
+ let output = if keepdims == 1 {
+ input_sq.sum_keepdim(axes)?.sqrt()?
+ } else {
+ input_sq.sum(axes)?.sqrt()?
+ };
+
+ values.insert(node.output[0].clone(), output);
+ }
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
@@ -1395,13 +1554,6 @@ fn simple_eval_(
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
- let get_opt = |i: usize| {
- node.input
- .get(i)
- .filter(|s: &&String| !s.is_empty())
- .map(|s| get(s))
- };
-
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
@@ -1580,3 +1732,36 @@ fn simple_eval_(
})
.collect()
}
+
+fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
+ let (longest, shortest) = if shape_a.len() > shape_b.len() {
+ (shape_a, shape_b)
+ } else {
+ (shape_b, shape_a)
+ };
+ let diff = longest.len() - shortest.len();
+ let mut target_shape = longest[0..diff].to_vec();
+ for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
+ if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
+ target_shape.push(usize::max(*dim1, *dim2));
+ } else {
+ bail!(
+ "Expand: incompatible shapes for broadcast, {:?} and {:?}",
+ shape_a,
+ shape_b
+ );
+ }
+ }
+ Ok(target_shape)
+}
+
+fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
+ if shapes.is_empty() {
+ return Ok(Vec::new());
+ }
+ let mut shape_out = shapes[0].to_vec();
+ for shape in shapes[1..].iter() {
+ shape_out = broadcast_shape(&shape_out, shape)?;
+ }
+ Ok(shape_out)
+}
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index 51ee037e..55d6fb86 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_expand_dim_changed() -> Result<()> {
+ // Create a manual graph for the Expand operation
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Expand".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec!["data".to_string(), "new_shape".to_string()],
+ output: vec!["expanded".to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ input: vec![
+ ValueInfoProto {
+ name: "data".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: "new_shape".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: "expanded".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ ..GraphProto::default()
+ }));
+
+ // Input tensor with shape [3, 1]
+ let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
+
+ // New shape tensor: [2, 1, 6]
+ let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
+
+ // Expected output after expansion
+ let expected = Tensor::from_vec(
+ vec![
+ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
+ 2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
+ 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
+ 3.0f32, 3.0f32, 3.0f32,
+ ],
+ (2, 3, 6),
+ &Device::Cpu,
+ )?;
+
+ // Execute the model evaluation
+ let inputs = HashMap::from_iter([
+ ("data".to_string(), data),
+ ("new_shape".to_string(), new_shape),
+ ]);
+ let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
+
+ // Retrieve and compare the result
+ let expanded = result.get("expanded").expect("Output 'expanded' not found");
+
+ assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+
+ Ok(())
+}
+
+fn make_graph_helper(
+ op_name: &str,
+ inputs: &[&str],
+ outputs: &[&str],
+ attribs: Vec<AttributeProto>,
+) -> ModelProto {
+ create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: op_name.to_string(),
+ domain: "".to_string(),
+ attribute: attribs,
+ input: inputs.iter().map(|s| s.to_string()).collect(),
+ output: outputs.iter().map(|s| s.to_string()).collect(),
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ input: inputs
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ ..ValueInfoProto::default()
+ })
+ .collect(),
+ output: outputs
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ ..ValueInfoProto::default()
+ })
+ .collect(),
+ ..GraphProto::default()
+ }))
+}
+
+#[test]
+fn test_expand_dim_unchanged() -> Result<()> {
+ // Create a manual graph for the Expand operation
+ let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
+
+ // Input tensor with shape [3, 1] and dtype f32
+ let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
+
+ // New shape tensor: [3, 4]
+ let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
+
+ // Expected output after expansion, dtype f32
+ let expected = Tensor::from_vec(
+ vec![
+ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
+ 3.0f32,
+ ],
+ (3, 4),
+ &Device::Cpu,
+ )?;
+
+ // Execute the model evaluation
+ let inputs = HashMap::from_iter([
+ ("data".to_string(), data),
+ ("new_shape".to_string(), new_shape),
+ ]);
+ let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
+
+ // Retrieve and compare the result
+ let expanded = result.get("expanded").expect("Output 'expanded' not found");
+ assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+
+ Ok(())
+}
+
+fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
+ let attribs = vec![AttributeProto {
+ name: "axis".to_string(),
+ r#type: AttributeType::Int.into(),
+ i: axis,
+ ..AttributeProto::default()
+ }];
+
+ make_graph_helper("Split", inputs, outputs, attribs)
+}
+
+#[test]
+fn test_split_equal_parts_1d_opset13() -> Result<()> {
+ let input = Tensor::from_vec(
+ vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
+ (6,),
+ &Device::Cpu,
+ )?;
+ let mut inputs = HashMap::new();
+ inputs.insert("input".to_string(), input);
+
+ {
+ let manual_graph =
+ make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
+ assert_eq!(eval.len(), 3);
+
+ let out1 = eval.get("output_1").expect("Output 'output_1' not found");
+ let out2 = eval.get("output_2").expect("Output 'output_2' not found");
+ let out3 = eval.get("output_3").expect("Output 'output_3' not found");
+
+ assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
+ assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
+ assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
+ }
+
+ {
+ let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
+ inputs.insert("split".to_string(), splits);
+
+ let manual_graph =
+ make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 2);
+
+ let out1 = eval.get("output_1").expect("Output 'output_1' not found");
+ let out2 = eval.get("output_2").expect("Output 'output_2' not found");
+
+ assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
+ assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
+ }
+ Ok(())
+}
+
+fn make_reduce_sum_graph_helper(
+ inputs: &[&str],
+ outputs: &[&str],
+ keepdims: Option<i64>,
+ noop_with_empty_axes: Option<i64>,
+) -> ModelProto {
+ let mut attribs = vec![];
+ if let Some(keepdims) = keepdims {
+ attribs.push(AttributeProto {
+ name: "keepdims".to_string(),
+ r#type: AttributeType::Int.into(),
+ i: keepdims,
+ ..AttributeProto::default()
+ });
+ }
+ if let Some(noop_with_empty_axes) = noop_with_empty_axes {
+ attribs.push(AttributeProto {
+ name: "noop_with_empty_axes".to_string(),
+ r#type: AttributeType::Ints.into(),
+ i: noop_with_empty_axes,
+ ..AttributeProto::default()
+ });
+ }
+ make_graph_helper("ReduceSum", inputs, outputs, attribs)
+}
+
+#[test]
+fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
+ let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
+
+ // Test with example data
+ {
+ let data = Tensor::from_vec(
+ vec![
+ 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ // let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data);
+ // inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
+
+ assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+ }
+
+ {
+ let data = Tensor::from_vec(
+ vec![
+ -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data.clone());
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = data.sum_all()?.reshape((1, 1, 1))?;
+
+ assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+ }
+
+ Ok(())
+}
+
+#[test]
+fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
+ let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
+
+ // Test with example data
+ {
+ let data = Tensor::from_vec(
+ vec![
+ 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data);
+ inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = Tensor::from_vec(
+ vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
+ (3, 2),
+ &Device::Cpu,
+ )?;
+
+ assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+ }
+
+ // Test with random data
+ {
+ let shape = (3, 2, 2);
+ let data = Tensor::from_vec(
+ vec![
+ -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data.clone());
+ inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+
+ // Calculate expected result
+ let expected = data.sum(1)?;
+
+ assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+ }
+
+ Ok(())
+}