summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/mod.rs1
-rw-r--r--candle-examples/examples/bert/model.rs2
-rw-r--r--candle-examples/examples/bigcode/model.rs2
-rw-r--r--candle-examples/examples/falcon/model.rs2
-rw-r--r--candle-examples/examples/llama/model.rs2
-rw-r--r--candle-examples/examples/llama2-c/model.rs2
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs2
-rw-r--r--candle-examples/examples/mnist-training/main.rs2
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs1
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs1
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs1
-rw-r--r--candle-examples/examples/quantized/main.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs1
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs1
-rw-r--r--candle-examples/examples/whisper/model.rs2
-rw-r--r--candle-nn/examples/basic_optimizer.rs2
-rw-r--r--candle-nn/src/activation.rs4
-rw-r--r--candle-nn/src/conv.rs8
-rw-r--r--candle-nn/src/embedding.rs4
-rw-r--r--candle-nn/src/group_norm.rs4
-rw-r--r--candle-nn/src/layer_norm.rs10
-rw-r--r--candle-nn/src/lib.rs21
-rw-r--r--candle-nn/src/linear.rs6
-rw-r--r--candle-nn/tests/group_norm.rs2
-rw-r--r--candle-nn/tests/layer_norm.rs2
-rw-r--r--candle-nn/tests/optim.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs2
-rw-r--r--candle-wasm-examples/whisper/src/model.rs2
33 files changed, 70 insertions, 28 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 9f22d717..00ead0cd 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -184,6 +184,7 @@ impl QTensor {
}
}
+#[derive(Debug)]
pub struct QMatMul(QTensor);
impl QMatMul {
diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs
index b2438e71..3f164a3a 100644
--- a/candle-examples/examples/bert/model.rs
+++ b/candle-examples/examples/bert/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, VarBuilder};
+use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index 99f5bb5a..1e63956b 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
+use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index 1c77cbaf..b638dd51 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
+use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 13eb7390..86d13bdb 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, VarBuilder};
+use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index aae9673a..9b982ddd 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
-use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index fa8f9abf..1e7cafa2 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
-use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index d9e596ce..8472bb55 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Result, Tensor, D};
-use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
+use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index eaf4ca05..9c966497 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -1,6 +1,7 @@
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
+use candle_nn::Module;
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 01266e63..b955205f 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -4,6 +4,7 @@ use crate::nn::{
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
+use candle_nn::Module;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index ef65df39..613b4112 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -4,6 +4,7 @@
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::{DType, Tensor, D};
+use candle_nn::Module;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index a67a5a03..7da7cf1c 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -7,7 +7,7 @@ use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::Embedding;
+use candle_nn::{Embedding, Module};
use candle_transformers::generation::LogitsProcessor;
const MAX_SEQ_LEN: usize = 4096;
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
index d981b6f4..255ce857 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -1,6 +1,7 @@
//! Attention Based Building Blocks
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
+use candle_nn::Module;
#[derive(Debug)]
struct GeGlu {
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 29591f55..2d450d99 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -7,6 +7,7 @@
//! https://github.com/openai/CLIP
use candle::{DType, Device, Result, Tensor, D};
use candle_nn as nn;
+use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
index c94f24f8..88a153ae 100644
--- a/candle-examples/examples/stable-diffusion/embeddings.rs
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -1,6 +1,7 @@
#![allow(dead_code)]
use candle::{Result, Tensor, D};
use candle_nn as nn;
+use candle_nn::Module;
#[derive(Debug)]
pub struct TimestepEmbedding {
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
index b6a628be..94f436c8 100644
--- a/candle-examples/examples/stable-diffusion/resnet.rs
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -8,6 +8,7 @@
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
+use candle_nn::Module;
/// Configuration for a ResNet block.
#[derive(Debug, Clone, Copy)]
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index 0fa2f31a..6f568113 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -7,6 +7,7 @@ use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
+use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 308e577d..5602a9ad 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -1,4 +1,5 @@
use candle::{Device, Result, Tensor};
+use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
index 7a10d932..abba39fa 100644
--- a/candle-examples/examples/stable-diffusion/vae.rs
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -10,6 +10,7 @@ use crate::unet_2d_blocks::{
};
use candle::{Result, Tensor};
use candle_nn as nn;
+use candle_nn::Module;
#[derive(Debug, Clone)]
struct EncoderConfig {
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 00d5707e..553bd93b 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -1,5 +1,5 @@
use candle::{Device, IndexOp, Result, Tensor};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
+use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
diff --git a/candle-nn/examples/basic_optimizer.rs b/candle-nn/examples/basic_optimizer.rs
index 3c5665e8..cd5824dd 100644
--- a/candle-nn/examples/basic_optimizer.rs
+++ b/candle-nn/examples/basic_optimizer.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor};
-use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap};
+use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap};
fn gen_data() -> Result<(Tensor, Tensor)> {
// Generate some sample linear data.
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 9554e68a..0db3edc9 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -7,8 +7,8 @@ pub enum Activation {
Elu(f64),
}
-impl Activation {
- pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+impl super::Module for Activation {
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 67a80417..5057d2ef 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -35,8 +35,10 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for Conv1d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
@@ -84,8 +86,10 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for Conv2d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index f4ba88e7..918c1805 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -18,8 +18,10 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
+}
- pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
+impl crate::Module for Embedding {
+ fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index ac77db4b..e85c4379 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -34,8 +34,10 @@ impl GroupNorm {
num_groups,
})
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for GroupNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_shape = x.dims();
if x_shape.len() <= 2 {
candle::bail!("input rank for GroupNorm should be at least 3");
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 17cdef3d..61fbe2d2 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -8,7 +8,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
-//! use candle_nn::LayerNorm;
+//! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(1f32, &Cpu)?;
@@ -95,8 +95,10 @@ impl LayerNorm {
eps,
}
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for LayerNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
@@ -152,8 +154,10 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
+}
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl crate::Module for RmsNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index c04e8ff4..da63d592 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -1,5 +1,5 @@
-// For now this crate shares its error type with candle-core. We may introduce some separate
-// error type if needed or add some specialized cases on the candle-core side.
+use candle::{Result, Tensor};
+
pub mod activation;
pub mod conv;
pub mod embedding;
@@ -21,3 +21,20 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};
+
+// A simple trait defining a module with forward method using a single argument.
+pub trait Module: std::fmt::Debug {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor>;
+
+ /// Change the module to use training mode vs eval mode.
+ ///
+ /// The default implementation does nothing as this is only used for a couple modules such as
+ /// dropout or batch-normalization.
+ fn set_training(&mut self, _training: bool) {}
+}
+
+impl Module for candle::quantized::QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward(xs)
+ }
+}
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index a0bd925a..a7bd1028 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -7,7 +7,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
-//! use candle_nn::Linear;
+//! use candle_nn::{Linear, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
@@ -29,8 +29,10 @@ impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
+}
- pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+impl super::Module for Linear {
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs
index f3ef2455..3a906c9d 100644
--- a/candle-nn/tests/group_norm.rs
+++ b/candle-nn/tests/group_norm.rs
@@ -23,7 +23,7 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
-use candle_nn::GroupNorm;
+use candle_nn::{GroupNorm, Module};
mod test_utils;
use test_utils::to_vec3_round;
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs
index 3a300cec..849b4987 100644
--- a/candle-nn/tests/layer_norm.rs
+++ b/candle-nn/tests/layer_norm.rs
@@ -3,7 +3,7 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
-use candle_nn::LayerNorm;
+use candle_nn::{LayerNorm, Module};
#[test]
fn layer_norm() -> Result<()> {
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 1327ae91..899745d4 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -6,7 +6,7 @@ use test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result;
use candle::{Device, Tensor, Var};
-use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
+use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
#[test]
fn sgd_optim() -> Result<()> {
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 2c867793..3fedb1d3 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs
index 9f3d92f5..3470c3d6 100644
--- a/candle-wasm-examples/whisper/src/model.rs
+++ b/candle-wasm-examples/whisper/src/model.rs
@@ -3,7 +3,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
-use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
+use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation: