summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-nn/src/var_builder.rs41
1 files changed, 41 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 24832bc7..cbd238dd 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -191,6 +191,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
}
struct Zeros;
+
impl SimpleBackend for Zeros {
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
Tensor::zeros(s, dtype, dev)
@@ -325,6 +326,39 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
}
+impl SimpleBackend for candle::pickle::PthTensors {
+ fn get(
+ &self,
+ s: Shape,
+ path: &str,
+ _: crate::Init,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Tensor> {
+ let tensor = match self.get(path)? {
+ None => Err(Error::CannotFindTensor {
+ path: path.to_string(),
+ }
+ .bt())?,
+ Some(tensor) => tensor,
+ };
+ let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
+ if tensor.shape() != &s {
+ Err(candle::Error::UnexpectedShape {
+ msg: format!("shape mismatch for {path}"),
+ expected: s,
+ got: tensor.shape().clone(),
+ }
+ .bt())?
+ }
+ Ok(tensor)
+ }
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.get(name).map_or(false, |v| v.is_some())
+ }
+}
+
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
fn get(
&self,
@@ -438,9 +472,16 @@ impl<'a> VarBuilder<'a> {
let npz = candle::npy::NpzTensors::new(p)?;
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
}
+
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
+ pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
+ let pth = candle::pickle::PthTensors::new(p)?;
+ Ok(Self::new(Box::new(pth), dtype, dev.clone()))
+ }
}
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
+
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
impl ShardedSafeTensors {