diff options
author | ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 11:19:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-12 11:19:49 +0100 |
commit | a3d92ab226ffc33743f4388a814d7dfe7fbe2809 (patch) | |
tree | ed32163806b360515ec98436dc7d05150bd726ec /candle-metal-kernels | |
parent | e90bcdcc7c51dd85037055b59f22568100d801f0 (diff) | |
download | candle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.tar.gz candle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.tar.bz2 candle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.zip |
Metal: Activate bfloat affine and add benchmark (#1543)
* Use cfg to seperate benchmark results based on features
* Add bfloat affine and benchmarks
* Fix flops calculation
* Remove allow pragma
* Avoid some unnecessary returns.
* Improve benchmarks layout
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 3d8e7f0d..a4484998 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index( using namespace metal; -#define AFFINE(FN_NAME, TYPENAME) \ +#define AFFINE(FN_NAME, T) \ kernel void FN_NAME( \ constant size_t &dim, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[id]) * mul + add); \ + output[id] = T(fma(float(input[id]), mul, add)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \ constant size_t *strides, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ + output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ } #define POWF(FN_NAME, TYPENAME) \ |