summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs36
1 files changed, 22 insertions, 14 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index f6e6160b..00669468 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -14,6 +14,7 @@ use std::sync::Arc;
pub struct VarBuilderArgs<'a, B: Backend> {
data: Arc<TensorData<B>>,
path: Vec<String>,
+ pub dtype: DType,
_phantom: std::marker::PhantomData<&'a B>,
}
@@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: self.path.clone(),
+ dtype: self.dtype,
_phantom: self._phantom,
}
}
@@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
struct TensorData<B: Backend> {
backend: B,
- pub dtype: DType,
pub device: Device,
}
@@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
let data = TensorData {
backend,
- dtype,
device: dev.clone(),
};
Self {
data: Arc::new(data),
path: vec![],
+ dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![],
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![prefix.to_string()],
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path,
+ dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
/// The dtype used by default.
pub fn dtype(&self) -> DType {
- self.data.dtype
+ self.dtype
+ }
+
+ /// Clone the VarBuilder tweaking its dtype
+ pub fn to_dtype(&self, dtype: DType) -> Self {
+ Self {
+ data: self.data.clone(),
+ path: self.path.clone(),
+ dtype,
+ _phantom: std::marker::PhantomData,
+ }
}
fn path(&self, tensor_name: &str) -> String {
@@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
- self.get_with_hints_dtype(s, name, hints, self.data.dtype)
+ self.get_with_hints_dtype(s, name, hints, self.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
@@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
dtype: DType,
device: Device,
) -> Self {
- let data = TensorData {
- backend,
- dtype,
- device,
- };
+ let data = TensorData { backend, device };
Self {
data: Arc::new(data),
path: vec![],
+ dtype,
_phantom: std::marker::PhantomData,
}
}
@@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
- let data = TensorData {
- backend,
- dtype,
- device,
- };
+ let data = TensorData { backend, device };
Self {
data: Arc::new(data),
+ dtype,
path,
_phantom: std::marker::PhantomData,
}