summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-30 19:31:14 +0200
committerGitHub <noreply@github.com>2024-09-30 19:31:14 +0200
commit683ab698def755c24cec9987069d25efcf831fc4 (patch)
tree84d0bd8ad2f5d7a00f67050c83520326d947b2fe /candle-nn
parent2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (diff)
downloadcandle-683ab698def755c24cec9987069d25efcf831fc4.tar.gz
candle-683ab698def755c24cec9987069d25efcf831fc4.tar.bz2
candle-683ab698def755c24cec9987069d25efcf831fc4.zip
Add Pixtral. (#2521)
* Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs36
1 files changed, 22 insertions, 14 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index f6e6160b..00669468 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -14,6 +14,7 @@ use std::sync::Arc;
pub struct VarBuilderArgs<'a, B: Backend> {
data: Arc<TensorData<B>>,
path: Vec<String>,
+ pub dtype: DType,
_phantom: std::marker::PhantomData<&'a B>,
}
@@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: self.path.clone(),
+ dtype: self.dtype,
_phantom: self._phantom,
}
}
@@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
struct TensorData<B: Backend> {
backend: B,
- pub dtype: DType,
pub device: Device,
}
@@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
let data = TensorData {
backend,
- dtype,
device: dev.clone(),
};
Self {
data: Arc::new(data),
path: vec![],
+ dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![],
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![prefix.to_string()],
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path,
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
/// The dtype used by default.
pub fn dtype(&self) -> DType {
- self.data.dtype
+ self.dtype
+ }
+
+ /// Clone the VarBuilder tweaking its dtype
+ pub fn to_dtype(&self, dtype: DType) -> Self {
+ Self {
+ data: self.data.clone(),
+ path: self.path.clone(),
+ dtype,
+ _phantom: std::marker::PhantomData,
+ }
}
fn path(&self, tensor_name: &str) -> String {
@@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
- self.get_with_hints_dtype(s, name, hints, self.data.dtype)
+ self.get_with_hints_dtype(s, name, hints, self.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
@@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
dtype: DType,
device: Device,
) -> Self {
- let data = TensorData {
- backend,
- dtype,
- device,
- };
+ let data = TensorData { backend, device };
Self {
data: Arc::new(data),
path: vec![],
+ dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
- let data = TensorData {
- backend,
- dtype,
- device,
- };
+ let data = TensorData { backend, device };
Self {
data: Arc::new(data),
+ dtype,
path,
_phantom: std::marker::PhantomData,
}