summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormokulus <36231852+mokulus@users.noreply.github.com>2024-05-21 21:47:32 +0200
committerGitHub <noreply@github.com>2024-05-21 21:47:32 +0200
commit7ff921c5385e1f08dc534b67a969cd06b91714d5 (patch)
treeee1cc7cd0c1131991cf15eb44bb00d8168b023e2
parent9b8537a62fe317ac07f8dfafa41b181793925490 (diff)
downloadcandle-7ff921c5385e1f08dc534b67a969cd06b91714d5.tar.gz
candle-7ff921c5385e1f08dc534b67a969cd06b91714d5.tar.bz2
candle-7ff921c5385e1f08dc534b67a969cd06b91714d5.zip
Add RandomNormal ONNX operator (#2200)
-rw-r--r--candle-onnx/src/eval.rs22
-rw-r--r--candle-onnx/tests/ops.rs144
2 files changed, 158 insertions, 8 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 78e0554a..65fb6d77 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -971,7 +971,7 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
- "RandomUniform" => {
+ random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
@@ -979,36 +979,42 @@ pub fn simple_eval(
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
- "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
+ "unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
- "unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
+ "unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
},
Err(_) => {
bail!(
- "unsupported 'dtype' value {dt:?} for RandomUniform {}",
+ "unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
};
- let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
- let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
- bail!("seed for RandomUniform is currently not supported")
+ bail!("seed for {random_type} is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
- let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
+ let output = if random_type == "RandomUniform" {
+ let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
+ let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
+ Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
+ } else {
+ let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
+ let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
+ Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
+ };
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index 30e2480b..a53ad8c5 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> {
Ok(())
}
+// "RandomNormal"
+#[test]
+fn test_random_normal() -> Result<()> {
+ test(vec![3, 2, 1, 4], None, None)?;
+ test(vec![2, 2, 2, 2], Some(-10.0), None)?;
+ test(vec![2, 2, 2, 2], None, Some(10.0))?;
+ test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
+
+ fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
+ let att_mean = AttributeProto {
+ name: "mean".to_string(),
+ ref_attr_name: "mean".to_string(),
+ i: 0,
+ doc_string: "mean".to_string(),
+ r#type: 1, // FLOAT
+ f: mean.unwrap_or(0.0),
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_scale = AttributeProto {
+ name: "scale".to_string(),
+ ref_attr_name: "scale".to_string(),
+ i: 0,
+ doc_string: "scale".to_string(),
+ r#type: 1, // FLOAT
+ f: scale.unwrap_or(1.0),
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_shape = AttributeProto {
+ name: "shape".to_string(),
+ ref_attr_name: "shape".to_string(),
+ i: 0,
+ doc_string: "shape".to_string(),
+ r#type: 7, // INTS
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: shape,
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_dtype = AttributeProto {
+ name: "dtype".to_string(),
+ ref_attr_name: "dtype".to_string(),
+ i: 11, // DOUBLE
+ doc_string: "dtype".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let attrs = {
+ let mut mut_attrs = vec![att_shape, att_dtype];
+ if mean.is_some() {
+ mut_attrs.push(att_mean);
+ }
+ if scale.is_some() {
+ mut_attrs.push(att_scale);
+ }
+ mut_attrs
+ };
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "RandomNormal".to_string(),
+ domain: "".to_string(),
+ attribute: attrs,
+ input: vec![],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+ let data = z.flatten_all()?.to_vec1::<f64>()?;
+
+ // test if values are unique
+ for (i, a) in data.iter().enumerate() {
+ for (j, b) in data.iter().enumerate() {
+ if i == j {
+ continue;
+ };
+ assert_ne!(a, b);
+ }
+ }
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
// "Range"
#[test]
fn test_range() -> Result<()> {