summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 22:17:45 +0100
committerGitHub <noreply@github.com>2023-11-04 22:17:45 +0100
commit39ad840a909e487f6a05b06c0aee81029eafbf33 (patch)
tree49508e6353aeab23916c700f19ed4dc92ee611e3 /candle-onnx/src
parentb5e4f84bed92102b742d16c3c6f5846038f0a83a (diff)
downloadcandle-39ad840a909e487f6a05b06c0aee81029eafbf33.tar.gz
candle-39ad840a909e487f6a05b06c0aee81029eafbf33.tar.bz2
candle-39ad840a909e487f6a05b06c0aee81029eafbf33.zip
Better tensor initialization in ONNX. (#1270)
* Better tensor initialization in ONNX. * MaxPool support. * Add AvgPool. * Get the squeezenet example to work.
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs113
1 files changed, 103 insertions, 10 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index bc04eb00..4d44bd8e 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -96,7 +96,20 @@ fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
match DataType::try_from(t.data_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => {
- Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu)
+ if dt == DType::F32 && !t.float_data.is_empty() {
+ Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
+ } else if dt == DType::F64 && !t.double_data.is_empty() {
+ Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
+ } else if dt == DType::I64 && !t.int64_data.is_empty() {
+ Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
+ } else {
+ Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ dt,
+ dims.as_slice(),
+ &Device::Cpu,
+ )
+ }
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
@@ -174,17 +187,22 @@ pub fn simple_eval(
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
- // TODO: Check that there is at most a single -1, handle other neg values.
+ // TODO: Check that there is at most a single -1 or 0, handle other neg values.
+ let mut other_than_minus1 = 1usize;
+ for &v in input1.iter() {
+ if v != -1 && v != 0 {
+ other_than_minus1 *= v as usize
+ }
+ }
let input1 = input1
.iter()
- .map(|&v| {
- if v == -1 {
- input0.elem_count()
- } else {
- v as usize
- }
+ .enumerate()
+ .map(|(idx, &v)| match v {
+ -1 => Ok(input0.elem_count() / other_than_minus1),
+ 0 => input0.dim(idx),
+ _ => Ok(v as usize),
})
- .collect::<Vec<usize>>();
+ .collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
@@ -235,6 +253,81 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
+ "Dropout" => {
+ let input = get(&node.input[0])?;
+ // Do not apply dropout at the moment, consider that we're only doing inference.
+ values.insert(node.output[0].clone(), input.clone());
+ }
+ "MaxPool" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
+ let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
+ let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
+ let pads = get_attr_opt::<[i64]>(node, "pads")?;
+ let strides = get_attr_opt::<[i64]>(node, "strides")?;
+ let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
+ match auto_pad {
+ None | Some("NOTSET") => (),
+ Some(s) => bail!("unsupported auto_pad {s}"),
+ };
+ if let Some(d) = dilations {
+ if d.iter().any(|&v| v != 1) {
+ bail!("MaxPool with dilation != 1, {dilations:?}")
+ }
+ }
+ if let Some(d) = pads {
+ if d.iter().any(|&v| v != 0) {
+ bail!("MaxPool with pads != 0, {pads:?}")
+ }
+ }
+ let xs = get(&node.input[0])?;
+ let (k1, k2) = match kernel_shape {
+ [k1, k2] => (*k1 as usize, *k2 as usize),
+ _ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
+ };
+ let ys = match strides {
+ None => xs.max_pool2d((k1, k2))?,
+ Some([s1, s2]) => {
+ xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
+ }
+ Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
+ };
+ values.insert(node.output[0].clone(), ys);
+ }
+ "AveragePool" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
+ let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
+ let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
+ let pads = get_attr_opt::<[i64]>(node, "pads")?;
+ let strides = get_attr_opt::<[i64]>(node, "strides")?;
+ let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
+ match auto_pad {
+ None | Some("NOTSET") => (),
+ Some(s) => bail!("unsupported auto_pad {s}"),
+ };
+ if let Some(d) = dilations {
+ if d.iter().any(|&v| v != 1) {
+ bail!("AvgPool with dilation != 1, {dilations:?}")
+ }
+ }
+ if let Some(d) = pads {
+ if d.iter().any(|&v| v != 0) {
+ bail!("AvgPool with pads != 0, {pads:?}")
+ }
+ }
+ let xs = get(&node.input[0])?;
+ let (k1, k2) = match kernel_shape {
+ [k1, k2] => (*k1 as usize, *k2 as usize),
+ _ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
+ };
+ let ys = match strides {
+ None => xs.avg_pool2d((k1, k2))?,
+ Some([s1, s2]) => {
+ xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
+ }
+ Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
+ };
+ values.insert(node.output[0].clone(), ys);
+ }
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@@ -453,7 +546,7 @@ pub fn simple_eval(
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
- op_type => bail!("unsupported op_type {op_type} for op {}", node.name),
+ op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
graph