summaryrefslogtreecommitdiff
path: root/candle-core/src/custom_op.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/custom_op.rs')
-rw-r--r--candle-core/src/custom_op.rs67
1 files changed, 67 insertions, 0 deletions
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs
index 3a85dba9..276e3658 100644
--- a/candle-core/src/custom_op.rs
+++ b/candle-core/src/custom_op.rs
@@ -375,3 +375,70 @@ impl Tensor {
)
}
}
+
+pub struct UgIOp1 {
+ name: &'static str,
+ #[cfg(feature = "cuda")]
+ func: cudarc::driver::CudaFunction,
+}
+
+impl UgIOp1 {
+ #[allow(unused)]
+ pub fn new(
+ name: &'static str,
+ kernel: ug::lang::ssa::Kernel,
+ device: &crate::Device,
+ ) -> Result<Self> {
+ #[cfg(feature = "cuda")]
+ {
+ let device = device.as_cuda_device()?;
+ let func = device.compile(name, kernel)?;
+ Ok(Self { name, func })
+ }
+ #[cfg(not(feature = "cuda"))]
+ {
+ Ok(Self { name })
+ }
+ }
+}
+
+impl InplaceOp1 for UgIOp1 {
+ fn name(&self) -> &'static str {
+ self.name
+ }
+
+ fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
+ crate::bail!("ug ops are only supported on cuda at the moment")
+ }
+
+ fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
+ crate::bail!("ug ops are only supported on cuda at the moment")
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
+ use crate::cuda_backend::WrapErr;
+ use cudarc::driver::LaunchAsync;
+
+ let elem_count = layout.shape().elem_count();
+ // TODO: support more dtypes.
+ let sto = sto.as_cuda_slice::<f32>()?;
+ let sto = match layout.contiguous_offsets() {
+ None => crate::bail!("input has to be contiguous"),
+ Some((o1, o2)) => sto.slice(o1..o2),
+ };
+ let params = (&sto,);
+ let (g, b) = if elem_count % 32 == 0 {
+ (elem_count / 32, 32)
+ } else {
+ (elem_count, 1)
+ };
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (g as u32, 1, 1),
+ block_dim: (b as u32, 1, 1),
+ shared_mem_bytes: 0,
+ };
+ unsafe { self.func.clone().launch(cfg, params) }.w()?;
+ Ok(())
+ }
+}