summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorMilkFather <31627231+MilkFather@users.noreply.github.com>2024-04-29 17:04:43 +0800
committerGitHub <noreply@github.com>2024-04-29 11:04:43 +0200
commit3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (patch)
tree6077a6f41f7b1ed97b6f44d5e8305126c0d5f5a5 /candle-nn
parented7b99f525ab898aa677fe1f4446e345ac74f4ec (diff)
downloadcandle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.tar.gz
candle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.tar.bz2
candle-3bbb88fcb463a6bdbb0e71c7b2d211dd02681493.zip
Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs188
-rw-r--r--candle-nn/tests/ops.rs11
2 files changed, 197 insertions, 2 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 7fc26c3f..eabc95d8 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
&xs[0].silu()? * &xs[1]
}
+struct Sigmoid;
+
+impl candle::CustomOp1 for Sigmoid {
+ fn name(&self) -> &'static str {
+ "sigmoid"
+ }
+
+ fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
+ use candle::backend::BackendStorage;
+
+ fn fwd<T: num_traits::Float>(v: T) -> T {
+ (v.neg().exp() + T::one()).recip()
+ }
+
+ // FIXME: using `candle::map_dtype` causes compilation errors.
+ let storage = match storage {
+ CpuStorage::BF16(slice) => {
+ CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
+ }
+ CpuStorage::F16(slice) => {
+ CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
+ }
+ CpuStorage::F32(slice) => {
+ CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
+ }
+ CpuStorage::F64(slice) => {
+ CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
+ }
+ _ => Err(candle::Error::UnsupportedDTypeForOp(
+ storage.dtype(),
+ self.name(),
+ ))?,
+ };
+ Ok((storage, layout.shape().clone()))
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ storage: &candle::CudaStorage,
+ layout: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
+ };
+ use candle::cuda_backend::SlicePtrOrNull;
+ use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ struct S;
+ impl Map1 for S {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el_count as u32);
+ let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
+
+ let params = (el_count, dims.len(), &ds, src, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+ }
+
+ let dev = storage.device();
+ let slice = S.map(&storage.slice, dev, layout)?;
+ let dst = candle::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, layout.shape().clone()))
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ use candle::MetalError;
+ let device = storage.device();
+ let dtype = storage.dtype();
+ let shape = layout.shape();
+ let el_count = shape.elem_count();
+ let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
+ let command_buffer = device.command_buffer()?;
+ command_buffer.set_label("sigmoid");
+ let src = candle_metal_kernels::BufferOffset {
+ buffer: storage.buffer(),
+ offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
+ };
+
+ match (el_count % 2, dtype, layout.is_contiguous()) {
+ (0, DType::BF16 | DType::F16, true) => {
+ use candle_metal_kernels::unary::contiguous_tiled;
+ let kernel_name = match dtype {
+ DType::F16 => contiguous_tiled::sigmoid::HALF,
+ DType::F32 => contiguous_tiled::sigmoid::FLOAT,
+ DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
+ dtype => {
+ candle::bail!(
+ "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
+ )
+ }
+ };
+ candle_metal_kernels::call_unary_contiguous_tiled(
+ device.metal_device(),
+ &command_buffer,
+ device.kernels(),
+ kernel_name,
+ el_count,
+ src,
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ }
+ (_, _, true) => {
+ use candle_metal_kernels::unary::contiguous;
+ let kernel_name = match dtype {
+ DType::F16 => contiguous::sigmoid::HALF,
+ DType::F32 => contiguous::sigmoid::FLOAT,
+ DType::BF16 => contiguous::sigmoid::BFLOAT,
+ dtype => {
+ candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
+ }
+ };
+ candle_metal_kernels::call_unary_contiguous(
+ device.metal_device(),
+ &command_buffer,
+ device.kernels(),
+ kernel_name,
+ el_count,
+ src,
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ }
+ (_, _, false) => {
+ use candle_metal_kernels::unary::strided;
+ let kernel_name = match dtype {
+ DType::F16 => strided::sigmoid::HALF,
+ DType::F32 => strided::sigmoid::FLOAT,
+ DType::BF16 => strided::sigmoid::BFLOAT,
+ dtype => {
+ candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
+ }
+ };
+ let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
+ candle_metal_kernels::call_unary_strided(
+ device.metal_device(),
+ &command_buffer,
+ device.kernels(),
+ kernel_name,
+ layout.dims(),
+ src,
+ layout.stride(),
+ dst,
+ )
+ .map_err(MetalError::from)?;
+ }
+ }
+
+ let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
+ Ok((new_storage, layout.shape().clone()))
+ }
+
+ fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
+ // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
+ let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
+ Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
+ }
+}
+
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
- // TODO: Should we have a specialized op for this?
- (xs.neg()?.exp()? + 1.0)?.recip()
+ xs.apply_op1(Sigmoid)
}
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index 24a49d06..f9cfe46d 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -170,8 +170,19 @@ fn rope_thd(device: &Device) -> Result<()> {
Ok(())
}
+fn sigmoid(device: &Device) -> Result<()> {
+ let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
+ let tensor = Tensor::new(data, device)?;
+ let s1 = candle_nn::ops::sigmoid(&tensor)?;
+ let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
+ let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ assert_eq!(diff, 0.);
+ Ok(())
+}
+
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
+test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);