diff options
author | yinqiwen <yinqiwen@gmail.com> | 2024-04-01 18:10:08 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-01 12:10:08 +0200 |
commit | 5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8 (patch) | |
tree | b3601f0478fcc10b2bb97ae2bf9a77f27325ae1e /candle-nn | |
parent | 888c09a3dbf8413c3aa76076e49cf52460334bbd (diff) | |
download | candle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.tar.gz candle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.tar.bz2 candle-5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8.zip |
Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index bf090219..7de46044 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result<Tensor> { - let path = self.path(name); - self.data - .backend - .get(s.into(), &path, hints, self.data.dtype, &self.data.device) + self.get_with_hints_dtype(s, name, hints, self.data.dtype) } /// Retrieve the tensor associated with the given name at the current path. pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> { self.get_with_hints(s, name, Default::default()) } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_with_hints_dtype<S: Into<Shape>>( + &self, + s: S, + name: &str, + hints: B::Hints, + dtype: DType, + ) -> Result<Tensor> { + let path = self.path(name); + self.data + .backend + .get(s.into(), &path, hints, dtype, &self.data.device) + } } struct Zeros; |