diff options
author | MilkFather <31627231+MilkFather@users.noreply.github.com> | 2024-04-29 17:04:43 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-29 11:04:43 +0200 |
commit | 3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (patch) | |
tree | 6077a6f41f7b1ed97b6f44d5e8305126c0d5f5a5 /candle-nn | |
parent | ed7b99f525ab898aa677fe1f4446e345ac74f4ec (diff) | |
download | candle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.tar.gz candle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.tar.bz2 candle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.zip |
Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op
* small fix
* add as a method on `Tensor`
* implement gradient calculation for sigmoid
* add sigmoid tests
* we should have a specialized op for this
* fix clippy
* fix clippy 2
* Revert all previous commits in favor of a `CustomOp` based solution
* use `CustomOp1` implementation
* fix rustfmt
* experimental add metal impl
* add cuda kernel impl
* fix fmt
* Add a test + reduce some cuda duplication.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/ops.rs | 188 | ||||
-rw-r--r-- | candle-nn/tests/ops.rs | 11 |
2 files changed, 197 insertions, 2 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 7fc26c3f..eabc95d8 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> { &xs[0].silu()? * &xs[1] } +struct Sigmoid; + +impl candle::CustomOp1 for Sigmoid { + fn name(&self) -> &'static str { + "sigmoid" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + fn fwd<T: num_traits::Float>(v: T) -> T { + (v.neg().exp() + T::one()).recip() + } + + // FIXME: using `candle::map_dtype` causes compilation errors. + let storage = match storage { + CpuStorage::BF16(slice) => { + CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F16(slice) => { + CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F32(slice) => { + CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F64(slice) => { + CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd)) + } + _ => Err(candle::Error::UnsupportedDTypeForOp( + storage.dtype(), + self.name(), + ))?, + }; + Ok((storage, layout.shape().clone())) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &candle::CudaStorage, + layout: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::backend::BackendStorage; + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use candle::cuda_backend::SlicePtrOrNull; + use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S; + impl Map1 for S { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::<T>(el_count) }.w()?; + + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } + } + + let dev = storage.device(); + let slice = S.map(&storage.slice, dev, layout)?; + let dst = candle::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + use candle::MetalError; + let device = storage.device(); + let dtype = storage.dtype(); + let shape = layout.shape(); + let el_count = shape.elem_count(); + let buffer = device.new_buffer(el_count, dtype, "sigmoid")?; + let command_buffer = device.command_buffer()?; + command_buffer.set_label("sigmoid"); + let src = candle_metal_kernels::BufferOffset { + buffer: storage.buffer(), + offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(), + }; + + match (el_count % 2, dtype, layout.is_contiguous()) { + (0, DType::BF16 | DType::F16, true) => { + use candle_metal_kernels::unary::contiguous_tiled; + let kernel_name = match dtype { + DType::F16 => contiguous_tiled::sigmoid::HALF, + DType::F32 => contiguous_tiled::sigmoid::FLOAT, + DType::BF16 => contiguous_tiled::sigmoid::BFLOAT, + dtype => { + candle::bail!( + "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented" + ) + } + }; + candle_metal_kernels::call_unary_contiguous_tiled( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, true) => { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match dtype { + DType::F16 => contiguous::sigmoid::HALF, + DType::F32 => contiguous::sigmoid::FLOAT, + DType::BF16 => contiguous::sigmoid::BFLOAT, + dtype => { + candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented") + } + }; + candle_metal_kernels::call_unary_contiguous( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, false) => { + use candle_metal_kernels::unary::strided; + let kernel_name = match dtype { + DType::F16 => strided::sigmoid::HALF, + DType::F32 => strided::sigmoid::FLOAT, + DType::BF16 => strided::sigmoid::BFLOAT, + dtype => { + candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented") + } + }; + let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer); + candle_metal_kernels::call_unary_strided( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; + } + } + + let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype); + Ok((new_storage, layout.shape().clone())) + } + + fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> { + // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x) + let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?; + Ok(Some(grad_res.mul(&d_dx_sigmoid)?)) + } +} + pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { - // TODO: Should we have a specialized op for this? - (xs.neg()?.exp()? + 1.0)?.recip() + xs.apply_op1(Sigmoid) } pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> { diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 24a49d06..f9cfe46d 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -170,8 +170,19 @@ fn rope_thd(device: &Device) -> Result<()> { Ok(()) } +fn sigmoid(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let s1 = candle_nn::ops::sigmoid(&tensor)?; + let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?; + let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + assert_eq!(diff, 0.); + Ok(()) +} + test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); +test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); |