summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-26 11:16:04 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-07-27 09:58:47 +0200
commit7c7e6ba201d0270f5ac689c20f16f59e00ed4d01 (patch)
tree134efa4d30f7dcaac74b1ec858692c50671a5953
parent1553b58fe59a29fe808b9b4d43a6502046ce26dd (diff)
downloadcandle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.gz
candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.bz2
candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.zip
Removing inner dependency on safetensors.
-rw-r--r--candle-core/src/safetensors.rs27
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs23
-rw-r--r--candle-nn/src/var_builder.rs10
-rw-r--r--candle-wasm-examples/whisper/Cargo.toml2
4 files changed, 30 insertions, 32 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index e81fe184..dee57b37 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -1,7 +1,6 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType};
-use safetensors::slice::SliceIterator;
use safetensors::tensor as st;
-use safetensors::tensor::{Dtype, SafeTensors};
+use safetensors::tensor::SafeTensors;
use std::borrow::Cow;
impl From<DType> for st::Dtype {
@@ -118,26 +117,24 @@ impl<'a> Load for st::TensorView<'a> {
}
impl Tensor {
- pub fn from_safetensors_slice(
- iterator: SliceIterator,
- dtype: Dtype,
+ pub fn from_raw_buffer(
+ data: &[u8],
+ dtype: DType,
shape: &[usize],
device: &Device,
) -> Result<Self> {
- let data: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
match dtype {
- st::Dtype::U8 => convert_slice::<u8>(&data, shape, device),
- st::Dtype::U32 => convert_slice::<u8>(&data, shape, device),
- st::Dtype::BF16 => convert_slice::<half::bf16>(&data, shape, device),
- st::Dtype::F16 => convert_slice::<half::f16>(&data, shape, device),
- st::Dtype::F32 => convert_slice::<f32>(&data, shape, device),
- st::Dtype::F64 => convert_slice::<f64>(&data, shape, device),
- dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
+ DType::U8 => convert_slice::<u8>(data, shape, device),
+ DType::U32 => convert_slice::<u32>(data, shape, device),
+ DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
+ DType::F16 => convert_slice::<half::f16>(data, shape, device),
+ DType::F32 => convert_slice::<f32>(data, shape, device),
+ DType::F64 => convert_slice::<f64>(data, shape, device),
}
}
}
-pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
+fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
@@ -149,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
}
}
-pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
+fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
let tensor = tensor.flatten_all()?;
match tensor.dtype() {
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index e902734f..becaa879 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::collections::HashMap;
-use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
@@ -24,11 +23,11 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
- comm: Rc<Comm>,
+ comm: Arc<Comm>,
}
struct AllReduce {
- comm: Rc<Comm>,
+ comm: Arc<Comm>,
}
impl CustomOp1 for AllReduce {
@@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce {
}
}
-fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
+fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> {
x.custom_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
- fn new(linear: Linear, comm: Rc<Comm>) -> Self {
+ fn new(linear: Linear, comm: Arc<Comm>) -> Self {
Self { linear, comm }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
@@ -76,14 +75,14 @@ impl TensorParallelRowLinear {
}
impl TensorParallelColumnLinear {
- fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
+ fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None)))
}
- fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
+ fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weights: Vec<_> = prefixes
@@ -96,7 +95,7 @@ impl TensorParallelColumnLinear {
}
impl TensorParallelRowLinear {
- fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
+ fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?;
@@ -339,7 +338,7 @@ impl CausalSelfAttention {
}
}
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(),
&["q_proj", "k_proj", "v_proj"],
@@ -388,7 +387,7 @@ impl Mlp {
self.c_proj.forward(&x)
}
- fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
+ fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
@@ -422,7 +421,7 @@ impl Block {
Ok(x)
}
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
@@ -466,7 +465,7 @@ impl Llama {
logits.to_dtype(DType::F32)
}
- pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
+ pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index b02d216b..1466f6d0 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,6 +1,5 @@
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
-use safetensors::slice::IndexOp;
-use safetensors::tensor::SafeTensors;
+use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
use std::sync::Arc;
@@ -70,7 +69,7 @@ impl<'a> TensorData<'a> {
#[derive(Clone)]
pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>,
- pub path: Vec<String>,
+ path: Vec<String>,
}
impl<'a> VarBuilder<'a> {
@@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> {
shape[dim] = block_size;
- Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
+ let dtype: DType = dtype.try_into()?;
+
+ let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+ Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
};
diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml
index 4ebb2788..b51d4052 100644
--- a/candle-wasm-examples/whisper/Cargo.toml
+++ b/candle-wasm-examples/whisper/Cargo.toml
@@ -15,7 +15,6 @@ candle = { path = "../../candle-core" }
candle-nn = { path = "../../candle-nn" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
-safetensors = { workspace = true }
# App crates.
anyhow = { workspace = true }
@@ -24,6 +23,7 @@ rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
wav = { workspace = true }
+safetensors = { workspace = true }
# Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] }