summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs2
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/unary.metal7
-rw-r--r--sd_final.pngbin0 -> 336984 bytes
4 files changed, 8 insertions, 3 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 6d8afab1..48b58e76 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -648,6 +648,7 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
+ ("urecip", DType::F32) => contiguous::round::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
@@ -662,6 +663,7 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
+ ("urecip", DType::F16) => contiguous::round::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index dd97a86d..e3f9397e 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -165,7 +165,7 @@ macro_rules! ops{
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
+ ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, urecip);
}
pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 04fa37a9..46a2b0fe 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index(
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
+template <typename T> METAL_FUNC T urecip(T in){ return T(1.0 / in); }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
+
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
@@ -57,8 +59,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
-
-
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@@ -108,6 +108,8 @@ UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
+UNARY_OP(urecip)
+
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@@ -128,6 +130,7 @@ BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
+BFLOAT_UNARY_OP(urecip)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif
diff --git a/sd_final.png b/sd_final.png
new file mode 100644
index 00000000..cad1029d
--- /dev/null
+++ b/sd_final.png
Binary files differ