diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-04 17:58:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-04 17:58:44 +0100 |
commit | c18a856e76cad9626406c3c483a53fb5b7eeef7b (patch) | |
tree | 67c71e73d59dd5ab506d98c134492e08bd9e5e68 /candle-kernels | |
parent | 3349c892523426a00e16dd094837f5d786754ce1 (diff) | |
download | candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.gz candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.bz2 candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.zip |
Add the rounding operators. (#1030)
* Add the rounding operators.
* Avoid tracking gradients for the rounding operations.
* Add some rounding tests.
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/cuda_utils.cuh | 12 | ||||
-rw-r--r-- | candle-kernels/src/unary.cu | 12 |
2 files changed, 24 insertions, 0 deletions
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 8e46a07c..b0a85249 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -131,6 +131,12 @@ __device__ __forceinline__ float tanhg(float a) { return tanhf(a); } __device__ __forceinline__ double tanhg(double a) { return tanh(a); } __device__ __forceinline__ float erfg(float a) { return erff(a); } __device__ __forceinline__ double erfg(double a) { return erf(a); } +__device__ __forceinline__ float ceilg(float a) { return ceilf(a); } +__device__ __forceinline__ double ceilg(double a) { return ceil(a); } +__device__ __forceinline__ float floorg(float a) { return floorf(a); } +__device__ __forceinline__ double floorg(double a) { return floor(a); } +__device__ __forceinline__ float roundg(float a) { return roundf(a); } +__device__ __forceinline__ double roundg(double a) { return round(a); } __device__ __forceinline__ float normcdfg(float a) { return normcdff(a); } __device__ __forceinline__ double normcdfg(double a) { return normcdf(a); } __device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); } @@ -162,6 +168,9 @@ __device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return on __device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); } __device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); } __device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); } +__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); } +__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); } +__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); } __device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); } __device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); } __device__ __forceinline__ __half logg(__half a) { return hlog(a); } @@ -180,6 +189,9 @@ __device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 __device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); } __device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 ceilg(__nv_bfloat16 a) { return __float2bfloat16(ceilf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 floorg(__nv_bfloat16 a) { return __float2bfloat16(floorf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 roundg(__nv_bfloat16 a) { return __float2bfloat16(roundf(__bfloat162float(a))); } __device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); } __device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); } diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index d65eac17..409a337d 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -92,6 +92,9 @@ UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x)) +UNARY_OP(__nv_bfloat16, uceil_bf16, ceilg(x)) +UNARY_OP(__nv_bfloat16, ufloor_bf16, floorg(x)) +UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x)) UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x)) UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) @@ -113,6 +116,9 @@ UNARY_OP(__half, usin_f16, sing(x)) UNARY_OP(__half, ucos_f16, cosg(x)) UNARY_OP(__half, utanh_f16, tanhg(x)) UNARY_OP(__half, uerf_f16, erfg(x)) +UNARY_OP(__half, uceil_f16, ceilg(x)) +UNARY_OP(__half, ufloor_f16, floorg(x)) +UNARY_OP(__half, uround_f16, roundg(x)) UNARY_OP(__half, unormcdf_f16, normcdfg(x)) UNARY_OP(__half, uabs_f16, absg(x)) UNARY_OP(__half, usqr_f16, x*x) @@ -145,6 +151,12 @@ UNARY_OP(float, utanh_f32, tanhg(x)) UNARY_OP(double, utanh_f64, tanhg(x)) UNARY_OP(float, uerf_f32, erfg(x)) UNARY_OP(double, uerf_f64, erfg(x)) +UNARY_OP(float, uceil_f32, ceilg(x)) +UNARY_OP(double, uceil_f64, ceilg(x)) +UNARY_OP(float, ufloor_f32, floorg(x)) +UNARY_OP(double, ufloor_f64, floorg(x)) +UNARY_OP(float, uround_f32, roundg(x)) +UNARY_OP(double, uround_f64, roundg(x)) UNARY_OP(float, unormcdf_f32, normcdfg(x)) UNARY_OP(double, unormcdf_f64, normcdfg(x)) UNARY_OP(float, uabs_f32, absg(x)) |