diff options
Diffstat (limited to 'candle-onnx/src')
-rw-r--r-- | candle-onnx/src/eval.rs | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 78e0554a..65fb6d77 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -971,7 +971,7 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } - "RandomUniform" => { + random_type @ ("RandomUniform" | "RandomNormal") => { let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float // type by // default @@ -979,36 +979,42 @@ pub fn simple_eval( Ok(dt) => match dtype(dt) { Some(DType::U8 | DType::U32 | DType::I64) => { bail!( - "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}", + "unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}", node.name ) } Some(dt) => dt, None => { bail!( - "unsupported 'dtype' value {dt:?} for RandomUnifrom {}", + "unsupported 'dtype' value {dt:?} for {random_type} {}", node.name ) } }, Err(_) => { bail!( - "unsupported 'dtype' value {dt:?} for RandomUniform {}", + "unsupported 'dtype' value {dt:?} for {random_type} {}", node.name ) } }; - let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); - let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); let seed: Option<f32> = get_attr_opt(node, "seed")?.copied(); if seed.is_some() { - bail!("seed for RandomUniform is currently not supported") + bail!("seed for {random_type} is currently not supported") }; let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")? .iter() .map(|x| *x as usize) .collect(); - let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?; + let output = if random_type == "RandomUniform" { + let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); + let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); + Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)? + } else { + let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0); + let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0); + Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)? + }; values.insert(node.output[0].clone(), output); } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), |