summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authoryinqiwen <yinqiwen@gmail.com>2024-04-01 18:10:08 +0800
committerGitHub <noreply@github.com>2024-04-01 12:10:08 +0200
commit5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8 (patch)
treeb3601f0478fcc10b2bb97ae2bf9a77f27325ae1e /candle-nn
parent888c09a3dbf8413c3aa76076e49cf52460334bbd (diff)
downloadcandle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.tar.gz
candle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.tar.bz2
candle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.zip
Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs19
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index bf090219..7de46044 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
- let path = self.path(name);
- self.data
- .backend
- .get(s.into(), &path, hints, self.data.dtype, &self.data.device)
+ self.get_with_hints_dtype(s, name, hints, self.data.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
self.get_with_hints(s, name, Default::default())
}
+
+ /// Retrieve the tensor associated with the given name & dtype at the current path.
+ pub fn get_with_hints_dtype<S: Into<Shape>>(
+ &self,
+ s: S,
+ name: &str,
+ hints: B::Hints,
+ dtype: DType,
+ ) -> Result<Tensor> {
+ let path = self.path(name);
+ self.data
+ .backend
+ .get(s.into(), &path, hints, dtype, &self.data.device)
+ }
}
struct Zeros;