summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-21 13:18:42 +0200
committerGitHub <noreply@github.com>2024-09-21 13:18:42 +0200
commitaf2104078f25865a02e18712986ee4b988d7affb (patch)
tree1d1e636792958e09139a19499780bc37602d056a /candle-metal-kernels/src
parent5fc4f177273c0c2435f9faeb6c3c1ea3d92bdf4e (diff)
downloadcandle-af2104078f25865a02e18712986ee4b988d7affb.tar.gz
candle-af2104078f25865a02e18712986ee4b988d7affb.tar.bz2
candle-af2104078f25865a02e18712986ee4b988d7affb.zip
Metal commands refactoring (#2489)
* Split out the commands part of the metal device. * Make most fields private. * Move the allocator back. * Rework the encoder provider type.
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r--candle-metal-kernels/src/utils.rs33
1 files changed, 28 insertions, 5 deletions
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index 2ddd610b..d2cc09f4 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -168,17 +168,22 @@ pub trait EncoderProvider {
fn encoder(&self) -> Self::Encoder<'_>;
}
-pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
+pub struct WrappedEncoder<'a> {
+ inner: &'a ComputeCommandEncoderRef,
+ end_encoding_on_drop: bool,
+}
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
- self.0.end_encoding()
+ if self.end_encoding_on_drop {
+ self.inner.end_encoding()
+ }
}
}
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
- self.0
+ self.inner
}
}
@@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer {
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
- WrappedEncoder(self.new_compute_command_encoder())
+ WrappedEncoder {
+ inner: self.new_compute_command_encoder(),
+ end_encoding_on_drop: true,
+ }
}
}
@@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef {
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
- WrappedEncoder(self.new_compute_command_encoder())
+ WrappedEncoder {
+ inner: self.new_compute_command_encoder(),
+ end_encoding_on_drop: true,
+ }
+ }
+}
+
+impl EncoderProvider for &ComputeCommandEncoderRef {
+ type Encoder<'a> = WrappedEncoder<'a>
+ where
+ Self: 'a;
+ fn encoder(&self) -> Self::Encoder<'_> {
+ WrappedEncoder {
+ inner: self,
+ end_encoding_on_drop: false,
+ }
}
}