diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-18 09:38:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-18 09:38:22 +0100 |
commit | c78ce765016392673805ed8dfafb4ae1a7b6c26f (patch) | |
tree | df7bab84b80da4754aef94f0dd73503c33bc6e44 /candle-wasm-examples | |
parent | 13401df4d141bf568a2c2056411d62060707e79b (diff) | |
download | candle-c78ce765016392673805ed8dfafb4ae1a7b6c26f.tar.gz candle-c78ce765016392673805ed8dfafb4ae1a7b6c26f.tar.bz2 candle-c78ce765016392673805ed8dfafb4ae1a7b6c26f.zip |
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait.
* Use the module trait.
* Implement module for qmatmul.
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 2 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/model.rs | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 2c867793..3fedb1d3 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 9f3d92f5..3470c3d6 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -3,7 +3,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: |