summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/llama/weights.rs2
-rw-r--r--candle-core/src/safetensors.rs14
2 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs
index cc3fccd4..c3364cef 100644
--- a/candle-core/examples/llama/weights.rs
+++ b/candle-core/examples/llama/weights.rs
@@ -105,7 +105,7 @@ impl Llama {
) -> Result<Self> {
let handles: Vec<_> = filenames
.iter()
- .map(candle::safetensors::MmapedFile::new)
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index b80a756a..99e11c60 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -12,6 +12,10 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
Tensor::from_slice(data, view.shape(), device)
} else {
let mut c = Vec::with_capacity(elem_count);
+ // SAFETY: We just created c, so the allocated memory is necessarily
+ // contiguous and non overlapping with the view's data.
+ // We're downgrading the `c` pointer from T to u8, which removes alignment
+ // constraints.
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
c.set_len(elem_count)
@@ -42,9 +46,15 @@ pub struct SafeTensors<'a>(st::SafeTensors<'a>);
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
- pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ /// Creates a wrapper around a memory mapped file from which you can retrieve
+ /// tensors using [`MmapedFile::deserialize`]
+ ///
+ /// # Safety
+ ///
+ /// The unsafe is inherited from [`memmap2::MmapOptions`].
+ pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
- let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
+ let mmap = memmap2::MmapOptions::new().map(&file)?;
Ok(Self(mmap))
}