summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md1
-rw-r--r--candle-core/src/tensor.rs22
-rw-r--r--candle-nn/src/conv.rs16
-rw-r--r--candle-nn/src/func.rs27
-rw-r--r--candle-nn/src/lib.rs4
5 files changed, 69 insertions, 1 deletions
diff --git a/README.md b/README.md
index 7b98dca8..e20d328f 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,7 @@ And then head over to
- Whisper (multi-lingual support).
- Stable Diffusion.
- Computer Vision: DINOv2.
+- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 4ea66186..050e593a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -487,6 +487,28 @@ impl Tensor {
self.to_scalar::<S>()
}
+ /// Repeat this tensor along the specified dimensions.
+ pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
+ // Similar to PyTorch, we extend the number of dimensions of self if needed.
+ let repeats = shape.into();
+ let repeats = repeats.dims();
+ let mut inp = if self.rank() < repeats.len() {
+ let mut shape = self.dims().to_vec();
+ while shape.len() < repeats.len() {
+ shape.push(1)
+ }
+ self.reshape(shape)?
+ } else {
+ self.clone()
+ };
+ for (idx, &repeat) in repeats.iter().enumerate() {
+ if repeat > 1 {
+ inp = Tensor::cat(&vec![&inp; repeat], idx)?
+ }
+ }
+ Ok(inp)
+ }
+
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 5057d2ef..df9818ab 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -141,3 +141,19 @@ pub fn conv2d(
let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
Ok(Conv2d::new(ws, Some(bs), cfg))
}
+
+pub fn conv2d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv2dConfig,
+ vs: crate::VarBuilder,
+) -> Result<Conv2d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init(
+ (out_channels, in_channels, kernel_size, kernel_size),
+ "weight",
+ init_ws,
+ )?;
+ Ok(Conv2d::new(ws, None, cfg))
+}
diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs
new file mode 100644
index 00000000..c9ec287a
--- /dev/null
+++ b/candle-nn/src/func.rs
@@ -0,0 +1,27 @@
+//! Layers defined by closures.
+use candle::{Result, Tensor};
+
+/// A layer defined by a simple closure.
+pub struct Func<'a> {
+ #[allow(clippy::type_complexity)]
+ f: Box<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send>,
+}
+
+impl<'a> std::fmt::Debug for Func<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "func")
+ }
+}
+
+pub fn func<'a, F>(f: F) -> Func<'a>
+where
+ F: 'a + Fn(&Tensor) -> Result<Tensor> + Send,
+{
+ Func { f: Box::new(f) }
+}
+
+impl<'a> super::Module for Func<'a> {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ (*self.f)(xs)
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index e195ac67..34e2dbed 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -4,6 +4,7 @@ pub mod activation;
pub mod batch_norm;
pub mod conv;
pub mod embedding;
+pub mod func;
pub mod group_norm;
pub mod init;
pub mod layer_norm;
@@ -16,8 +17,9 @@ pub mod var_map;
pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
-pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
+pub use conv::{conv1d, conv2d, conv2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
+pub use func::{func, Func};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};