summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-13 12:09:41 +0200
committerGitHub <noreply@github.com>2024-04-13 12:09:41 +0200
commit8b8fb630dfe46246a10052eb1ebd9ae0f35900e4 (patch)
tree7d7b977a48d317da6464752b1630597afd8d04a7 /candle-nn
parentfb805b8ca2c9413ad9227800328145434a08eaca (diff)
downloadcandle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.tar.gz
candle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.tar.bz2
candle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.zip
Add a convenient way to rename tensors accessed through a varbuilder. (#2052)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs93
1 files changed, 93 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 7de46044..5539370a 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -498,6 +498,53 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
+
+ /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
+ /// passing the new names to the inner VarBuilder.
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ ///
+ /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
+ /// let tensors: std::collections::HashMap<_, _> = [
+ /// ("foo".to_string(), a),
+ /// ]
+ /// .into_iter()
+ /// .collect();
+ /// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
+ /// assert!(vb.contains_tensor("foo"));
+ /// assert!(vb.get((2, 3), "foo").is_ok());
+ /// assert!(!vb.contains_tensor("bar"));
+ /// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
+ /// assert!(vb.contains_tensor("bar"));
+ /// assert!(vb.contains_tensor("foo"));
+ /// assert!(vb.get((2, 3), "bar").is_ok());
+ /// assert!(vb.get((2, 3), "foo").is_ok());
+ /// assert!(!vb.contains_tensor("baz"));
+ /// # Ok::<(), candle::Error>(())
+ /// ```
+ pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
+ let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
+ self.rename(f)
+ }
+
+ pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
+ let dtype = self.dtype();
+ let device = self.device().clone();
+ let path = self.path.clone();
+ let backend = Rename::new(self, renamer);
+ let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
+ let data = TensorData {
+ backend,
+ dtype,
+ device,
+ };
+ Self {
+ data: Arc::new(data),
+ path,
+ _phantom: std::marker::PhantomData,
+ }
+ }
}
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
@@ -618,3 +665,49 @@ impl Backend for ShardedSafeTensors {
self.0.get(name).is_ok()
}
}
+
+/// This traits specifies a way to rename the queried names into names that are stored in an inner
+/// VarBuilder.
+pub trait Renamer {
+ /// This is applied to the name obtained by a name call and the resulting name is passed to the
+ /// inner VarBuilder.
+ fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
+}
+
+pub struct Rename<'a, R: Renamer> {
+ inner: VarBuilder<'a>,
+ renamer: R,
+}
+
+impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
+ fn get(
+ &self,
+ s: Shape,
+ name: &str,
+ h: crate::Init,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Tensor> {
+ let name = self.renamer.rename(name);
+ self.inner
+ .get_with_hints_dtype(s, &name, h, dtype)?
+ .to_device(dev)
+ }
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ let name = self.renamer.rename(name);
+ self.inner.contains_tensor(&name)
+ }
+}
+
+impl<'a, R: Renamer> Rename<'a, R> {
+ pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
+ Self { inner, renamer }
+ }
+}
+
+impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
+ fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
+ std::borrow::Cow::Owned(self(v))
+ }
+}