diff options
33 files changed, 70 insertions, 28 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 9f22d717..00ead0cd 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -184,6 +184,7 @@ impl QTensor { } } +#[derive(Debug)] pub struct QMatMul(QTensor); impl QMatMul { diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs index b2438e71..3f164a3a 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-examples/examples/bert/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, VarBuilder}; +use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 99f5bb5a..1e63956b 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { let weight = vb.get((size2, size1), "weight")?; diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 1c77cbaf..b638dd51 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -1,6 +1,6 @@ use anyhow::Result; use candle::{DType, Device, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 13eb7390..86d13bdb 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, VarBuilder}; +use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index aae9673a..9b982ddd 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -1,6 +1,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index fa8f9abf..1e7cafa2 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::rc::Rc; diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index d9e596ce..8472bb55 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; use clap::{Parser, ValueEnum}; use candle::{DType, Result, Tensor, D}; -use candle_nn::{loss, ops, Linear, VarBuilder, VarMap}; +use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index eaf4ca05..9c966497 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -1,6 +1,7 @@ use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder}; use anyhow::Result; use candle::{DType, IndexOp, Tensor}; +use candle_nn::Module; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 01266e63..b955205f 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -4,6 +4,7 @@ use crate::nn::{ use crate::{encodec_model, t5_model}; use anyhow::Result; use candle::{DType, Device, Tensor, D}; +use candle_nn::Module; // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 #[derive(Debug, Clone, PartialEq)] diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index ef65df39..613b4112 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -4,6 +4,7 @@ use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; use anyhow::Result; use candle::{DType, Tensor, D}; +use candle_nn::Module; use std::sync::Arc; #[derive(Debug, Clone, PartialEq)] diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a67a5a03..7da7cf1c 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -7,7 +7,7 @@ use tokenizers::Tokenizer; use candle::quantized::ggml_file::Content; use candle::quantized::QTensor; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::Embedding; +use candle_nn::{Embedding, Module}; use candle_transformers::generation::LogitsProcessor; const MAX_SEQ_LEN: usize = 4096; diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs index d981b6f4..255ce857 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -1,6 +1,7 @@ //! Attention Based Building Blocks use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn as nn; +use candle_nn::Module; #[derive(Debug)] struct GeGlu { diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index 29591f55..2d450d99 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -7,6 +7,7 @@ //! https://github.com/openai/CLIP use candle::{DType, Device, Result, Tensor, D}; use candle_nn as nn; +use candle_nn::Module; #[derive(Debug, Clone, Copy)] pub enum Activation { diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs index c94f24f8..88a153ae 100644 --- a/candle-examples/examples/stable-diffusion/embeddings.rs +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use candle::{Result, Tensor, D}; use candle_nn as nn; +use candle_nn::Module; #[derive(Debug)] pub struct TimestepEmbedding { diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs index b6a628be..94f436c8 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -8,6 +8,7 @@ use crate::utils::{conv2d, Conv2d}; use candle::{Result, Tensor, D}; use candle_nn as nn; +use candle_nn::Module; /// Configuration for a ResNet block. #[derive(Debug, Clone, Copy)] diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index 0fa2f31a..6f568113 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -7,6 +7,7 @@ use crate::unet_2d_blocks::*; use crate::utils::{conv2d, Conv2d}; use candle::{Result, Tensor}; use candle_nn as nn; +use candle_nn::Module; #[derive(Debug, Clone, Copy)] pub struct BlockConfig { diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 308e577d..5602a9ad 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,4 +1,5 @@ use candle::{Device, Result, Tensor}; +use candle_nn::Module; pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { if steps < 1 { diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs index 7a10d932..abba39fa 100644 --- a/candle-examples/examples/stable-diffusion/vae.rs +++ b/candle-examples/examples/stable-diffusion/vae.rs @@ -10,6 +10,7 @@ use crate::unet_2d_blocks::{ }; use candle::{Result, Tensor}; use candle_nn as nn; +use candle_nn::Module; #[derive(Debug, Clone)] struct EncoderConfig { diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 00d5707e..553bd93b 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,5 +1,5 @@ use candle::{Device, IndexOp, Result, Tensor}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: diff --git a/candle-nn/examples/basic_optimizer.rs b/candle-nn/examples/basic_optimizer.rs index 3c5665e8..cd5824dd 100644 --- a/candle-nn/examples/basic_optimizer.rs +++ b/candle-nn/examples/basic_optimizer.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor}; -use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap}; +use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap}; fn gen_data() -> Result<(Tensor, Tensor)> { // Generate some sample linear data. diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 9554e68a..0db3edc9 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -7,8 +7,8 @@ pub enum Activation { Elu(f64), } -impl Activation { - pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { +impl super::Module for Activation { + fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { match self { Self::Gelu => xs.gelu(), Self::Relu => xs.relu(), diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 67a80417..5057d2ef 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -35,8 +35,10 @@ impl Conv1d { pub fn config(&self) -> &Conv1dConfig { &self.config } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for Conv1d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), @@ -84,8 +86,10 @@ impl Conv2d { pub fn config(&self) -> &Conv2dConfig { &self.config } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for Conv2d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index f4ba88e7..918c1805 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -18,8 +18,10 @@ impl Embedding { pub fn embeddings(&self) -> &Tensor { &self.embeddings } +} - pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> { +impl crate::Module for Embedding { + fn forward(&self, indexes: &Tensor) -> Result<Tensor> { let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.hidden_size); let indexes = indexes.flatten_all()?; diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index ac77db4b..e85c4379 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -34,8 +34,10 @@ impl GroupNorm { num_groups, }) } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for GroupNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x_shape = x.dims(); if x_shape.len() <= 2 { candle::bail!("input rank for GroupNorm should be at least 3"); diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 17cdef3d..61fbe2d2 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -8,7 +8,7 @@ //! //! ```rust //! use candle::{Tensor, Device::Cpu}; -//! use candle_nn::LayerNorm; +//! use candle_nn::{LayerNorm, Module}; //! # fn main() -> candle::Result<()> { //! //! let w = Tensor::new(1f32, &Cpu)?; @@ -95,8 +95,10 @@ impl LayerNorm { eps, } } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for LayerNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, @@ -152,8 +154,10 @@ impl RmsNorm { pub fn into_inner(self) -> LayerNorm { self.0 } +} - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl crate::Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { self.0.forward(xs) } } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index c04e8ff4..da63d592 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,5 +1,5 @@ -// For now this crate shares its error type with candle-core. We may introduce some separate -// error type if needed or add some specialized cases on the candle-core side. +use candle::{Result, Tensor}; + pub mod activation; pub mod conv; pub mod embedding; @@ -21,3 +21,20 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use optim::{AdamW, ParamsAdamW, SGD}; pub use var_builder::{VarBuilder, VarMap}; + +// A simple trait defining a module with forward method using a single argument. +pub trait Module: std::fmt::Debug { + fn forward(&self, xs: &Tensor) -> Result<Tensor>; + + /// Change the module to use training mode vs eval mode. + /// + /// The default implementation does nothing as this is only used for a couple modules such as + /// dropout or batch-normalization. + fn set_training(&mut self, _training: bool) {} +} + +impl Module for candle::quantized::QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.forward(xs) + } +} diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index a0bd925a..a7bd1028 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -7,7 +7,7 @@ //! //! ```rust //! use candle::{Tensor, Device::Cpu}; -//! use candle_nn::Linear; +//! use candle_nn::{Linear, Module}; //! # fn main() -> candle::Result<()> { //! //! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?; @@ -29,8 +29,10 @@ impl Linear { pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self { Self { weight, bias } } +} - pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { +impl super::Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { let w = match x.dims() { &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs index f3ef2455..3a906c9d 100644 --- a/candle-nn/tests/group_norm.rs +++ b/candle-nn/tests/group_norm.rs @@ -23,7 +23,7 @@ extern crate intel_mkl_src; use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::GroupNorm; +use candle_nn::{GroupNorm, Module}; mod test_utils; use test_utils::to_vec3_round; diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index 3a300cec..849b4987 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -3,7 +3,7 @@ extern crate intel_mkl_src; use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::LayerNorm; +use candle_nn::{LayerNorm, Module}; #[test] fn layer_norm() -> Result<()> { diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 1327ae91..899745d4 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -6,7 +6,7 @@ use test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle::{Device, Tensor, Var}; -use candle_nn::{AdamW, Linear, ParamsAdamW, SGD}; +use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD}; #[test] fn sgd_optim() -> Result<()> { diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 2c867793..3fedb1d3 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 9f3d92f5..3470c3d6 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -3,7 +3,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: |