summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorBaye Dieng <bayedieng98@gmail.com>2023-12-29 09:46:24 +0000
committerBaye Dieng <bayedieng98@gmail.com>2023-12-29 09:46:24 +0000
commitcc06ba2294374653cb61654c482513ef7f9b4c88 (patch)
treedaa4aeef7432d9c47ae7771026fd8f49d7da6b98 /candle-metal-kernels
parentb59b1b2bb67b77c58ef897ebd0c548d198871897 (diff)
downloadcandle-cc06ba2294374653cb61654c482513ef7f9b4c88.tar.gz
candle-cc06ba2294374653cb61654c482513ef7f9b4c88.tar.bz2
candle-cc06ba2294374653cb61654c482513ef7f9b4c88.zip
fix bad pattern matching and function name
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/unary.metal6
2 files changed, 4 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e3f9397e..94479882 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -165,7 +165,7 @@ macro_rules! ops{
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, urecip);
+ ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip);
}
pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 46a2b0fe..826b9045 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -19,7 +19,7 @@ METAL_FUNC uint get_strided_index(
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
-template <typename T> METAL_FUNC T urecip(T in){ return T(1.0 / in); }
+template <typename T> METAL_FUNC T recip(T in){ return T(1.0 / in); }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
@@ -108,7 +108,7 @@ UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
-UNARY_OP(urecip)
+UNARY_OP(recip)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@@ -130,7 +130,7 @@ BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
-BFLOAT_UNARY_OP(urecip)
+BFLOAT_UNARY_OP(recip)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif