diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-27 09:42:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-27 09:42:22 +0100 |
commit | 5320aa6b7d339ff594d3886dd29634ea8cde6f17 (patch) | |
tree | b8dc53eaea3966c288ac3c4a597b4d36c2deefa4 /candle-core/src/test_utils.rs | |
parent | a8b39dd7b784b3c3cdd4d228813bd48b2d0d79bb (diff) | |
download | candle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.tar.gz candle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.tar.bz2 candle-5320aa6b7d339ff594d3886dd29634ea8cde6f17.zip |
Move the test-utils bits to a shared place. (#619)
Diffstat (limited to 'candle-core/src/test_utils.rs')
-rw-r--r-- | candle-core/src/test_utils.rs | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/candle-core/src/test_utils.rs b/candle-core/src/test_utils.rs new file mode 100644 index 00000000..8ff73fc0 --- /dev/null +++ b/candle-core/src/test_utils.rs @@ -0,0 +1,56 @@ +use crate::{Result, Tensor}; + +#[macro_export] +macro_rules! test_device { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => { + #[test] + fn $test_cpu() -> Result<()> { + $fn_name(&Device::Cpu) + } + + #[cfg(feature = "cuda")] + #[test] + fn $test_cuda() -> Result<()> { + $fn_name(&Device::new_cuda(0)?) + } + }; +} + +pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> { + let b = 10f32.powi(digits); + let t = t.to_vec0::<f32>()?; + Ok(f32::round(t * b) / b) +} + +pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> { + let b = 10f32.powi(digits); + let t = t.to_vec1::<f32>()?; + let t = t.iter().map(|t| f32::round(t * b) / b).collect(); + Ok(t) +} + +pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::<f32>()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) +} + +pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::<f32>()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} |