summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/tmp/affine.rs
blob: cd019056c7ea99844e4fb2a03f1644126ea4dfc5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;

fn main() {
    let device = Device::system_default().unwrap();
    let kernels = Kernels::new();

    let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
    let f32_10k = (0..10000)
        .map(|_| rand::random::<f32>())
        .collect::<Vec<_>>();
    let f32_100k = (0..100000)
        .map(|_| rand::random::<f32>())
        .collect::<Vec<_>>();

    println!(
        "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
        "dtype", "kernel", "size", "runs", "total time", "avg time"
    );

    // f32
    run_affine_bench(&device, &kernels, &f32_1k);
    run_affine_bench(&device, &kernels, &f32_10k);
    run_affine_bench(&device, &kernels, &f32_100k);
}

fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
    let command_queue = device.new_command_queue();
    let options = MTLResourceOptions::StorageModeManaged;

    let iterations = 10000;
    let input = device.new_buffer_with_data(
        v.as_ptr() as *const core::ffi::c_void,
        core::mem::size_of_val(v) as u64,
        options,
    );
    let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);

    let mul: f32 = 1.2345;
    let add: f32 = 2.3456;
    let total_time = autoreleasepool(|| {
        let command_buffer = command_queue.new_command_buffer();
        let start = Instant::now();
        for _ in 0..iterations {
            call_affine(
                &device,
                command_buffer,
                &kernels,
                "affine_float",
                v.len(),
                &input,
                &mut output,
                mul,
                add,
            )
            .unwrap();
        }
        command_buffer.commit();
        command_buffer.wait_until_completed();

        start.elapsed()
    });
    println!(
        "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
        type_name::<T>().split("::").last().unwrap(),
        "affine",
        v.len(),
        iterations,
        total_time,
        total_time / iterations
    );
}