summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/metal_backend.rs18
1 files changed, 8 insertions, 10 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 0af11a3d..27b2824f 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -59,6 +59,8 @@ impl From<String> for MetalError {
}
}
+type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
+
#[derive(Clone)]
pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
@@ -103,7 +105,7 @@ pub struct MetalDevice {
///
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1).
- buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
+ buffers: AllocatedBuffers,
}
impl std::fmt::Debug for MetalDevice {
@@ -258,7 +260,7 @@ impl MetalDevice {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(s) > 1)
- .map(|s| Arc::clone(s))
+ .map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
@@ -270,7 +272,7 @@ impl MetalDevice {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
- descriptor.set_capture_device(&self);
+ descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
@@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage {
&self.device.kernels,
name,
(b, m, n, k),
- &lhs_l.stride(),
+ lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
- &rhs_l.stride(),
+ rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
@@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice {
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}?;
- Ok(Self::Storage::new(
- buffer.into(),
- self.clone(),
- storage.dtype(),
- ))
+ Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
}
fn rand_uniform(