diff options
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 22 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 16 | ||||
-rw-r--r-- | candle-nn/src/func.rs | 27 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 4 |
5 files changed, 69 insertions, 1 deletions
@@ -84,6 +84,7 @@ And then head over to - Whisper (multi-lingual support). - Stable Diffusion. - Computer Vision: DINOv2. +- File formats: load models from safetensors, npz, ggml, or PyTorch files. - Serverless (on CPU), small and fast deployments. - Quantization support using the llama.cpp quantized types. diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4ea66186..050e593a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -487,6 +487,28 @@ impl Tensor { self.to_scalar::<S>() } + /// Repeat this tensor along the specified dimensions. + pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { + // Similar to PyTorch, we extend the number of dimensions of self if needed. + let repeats = shape.into(); + let repeats = repeats.dims(); + let mut inp = if self.rank() < repeats.len() { + let mut shape = self.dims().to_vec(); + while shape.len() < repeats.len() { + shape.push(1) + } + self.reshape(shape)? + } else { + self.clone() + }; + for (idx, &repeat) in repeats.iter().enumerate() { + if repeat > 1 { + inp = Tensor::cat(&vec![&inp; repeat], idx)? + } + } + Ok(inp) + } + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// The input values `mul` and `add` are casted to the appropriate type so some rounding might /// be performed. diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 5057d2ef..df9818ab 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -141,3 +141,19 @@ pub fn conv2d( let bs = vs.get_or_init(out_channels, "bias", init_bs)?; Ok(Conv2d::new(ws, Some(bs), cfg)) } + +pub fn conv2d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv2dConfig, + vs: crate::VarBuilder, +) -> Result<Conv2d> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init( + (out_channels, in_channels, kernel_size, kernel_size), + "weight", + init_ws, + )?; + Ok(Conv2d::new(ws, None, cfg)) +} diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs new file mode 100644 index 00000000..c9ec287a --- /dev/null +++ b/candle-nn/src/func.rs @@ -0,0 +1,27 @@ +//! Layers defined by closures. +use candle::{Result, Tensor}; + +/// A layer defined by a simple closure. +pub struct Func<'a> { + #[allow(clippy::type_complexity)] + f: Box<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send>, +} + +impl<'a> std::fmt::Debug for Func<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func<'a, F>(f: F) -> Func<'a> +where + F: 'a + Fn(&Tensor) -> Result<Tensor> + Send, +{ + Func { f: Box::new(f) } +} + +impl<'a> super::Module for Func<'a> { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + (*self.f)(xs) + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index e195ac67..34e2dbed 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -4,6 +4,7 @@ pub mod activation; pub mod batch_norm; pub mod conv; pub mod embedding; +pub mod func; pub mod group_norm; pub mod init; pub mod layer_norm; @@ -16,8 +17,9 @@ pub mod var_map; pub use activation::Activation; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; -pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; +pub use conv::{conv1d, conv2d, conv2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; pub use embedding::{embedding, Embedding}; +pub use func::{func, Func}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; |