diff options
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r-- | candle-onnx/src/eval.rs | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 9c22eeab..de3e1010 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use crate::onnx::{self, GraphProto}; use candle::{bail, DType, Device, Result, Tensor}; -use std::{collections::HashMap, usize}; +use std::collections::HashMap; pub type Value = Tensor; @@ -321,7 +321,7 @@ fn simple_eval_( for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { Some(value) => Ok(value), - None => bail!("cannot find {input_name} for op {}", node.name), + None => bail!("cannot find {input_name} for op '{}'", node.name), }; let get_opt = |i: usize| { node.input @@ -362,7 +362,7 @@ fn simple_eval_( // HACK: current implementation of broadcast_pow cannot handle negative base, // so we use powf where we can, which *does* correctly handle negative base. if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() { - let output = input0.powf(exp as f64)?; + let output = input0.powf(exp)?; values.insert(node.output[0].clone(), output); } else { let output = input0.broadcast_pow(input1)?; @@ -643,7 +643,7 @@ fn simple_eval_( let mask = indices.lt(&zeros)?; mask.to_dtype(indices.dtype())? .broadcast_mul(&max)? - .add(&indices)? + .add(indices)? }; // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices @@ -767,7 +767,7 @@ fn simple_eval_( // where_cond requires that all inputs are the same shape. // In contrast, the Where op in ONNX only requires that they are broadcastable. - let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?; + let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?; let cond = cond.broadcast_as(shape.clone())?; let a = a.broadcast_as(shape.clone())?; let b = b.broadcast_as(shape)?; @@ -1283,8 +1283,7 @@ fn simple_eval_( .map(|x| x as usize) .collect::<Vec<_>>(); - let target_shape = - broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?; + let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?; let expanded_tensor = input_tensor.broadcast_as(target_shape)?; @@ -1301,12 +1300,12 @@ fn simple_eval_( .unwrap_or(0); let axes = match axes { - Some(axes) => axes? + Some(Ok(axes)) => axes .to_vec1::<i64>()? .into_iter() .map(|x| x as usize) .collect::<Vec<_>>(), - None => { + Some(Err(_)) | None => { if noop_with_empty_axes == 1 { vec![] } else { @@ -1640,7 +1639,7 @@ fn simple_eval_( let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size] let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size] let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size] - let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?; + let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?; let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?; let wb = b.index_select(&idx_wb, 0)?; let rb = b.index_select(&idx_rb, 0)?; @@ -1649,8 +1648,8 @@ fn simple_eval_( // w, r, wb, rb are all iofc but lstm expects ifco // so we need to move some stuff around - let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?; - let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?; + let idx_i = Tensor::arange(0, hidden_size, x.device())?; + let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?; let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?; let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?; let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?; @@ -1674,7 +1673,7 @@ fn simple_eval_( )?; let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c); - let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" { + let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" { Some(vec![]) } else { None @@ -1688,7 +1687,7 @@ fn simple_eval_( } assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped"); - if let Some(name) = node.output.get(0) { + if let Some(name) = node.output.first() { let h_acc = h_acc.as_ref().unwrap(); let h_acc = lstm.states_to_tensor(h_acc)?; let h_acc = h_acc.reshape(( |