diff options
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r-- | candle-nn/src/var_builder.rs | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index c593960b..c372897a 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -52,6 +52,8 @@ pub trait Backend { dtype: DType, dev: &Device, ) -> Result<Tensor>; + + fn contains_tensor(&self, name: &str) -> bool; } pub trait SimpleBackend { @@ -64,6 +66,8 @@ pub trait SimpleBackend { dtype: DType, dev: &Device, ) -> Result<Tensor>; + + fn contains_tensor(&self, name: &str) -> bool; } impl<'a> Backend for Box<dyn SimpleBackend + 'a> { @@ -78,6 +82,10 @@ impl<'a> Backend for Box<dyn SimpleBackend + 'a> { ) -> Result<Tensor> { self.as_ref().get(s, name, h, dtype, dev) } + + fn contains_tensor(&self, name: &str) -> bool { + self.as_ref().contains_tensor(name) + } } impl<'a, B: Backend> VarBuilderArgs<'a, B> { @@ -94,6 +102,8 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { } } + /// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd` + /// into a directory. pub fn push_prefix<S: ToString>(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); @@ -109,10 +119,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { self.push_prefix(s) } + /// The device used by default. pub fn device(&self) -> &Device { &self.data.device } + /// The dtype used by default. pub fn dtype(&self) -> DType { self.data.dtype } @@ -125,6 +137,14 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { } } + /// This returns true only if a tensor with the passed in name is available. E.g. when passed + /// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b` + /// exists. + pub fn contains_tensor(&self, tensor_name: &str) -> bool { + let path = self.path(tensor_name); + self.data.backend.contains_tensor(&path) + } + /// Retrieve the tensor associated with the given name at the current path. pub fn get_with_hints<S: Into<Shape>>( &self, @@ -149,6 +169,10 @@ impl SimpleBackend for Zeros { fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> { Tensor::zeros(s, dtype, dev) } + + fn contains_tensor(&self, _name: &str) -> bool { + true + } } impl SimpleBackend for HashMap<String, Tensor> { @@ -179,6 +203,10 @@ impl SimpleBackend for HashMap<String, Tensor> { } tensor.to_device(dev)?.to_dtype(dtype) } + + fn contains_tensor(&self, name: &str) -> bool { + self.contains_key(name) + } } impl SimpleBackend for VarMap { @@ -192,6 +220,10 @@ impl SimpleBackend for VarMap { ) -> Result<Tensor> { VarMap::get(self, s, name, h, dtype, dev) } + + fn contains_tensor(&self, name: &str) -> bool { + self.data().lock().unwrap().contains_key(name) + } } struct SafeTensorWithRouting<'a> { @@ -228,6 +260,10 @@ impl<'a> SimpleBackend for SafeTensorWithRouting<'a> { } Ok(tensor) } + + fn contains_tensor(&self, name: &str) -> bool { + self.routing.contains_key(name) + } } impl SimpleBackend for candle::npy::NpzTensors { @@ -257,6 +293,10 @@ impl SimpleBackend for candle::npy::NpzTensors { } Ok(tensor) } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).map_or(false, |v| v.is_some()) + } } impl<'a> VarBuilder<'a> { @@ -425,4 +465,8 @@ impl<'a> Backend for ShardedSafeTensors<'a> { let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype) } + + fn contains_tensor(&self, name: &str) -> bool { + self.0.routing.contains_key(name) + } } |