summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-05 07:22:46 +0200
committerGitHub <noreply@github.com>2024-05-05 07:22:46 +0200
commit01794dc16ef8d896933d61e9bd9b8a981cd51930 (patch)
treebdf97cc5675952720a4348340b5d3c895381fd76 /candle-core
parenta75cd8164fb0b8377d101fc8526783d4abd18f12 (diff)
downloadcandle-01794dc16ef8d896933d61e9bd9b8a981cd51930.tar.gz
candle-01794dc16ef8d896933d61e9bd9b8a981cd51930.tar.bz2
candle-01794dc16ef8d896933d61e9bd9b8a981cd51930.zip
Use write rather than try-write on the metal rw-locks. (#2162)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/metal_backend/device.rs12
-rw-r--r--candle-core/src/metal_backend/mod.rs8
2 files changed, 13 insertions, 7 deletions
diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs
index 44af7649..785fe621 100644
--- a/candle-core/src/metal_backend/device.rs
+++ b/candle-core/src/metal_backend/device.rs
@@ -100,11 +100,11 @@ impl MetalDevice {
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
- let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
+ let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
- .try_write()
+ .write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
@@ -119,7 +119,7 @@ impl MetalDevice {
}
pub fn wait_until_completed(&self) -> Result<()> {
- let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
+ let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
@@ -179,7 +179,7 @@ impl MetalDevice {
size,
MTLResourceOptions::StorageModeManaged,
);
- let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
+ let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@@ -232,7 +232,7 @@ impl MetalDevice {
}
fn drop_unused_buffers(&self) -> Result<()> {
- let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
+ let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
@@ -251,7 +251,7 @@ impl MetalDevice {
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
- let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
+ let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index e00566ca..9273eda8 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
-use std::sync::{Arc, Mutex, RwLock, TryLockError};
+use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
@@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
}
}
+impl<T> From<PoisonError<T>> for MetalError {
+ fn from(p: PoisonError<T>) -> Self {
+ MetalError::LockError(LockError::Poisoned(p.to_string()))
+ }
+}
+
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {