diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-13 12:09:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-13 12:09:41 +0200 |
commit | 8b8fb630dfe46246a10052eb1ebd9ae0f35900e4 (patch) | |
tree | 7d7b977a48d317da6464752b1630597afd8d04a7 /candle-nn | |
parent | fb805b8ca2c9413ad9227800328145434a08eaca (diff) | |
download | candle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.tar.gz candle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.tar.bz2 candle-8b8fb630dfe46246a10052eb1ebd9ae0f35900e4.zip |
Add a convenient way to rename tensors accessed through a varbuilder. (#2052)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 7de46044..5539370a 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -498,6 +498,53 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } + + /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before + /// passing the new names to the inner VarBuilder. + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// + /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; + /// let tensors: std::collections::HashMap<_, _> = [ + /// ("foo".to_string(), a), + /// ] + /// .into_iter() + /// .collect(); + /// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu); + /// assert!(vb.contains_tensor("foo")); + /// assert!(vb.get((2, 3), "foo").is_ok()); + /// assert!(!vb.contains_tensor("bar")); + /// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() }); + /// assert!(vb.contains_tensor("bar")); + /// assert!(vb.contains_tensor("foo")); + /// assert!(vb.get((2, 3), "bar").is_ok()); + /// assert!(vb.get((2, 3), "foo").is_ok()); + /// assert!(!vb.contains_tensor("baz")); + /// # Ok::<(), candle::Error>(()) + /// ``` + pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self { + let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f); + self.rename(f) + } + + pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self { + let dtype = self.dtype(); + let device = self.device().clone(); + let path = self.path.clone(); + let backend = Rename::new(self, renamer); + let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend); + let data = TensorData { + backend, + dtype, + device, + }; + Self { + data: Arc::new(data), + path, + _phantom: std::marker::PhantomData, + } + } } pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors); @@ -618,3 +665,49 @@ impl Backend for ShardedSafeTensors { self.0.get(name).is_ok() } } + +/// This traits specifies a way to rename the queried names into names that are stored in an inner +/// VarBuilder. +pub trait Renamer { + /// This is applied to the name obtained by a name call and the resulting name is passed to the + /// inner VarBuilder. + fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>; +} + +pub struct Rename<'a, R: Renamer> { + inner: VarBuilder<'a>, + renamer: R, +} + +impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> { + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result<Tensor> { + let name = self.renamer.rename(name); + self.inner + .get_with_hints_dtype(s, &name, h, dtype)? + .to_device(dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + let name = self.renamer.rename(name); + self.inner.contains_tensor(&name) + } +} + +impl<'a, R: Renamer> Rename<'a, R> { + pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self { + Self { inner, renamer } + } +} + +impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> { + fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> { + std::borrow::Cow::Owned(self(v)) + } +} |