diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-30 19:31:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-30 19:31:14 +0200 |
commit | 683ab698def755c24cec9987069d25efcf831fc4 (patch) | |
tree | 84d0bd8ad2f5d7a00f67050c83520326d947b2fe /candle-nn | |
parent | 2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (diff) | |
download | candle-683ab698def755c24cec9987069d25efcf831fc4.tar.gz candle-683ab698def755c24cec9987069d25efcf831fc4.tar.bz2 candle-683ab698def755c24cec9987069d25efcf831fc4.zip |
Add Pixtral. (#2521)
* Add Pixtral.
* More pixtral vision encoder.
* Sketch a pixtral example.
* Sketch a pixtral example.
* Better image loading.
* Support loading images embedded in safetensor files.
* Clippy fixes.
* Add the llava multimodal adapter.
* Add more of the llava bits.
* Add the pixtral config.
* More pixtral inference.
* Add the text generation bits.
* Get the example to work.
* Bugfix.
* Run some bits of the model in f32.
* Blessed version :)
* Better rope frequency computations.
* README update.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index f6e6160b..00669468 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -14,6 +14,7 @@ use std::sync::Arc; pub struct VarBuilderArgs<'a, B: Backend> { data: Arc<TensorData<B>>, path: Vec<String>, + pub dtype: DType, _phantom: std::marker::PhantomData<&'a B>, } @@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: self.path.clone(), + dtype: self.dtype, _phantom: self._phantom, } } @@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>; struct TensorData<B: Backend> { backend: B, - pub dtype: DType, pub device: Device, } @@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { backend, - dtype, device: dev.clone(), }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![prefix.to_string()], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path, + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { /// The dtype used by default. pub fn dtype(&self) -> DType { - self.data.dtype + self.dtype + } + + /// Clone the VarBuilder tweaking its dtype + pub fn to_dtype(&self, dtype: DType) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + dtype, + _phantom: std::marker::PhantomData, + } } fn path(&self, tensor_name: &str) -> String { @@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result<Tensor> { - self.get_with_hints_dtype(s, name, hints, self.data.dtype) + self.get_with_hints_dtype(s, name, hints, self.dtype) } /// Retrieve the tensor associated with the given name at the current path. @@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> { dtype: DType, device: Device, ) -> Self { - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> { let path = self.path.clone(); let backend = Rename::new(self, renamer); let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend); - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), + dtype, path, _phantom: std::marker::PhantomData, } |