diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-05 07:22:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-05 07:22:46 +0200 |
commit | 01794dc16ef8d896933d61e9bd9b8a981cd51930 (patch) | |
tree | bdf97cc5675952720a4348340b5d3c895381fd76 /candle-core | |
parent | a75cd8164fb0b8377d101fc8526783d4abd18f12 (diff) | |
download | candle-01794dc16ef8d896933d61e9bd9b8a981cd51930.tar.gz candle-01794dc16ef8d896933d61e9bd9b8a981cd51930.tar.bz2 candle-01794dc16ef8d896933d61e9bd9b8a981cd51930.zip |
Use write rather than try-write on the metal rw-locks. (#2162)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/metal_backend/device.rs | 12 | ||||
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 8 |
2 files changed, 13 insertions, 7 deletions
diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 44af7649..785fe621 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -100,11 +100,11 @@ impl MetalDevice { } pub fn command_buffer(&self) -> Result<CommandBuffer> { - let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?; let mut command_buffer = command_buffer_lock.to_owned(); let mut index = self .command_buffer_index - .try_write() + .write() .map_err(MetalError::from)?; if *index > self.compute_per_buffer { command_buffer.commit(); @@ -119,7 +119,7 @@ impl MetalDevice { } pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?; match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -179,7 +179,7 @@ impl MetalDevice { size, MTLResourceOptions::StorageModeManaged, ); - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; let subbuffers = buffers .entry((size, MTLResourceOptions::StorageModeManaged)) .or_insert(vec![]); @@ -232,7 +232,7 @@ impl MetalDevice { } fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; for subbuffers in buffers.values_mut() { let newbuffers = subbuffers .iter() @@ -251,7 +251,7 @@ impl MetalDevice { option: MTLResourceOptions, _name: &str, ) -> Result<Arc<Buffer>> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; if let Some(b) = self.find_available_buffer(size, option, &buffers) { // Cloning also ensures we increment the strong count return Ok(b.clone()); diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e00566ca..9273eda8 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; -use std::sync::{Arc, Mutex, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; mod device; pub use device::{DeviceId, MetalDevice}; @@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError { } } +impl<T> From<PoisonError<T>> for MetalError { + fn from(p: PoisonError<T>) -> Self { + MetalError::LockError(LockError::Poisoned(p.to_string())) + } +} + /// Metal related errors #[derive(thiserror::Error, Debug)] pub enum MetalError { |