1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
|
//! Varbuilder for Loading gguf files
//!
//! VarBuilder is a utility to store quantized tensors from a [GGUF model file](https://huggingface.co/docs/hub/gguf).
//! These tensors can be loaded from disk using `from_gguf` or from an in-memory
//! buffer using `from_gguf_buffer`.
use candle::quantized::QTensor;
use candle::{Device, Result, Shape};
use std::sync::Arc;
// VarBuilder specialized for QTensors
#[derive(Clone)]
pub struct VarBuilder {
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
path: Vec<String>,
device: Device,
}
impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: device.clone(),
})
}
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: device.clone(),
})
}
pub fn pp<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
data: self.data.clone(),
path,
device: self.device.clone(),
}
}
fn path(&self, tensor_name: &str) -> String {
if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
}
}
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {path}")
}
Some(qtensor) => {
let shape = s.into();
if qtensor.shape() != &shape {
candle::bail!(
"shape mismatch for {name}, got {:?}, expected {shape:?}",
qtensor.shape()
)
}
Ok(qtensor.clone())
}
}
}
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
}
Some(qtensor) => Ok(qtensor.clone()),
}
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn contains_key(&self, key: &str) -> bool {
self.data.contains_key(key)
}
}
|