summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/metal.rs
diff options
context:
space:
mode:
authorivarflakstad <69173633+ivarflakstad@users.noreply.github.com>2024-03-07 09:42:34 +0100
committerGitHub <noreply@github.com>2024-03-07 09:42:34 +0100
commit0c09d10f320df7c23fc231f7f400967b03d9b9da (patch)
tree194297a74d4f1cf695a32a9bee6c518e607adfe1 /candle-core/src/quantized/metal.rs
parent8a99cf7dd2e0d2ff9cb18232272dad380d887f2d (diff)
downloadcandle-0c09d10f320df7c23fc231f7f400967b03d9b9da.tar.gz
candle-0c09d10f320df7c23fc231f7f400967b03d9b9da.tar.bz2
candle-0c09d10f320df7c23fc231f7f400967b03d9b9da.zip
Improve metal buffer usage (#1807)
* Improve metal buffer usage * Clone cpu storage when loading to reduce wait_until_complete calls * Use powers of two for buffer sizes so reuse is more likely. * Select best available buffer by size. * Add count to MetalStorage -> can use buffer with different size Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co> * Simplify new buffer creation without blit copy. Revert &[] -> Vec * Add documentation on newBufferWithBytes safety / synchronization * Drop unused buffers after command buffer is done syncing. --------- Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>
Diffstat (limited to 'candle-core/src/quantized/metal.rs')
-rw-r--r--candle-core/src/quantized/metal.rs9
1 files changed, 7 insertions, 2 deletions
diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs
index af1cf369..7be0f74e 100644
--- a/candle-core/src/quantized/metal.rs
+++ b/candle-core/src/quantized/metal.rs
@@ -106,7 +106,12 @@ impl QMetalStorage {
}
let buffer = self.device.new_buffer_with_data(&out)?;
- Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
+ Ok(MetalStorage::new(
+ buffer,
+ self.device.clone(),
+ elem_count,
+ DType::F32,
+ ))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
@@ -170,7 +175,7 @@ impl QMetalStorage {
&dst,
)
.map_err(MetalError::from)?;
- let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
+ let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
}