diff options
-rw-r--r-- | candle-core/src/metal_backend.rs | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 1 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 6 |
3 files changed, 11 insertions, 14 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0af11a3d..27b2824f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -59,6 +59,8 @@ impl From<String> for MetalError { } } +type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>; + #[derive(Clone)] pub struct MetalDevice { /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc> @@ -103,7 +105,7 @@ pub struct MetalDevice { /// /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// (strong_count = 1). - buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, + buffers: AllocatedBuffers, } impl std::fmt::Debug for MetalDevice { @@ -258,7 +260,7 @@ impl MetalDevice { let newbuffers = subbuffers .iter() .filter(|s| Arc::strong_count(s) > 1) - .map(|s| Arc::clone(s)) + .map(Arc::clone) .collect(); *subbuffers = newbuffers; } @@ -270,7 +272,7 @@ impl MetalDevice { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - descriptor.set_capture_device(&self); + descriptor.set_capture_device(self); descriptor.set_output_url(path); capture @@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage { &self.device.kernels, name, (b, m, n, k), - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, @@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }?; - Ok(Self::Storage::new( - buffer.into(), - self.clone(), - storage.dtype(), - )) + Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) } fn rand_uniform( diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8d5a2624..1b3153b1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -574,7 +574,6 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T let options = MTLResourceOptions::StorageModeManaged; let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); - let num_dims = 1; let dims = vec![v.len()]; let strides = vec![1]; call_reduce_strided( diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 816eff42..abe33350 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,17 +226,17 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, - &kernels, + kernels, name, elem_count, last_dim, storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), - &mut output, + &output, ) .unwrap(); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); |