diff options
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r-- | candle-nn/src/var_builder.rs | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index f6e6160b..00669468 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -14,6 +14,7 @@ use std::sync::Arc; pub struct VarBuilderArgs<'a, B: Backend> { data: Arc<TensorData<B>>, path: Vec<String>, + pub dtype: DType, _phantom: std::marker::PhantomData<&'a B>, } @@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: self.path.clone(), + dtype: self.dtype, _phantom: self._phantom, } } @@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>; struct TensorData<B: Backend> { backend: B, - pub dtype: DType, pub device: Device, } @@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { backend, - dtype, device: dev.clone(), }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path: vec![prefix.to_string()], + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: self.data.clone(), path, + dtype: self.dtype, _phantom: std::marker::PhantomData, } } @@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { /// The dtype used by default. pub fn dtype(&self) -> DType { - self.data.dtype + self.dtype + } + + /// Clone the VarBuilder tweaking its dtype + pub fn to_dtype(&self, dtype: DType) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + dtype, + _phantom: std::marker::PhantomData, + } } fn path(&self, tensor_name: &str) -> String { @@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result<Tensor> { - self.get_with_hints_dtype(s, name, hints, self.data.dtype) + self.get_with_hints_dtype(s, name, hints, self.dtype) } /// Retrieve the tensor associated with the given name at the current path. @@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> { dtype: DType, device: Device, ) -> Self { - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), path: vec![], + dtype, _phantom: std::marker::PhantomData, } } @@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> { let path = self.path.clone(); let backend = Rename::new(self, renamer); let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend); - let data = TensorData { - backend, - dtype, - device, - }; + let data = TensorData { backend, device }; Self { data: Arc::new(data), + dtype, path, _phantom: std::marker::PhantomData, } |