diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-07-24 15:29:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-24 16:29:56 +0200 |
commit | ddafc61055601002622778b7762c15bd60057c1f (patch) | |
tree | 5363cf002fb93d9c0368140c1775721aa06d98bd /candle-metal-kernels/src/utils.rs | |
parent | a925ae6bc659d1b40570b5068b6913d38e75b12e (diff) | |
download | candle-ddafc61055601002622778b7762c15bd60057c1f.tar.gz candle-ddafc61055601002622778b7762c15bd60057c1f.tar.bz2 candle-ddafc61055601002622778b7762c15bd60057c1f.zip |
Use RAII for terminating the encoding. (#2353)
Diffstat (limited to 'candle-metal-kernels/src/utils.rs')
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 40 |
1 files changed, 28 insertions, 12 deletions
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 4ef2162c..b42bcff0 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -162,24 +162,40 @@ macro_rules! set_params { } pub trait EncoderProvider { - fn encoder(&self) -> &ComputeCommandEncoderRef; - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef); + type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a>; } -impl EncoderProvider for &metal::CommandBuffer { - fn encoder(&self) -> &ComputeCommandEncoderRef { - self.new_compute_command_encoder() +pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); + +impl<'a> Drop for WrappedEncoder<'a> { + fn drop(&mut self) { + self.0.end_encoding() } - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { - enc.end_encoding() +} + +impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { + fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { + &self.0 } } -impl EncoderProvider for &metal::CommandBufferRef { - fn encoder(&self) -> &ComputeCommandEncoderRef { - self.new_compute_command_encoder() +impl EncoderProvider for &metal::CommandBuffer { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + WrappedEncoder(self.new_compute_command_encoder()) } - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { - enc.end_encoding() +} + +impl EncoderProvider for &metal::CommandBufferRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + WrappedEncoder(self.new_compute_command_encoder()) } } |