diff options
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 2ddd610b..d2cc09f4 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -168,17 +168,22 @@ pub trait EncoderProvider { fn encoder(&self) -> Self::Encoder<'_>; } -pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); +pub struct WrappedEncoder<'a> { + inner: &'a ComputeCommandEncoderRef, + end_encoding_on_drop: bool, +} impl<'a> Drop for WrappedEncoder<'a> { fn drop(&mut self) { - self.0.end_encoding() + if self.end_encoding_on_drop { + self.inner.end_encoding() + } } } impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { - self.0 + self.inner } } @@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } } } @@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } + } +} + +impl EncoderProvider for &ComputeCommandEncoderRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self, + end_encoding_on_drop: false, + } } } |