summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorivarflakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 11:19:49 +0100
committerGitHub <noreply@github.com>2024-01-12 11:19:49 +0100
commita3d92ab226ffc33743f4388a814d7dfe7fbe2809 (patch)
treeed32163806b360515ec98436dc7d05150bd726ec /candle-metal-kernels
parente90bcdcc7c51dd85037055b59f22568100d801f0 (diff)
downloadcandle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.tar.gz
candle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.tar.bz2
candle-a3d92ab226ffc33743f4388a814d7dfe7fbe2809.zip
Metal: Activate bfloat affine and add benchmark (#1543)
* Use cfg to seperate benchmark results based on features * Add bfloat affine and benchmarks * Fix flops calculation * Remove allow pragma * Avoid some unnecessary returns. * Improve benchmarks layout --------- Co-authored-by: Laurent <laurent.mazare@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/affine.metal14
1 files changed, 7 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index 3d8e7f0d..a4484998 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
using namespace metal;
-#define AFFINE(FN_NAME, TYPENAME) \
+#define AFFINE(FN_NAME, T) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
- device const TYPENAME *input, \
- device TYPENAME *output, \
+ device const T *input, \
+ device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
- output[id] = TYPENAME(float(input[id]) * mul + add); \
+ output[id] = T(fma(float(input[id]), mul, add)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
@@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
- device const TYPENAME *input, \
- device TYPENAME *output, \
+ device const T *input, \
+ device T *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
- output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
+ output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
}
#define POWF(FN_NAME, TYPENAME) \