//! Varbuilder for Loading gguf files //! //! VarBuilder is a utility to store quantized tensors from a [GGUF model file](https://huggingface.co/docs/hub/gguf). //! These tensors can be loaded from disk using `from_gguf` or from an in-memory //! buffer using `from_gguf_buffer`. use candle::quantized::QTensor; use candle::{Device, Result, Shape}; use std::sync::Arc; // VarBuilder specialized for QTensors #[derive(Clone)] pub struct VarBuilder { data: Arc>>, path: Vec, device: Device, } impl VarBuilder { pub fn from_gguf>(p: P, device: &Device) -> Result { let mut file = std::fs::File::open(p)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { let tensor = content.tensor(&mut file, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), device: device.clone(), }) } pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result { let mut cursor = std::io::Cursor::new(buffer); let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { let tensor = content.tensor(&mut cursor, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), device: device.clone(), }) } pub fn pp(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); Self { data: self.data.clone(), path, device: self.device.clone(), } } fn path(&self, tensor_name: &str) -> String { if self.path.is_empty() { tensor_name.to_string() } else { [&self.path.join("."), tensor_name].join(".") } } pub fn get>(&self, s: S, name: &str) -> Result> { let path = self.path(name); match self.data.get(&path) { None => { candle::bail!("cannot find tensor {path}") } Some(qtensor) => { let shape = s.into(); if qtensor.shape() != &shape { candle::bail!( "shape mismatch for {name}, got {:?}, expected {shape:?}", qtensor.shape() ) } Ok(qtensor.clone()) } } } pub fn get_no_shape(&self, name: &str) -> Result> { let path = self.path(name); match self.data.get(&path) { None => { candle::bail!("cannot find tensor {name}") } Some(qtensor) => Ok(qtensor.clone()), } } pub fn device(&self) -> &Device { &self.device } pub fn contains_key(&self, key: &str) -> bool { self.data.contains_key(key) } }