diff options
-rw-r--r-- | candle-core/examples/llama/weights.rs | 2 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 14 |
2 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index cc3fccd4..c3364cef 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -105,7 +105,7 @@ impl Llama { ) -> Result<Self> { let handles: Vec<_> = filenames .iter() - .map(candle::safetensors::MmapedFile::new) + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) .collect::<Result<Vec<_>>>()?; let tensors: Vec<_> = handles .iter() diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b80a756a..99e11c60 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -12,6 +12,10 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T Tensor::from_slice(data, view.shape(), device) } else { let mut c = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. unsafe { std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); c.set_len(elem_count) @@ -42,9 +46,15 @@ pub struct SafeTensors<'a>(st::SafeTensors<'a>); pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { - pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + /// Creates a wrapper around a memory mapped file from which you can retrieve + /// tensors using [`MmapedFile::deserialize`] + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + let mmap = memmap2::MmapOptions::new().map(&file)?; Ok(Self(mmap)) } |