diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/benches/bench_main.rs | 7 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/affine.rs | 43 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/where_cond.rs | 64 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 3 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 16 |
7 files changed, 143 insertions, 4 deletions
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 8913df4f..162e3f2b 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,9 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches, benchmarks::random::benches); +criterion_main!( + benchmarks::affine::benches, + benchmarks::matmul::benches, + benchmarks::random::benches, + benchmarks::where_cond::benches +); diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs new file mode 100644 index 00000000..eded9f57 --- /dev/null +++ b/candle-core/benches/benchmarks/affine.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.affine(12.34, 56.78).unwrap(); +} + +fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_affine_benchmark(c, &device, DType::F32, "affine_f32"); + run_affine_benchmark(c, &device, DType::F16, "affine_f16"); + run_affine_benchmark(c, &device, DType::BF16, "affine_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index eb20ea70..c45effee 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,5 +1,7 @@ +pub(crate) mod affine; pub(crate) mod matmul; pub(crate) mod random; +pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs new file mode 100644 index 00000000..c517dcf5 --- /dev/null +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -0,0 +1,64 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr<const N: usize>() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::<SIZE>(); + +fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 48250233..8a75bd7c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( @@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e83fb55b..33bab1b6 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { } #[test] -fn logsumexp() -> Result<()> { +fn log_sum_exp() -> Result<()> { let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; - let output = input.logsumexp(D::Minus1)?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; assert_close(&output, &expected, 0.00001)?; Ok(()) } + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 4)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + ); + Ok(()) +} |