summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/utils.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-07-24 15:29:56 +0100
committerGitHub <noreply@github.com>2024-07-24 16:29:56 +0200
commitddafc61055601002622778b7762c15bd60057c1f (patch)
tree5363cf002fb93d9c0368140c1775721aa06d98bd /candle-metal-kernels/src/utils.rs
parenta925ae6bc659d1b40570b5068b6913d38e75b12e (diff)
downloadcandle-ddafc61055601002622778b7762c15bd60057c1f.tar.gz
candle-ddafc61055601002622778b7762c15bd60057c1f.tar.bz2
candle-ddafc61055601002622778b7762c15bd60057c1f.zip
Use RAII for terminating the encoding. (#2353)
Diffstat (limited to 'candle-metal-kernels/src/utils.rs')
-rw-r--r--candle-metal-kernels/src/utils.rs40
1 files changed, 28 insertions, 12 deletions
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index 4ef2162c..b42bcff0 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -162,24 +162,40 @@ macro_rules! set_params {
}
pub trait EncoderProvider {
- fn encoder(&self) -> &ComputeCommandEncoderRef;
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
+ type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
}
-impl EncoderProvider for &metal::CommandBuffer {
- fn encoder(&self) -> &ComputeCommandEncoderRef {
- self.new_compute_command_encoder()
+pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
+
+impl<'a> Drop for WrappedEncoder<'a> {
+ fn drop(&mut self) {
+ self.0.end_encoding()
}
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
- enc.end_encoding()
+}
+
+impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
+ fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
+ &self.0
}
}
-impl EncoderProvider for &metal::CommandBufferRef {
- fn encoder(&self) -> &ComputeCommandEncoderRef {
- self.new_compute_command_encoder()
+impl EncoderProvider for &metal::CommandBuffer {
+ type Encoder<'a> = WrappedEncoder<'a>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
+ WrappedEncoder(self.new_compute_command_encoder())
}
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
- enc.end_encoding()
+}
+
+impl EncoderProvider for &metal::CommandBufferRef {
+ type Encoder<'a> = WrappedEncoder<'a>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
+ WrappedEncoder(self.new_compute_command_encoder())
}
}