summaryrefslogtreecommitdiff
path: root/candle-nn/src/var_builder.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-26 11:16:04 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-07-27 09:58:47 +0200
commit7c7e6ba201d0270f5ac689c20f16f59e00ed4d01 (patch)
tree134efa4d30f7dcaac74b1ec858692c50671a5953 /candle-nn/src/var_builder.rs
parent1553b58fe59a29fe808b9b4d43a6502046ce26dd (diff)
downloadcandle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.gz
candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.bz2
candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.zip
Removing inner dependency on safetensors.
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r--candle-nn/src/var_builder.rs10
1 files changed, 6 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index b02d216b..1466f6d0 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,6 +1,5 @@
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
-use safetensors::slice::IndexOp;
-use safetensors::tensor::SafeTensors;
+use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
use std::sync::Arc;
@@ -70,7 +69,7 @@ impl<'a> TensorData<'a> {
#[derive(Clone)]
pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>,
- pub path: Vec<String>,
+ path: Vec<String>,
}
impl<'a> VarBuilder<'a> {
@@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> {
shape[dim] = block_size;
- Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
+ let dtype: DType = dtype.try_into()?;
+
+ let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+ Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
};