diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-03 12:00:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-03 12:00:35 +0200 |
commit | d0d530dfdce04d5fb656b10b4eb1bfd26dea37e8 (patch) | |
tree | 0d12f1d27dbb0996def62dc316415b5f9a760c8a | |
parent | 48089005f6585c2f603f0d930a4a94b5aa18402f (diff) | |
parent | 81cec86e758390e5f025a6e93888673b003fb4c8 (diff) | |
download | candle-d0d530dfdce04d5fb656b10b4eb1bfd26dea37e8.tar.gz candle-d0d530dfdce04d5fb656b10b4eb1bfd26dea37e8.tar.bz2 candle-d0d530dfdce04d5fb656b10b4eb1bfd26dea37e8.zip |
Merge pull request #59 from LaurentMazare/safety
Adding a bit more docs around safety.
-rw-r--r-- | candle-core/examples/llama/weights.rs | 2 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 14 |
2 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index cc3fccd4..c3364cef 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -105,7 +105,7 @@ impl Llama { ) -> Result<Self> { let handles: Vec<_> = filenames .iter() - .map(candle::safetensors::MmapedFile::new) + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) .collect::<Result<Vec<_>>>()?; let tensors: Vec<_> = handles .iter() diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b80a756a..99e11c60 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -12,6 +12,10 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T Tensor::from_slice(data, view.shape(), device) } else { let mut c = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. unsafe { std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); c.set_len(elem_count) @@ -42,9 +46,15 @@ pub struct SafeTensors<'a>(st::SafeTensors<'a>); pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { - pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + /// Creates a wrapper around a memory mapped file from which you can retrieve + /// tensors using [`MmapedFile::deserialize`] + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + let mmap = memmap2::MmapOptions::new().map(&file)?; Ok(Self(mmap)) } |