diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/var_builder.rs | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index b02d216b..1466f6d0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,6 +1,5 @@ use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; -use safetensors::slice::IndexOp; -use safetensors::tensor::SafeTensors; +use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; use std::sync::Arc; @@ -70,7 +69,7 @@ impl<'a> TensorData<'a> { #[derive(Clone)] pub struct VarBuilder<'a> { data: Arc<TensorData<'a>>, - pub path: Vec<String>, + path: Vec<String>, } impl<'a> VarBuilder<'a> { @@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> { shape[dim] = block_size; - Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)? + let dtype: DType = dtype.try_into()?; + + let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); + Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } _ => unimplemented!(), }; |