summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/tensor-tools.rs21
-rw-r--r--candle-core/src/quantized/gguf_file.rs12
2 files changed, 30 insertions, 3 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 19c9d0ca..d5f7dd57 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -6,6 +6,7 @@ enum Format {
Safetensors,
Npz,
Ggml,
+ Gguf,
Pth,
Pickle,
}
@@ -16,11 +17,12 @@ impl Format {
.extension()
.and_then(|e| e.to_str())
.and_then(|e| match e {
- // We don't infer any format for .bin as it can be used for ggml or pytorch.
+ // We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
"safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::Pth),
"ggml" => Some(Self::Ggml),
+ "gguf" => Some(Self::Gguf),
_ => None,
})
}
@@ -121,6 +123,23 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
}
}
+ Format::Gguf => {
+ let mut file = std::fs::File::open(file)?;
+ let content = candle_core::quantized::gguf_file::Content::read(&mut file)?;
+ if verbose {
+ let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
+ metadata.sort_by(|a, b| a.0.cmp(&b.0));
+ println!("metadata entries ({})", metadata.len());
+ for (key, value) in metadata.iter() {
+ println!(" {key}: {value:?}");
+ }
+ }
+ let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
+ tensors.sort_by(|a, b| a.0.cmp(&b.0));
+ for (name, info) in tensors.iter() {
+ println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
+ }
+ }
}
Ok(())
}
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index 781e3a8d..3f13b7de 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -7,7 +7,7 @@ use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
-pub const DEFAULT_ALIGNMENT: usize = 32;
+pub const DEFAULT_ALIGNMENT: u64 = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
@@ -208,7 +208,15 @@ impl Content {
);
}
let position = reader.stream_position()?;
- let alignment = DEFAULT_ALIGNMENT as u64;
+ let alignment = match metadata.get("general.alignment") {
+ Some(Value::U8(v)) => *v as u64,
+ Some(Value::U16(v)) => *v as u64,
+ Some(Value::U32(v)) => *v as u64,
+ Some(Value::I8(v)) if *v >= 0 => *v as u64,
+ Some(Value::I16(v)) if *v >= 0 => *v as u64,
+ Some(Value::I32(v)) if *v >= 0 => *v as u64,
+ _ => DEFAULT_ALIGNMENT,
+ };
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
Ok(Self {
magic,