summaryrefslogtreecommitdiff
path: root/candle-core/src/test_utils.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-27 09:42:22 +0100
committerGitHub <noreply@github.com>2023-08-27 09:42:22 +0100
commit5320aa6b7d339ff594d3886dd29634ea8cde6f17 (patch)
treeb8dc53eaea3966c288ac3c4a597b4d36c2deefa4 /candle-core/src/test_utils.rs
parenta8b39dd7b784b3c3cdd4d228813bd48b2d0d79bb (diff)
downloadcandle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.tar.gz
candle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.tar.bz2
candle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.zip
Move the test-utils bits to a shared place. (#619)
Diffstat (limited to 'candle-core/src/test_utils.rs')
-rw-r--r--candle-core/src/test_utils.rs56
1 files changed, 56 insertions, 0 deletions
diff --git a/candle-core/src/test_utils.rs b/candle-core/src/test_utils.rs
new file mode 100644
index 00000000..8ff73fc0
--- /dev/null
+++ b/candle-core/src/test_utils.rs
@@ -0,0 +1,56 @@
+use crate::{Result, Tensor};
+
+#[macro_export]
+macro_rules! test_device {
+ // TODO: Switch to generating the two last arguments automatically once concat_idents is
+ // stable. https://github.com/rust-lang/rust/issues/29599
+ ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
+ #[test]
+ fn $test_cpu() -> Result<()> {
+ $fn_name(&Device::Cpu)
+ }
+
+ #[cfg(feature = "cuda")]
+ #[test]
+ fn $test_cuda() -> Result<()> {
+ $fn_name(&Device::new_cuda(0)?)
+ }
+ };
+}
+
+pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec0::<f32>()?;
+ Ok(f32::round(t * b) / b)
+}
+
+pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec1::<f32>()?;
+ let t = t.iter().map(|t| f32::round(t * b) / b).collect();
+ Ok(t)
+}
+
+pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec2::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect();
+ Ok(t)
+}
+
+pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec3::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| {
+ t.iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect()
+ })
+ .collect();
+ Ok(t)
+}