summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs22
1 files changed, 14 insertions, 8 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 78e0554a..65fb6d77 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -971,7 +971,7 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
- "RandomUniform" => {
+ random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
@@ -979,36 +979,42 @@ pub fn simple_eval(
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
- "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
+ "unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
- "unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
+ "unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
},
Err(_) => {
bail!(
- "unsupported 'dtype' value {dt:?} for RandomUniform {}",
+ "unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
};
- let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
- let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
- bail!("seed for RandomUniform is currently not supported")
+ bail!("seed for {random_type} is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
- let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
+ let output = if random_type == "RandomUniform" {
+ let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
+ let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
+ Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
+ } else {
+ let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
+ let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
+ Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
+ };
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),