diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-28 17:51:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-28 16:51:19 +0100 |
commit | 95a857cf57c56a34ecdaae5372f2a13ebd900001 (patch) | |
tree | 9b0bac74758528addfdd27db331d3dcbae20f3ac | |
parent | 612f5b81561150ca6651368c245ac2065c04159a (diff) | |
download | candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.gz candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.bz2 candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.zip |
Move the llama2-c model in transformers. (#1205)
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 6 | ||||
-rw-r--r-- | candle-transformers/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/llama2_c.rs (renamed from candle-examples/examples/llama2-c/model.rs) | 0 | ||||
-rw-r--r-- | candle-transformers/src/models/llama2_c_weights.rs (renamed from candle-examples/examples/llama2-c/weights.rs) | 5 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 3 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_llama2_c.rs (renamed from candle-examples/examples/llama2-c/qmodel.rs) | 6 |
6 files changed, 12 insertions, 9 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 77dbc677..a3f01ae2 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -6,10 +6,10 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -mod model; -mod qmodel; +use candle_transformers::models::llama2_c as model; +use candle_transformers::models::llama2_c_weights as weights; +use candle_transformers::models::quantized_llama2_c as qmodel; mod training; -mod weights; use clap::{Parser, Subcommand}; use anyhow::{Error as E, Result}; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 5af7e55d..e7290be6 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } +byteorder = { workspace = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } candle-nn = { path = "../candle-nn", version = "0.3.0" } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-transformers/src/models/llama2_c.rs index 07a6e2f2..07a6e2f2 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-transformers/src/models/llama2_c.rs diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-transformers/src/models/llama2_c_weights.rs index b78418ce..e5a8bb88 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-transformers/src/models/llama2_c_weights.rs @@ -1,9 +1,8 @@ -use anyhow::Result; use byteorder::{LittleEndian, ReadBytesExt}; -use candle::{DType, Device, IndexOp, Shape, Tensor}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; -use crate::model::Config; +use super::llama2_c::Config; pub struct TransformerWeights { // token embedding table diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f722e93b..c59bd880 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,8 @@ pub mod efficientnet; pub mod falcon; pub mod jina_bert; pub mod llama; +pub mod llama2_c; +pub mod llama2_c_weights; pub mod mistral; pub mod mixformer; pub mod mpt; @@ -15,6 +17,7 @@ pub mod persimmon; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; +pub mod quantized_llama2_c; pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-transformers/src/models/quantized_llama2_c.rs index 07db146e..68ebee0d 100644 --- a/candle-examples/examples/llama2-c/qmodel.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -1,7 +1,7 @@ -use super::model::{Cache, Config}; +use super::llama2_c::{Cache, Config}; +use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; +pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Module, Result, Tensor, D}; -use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; -pub use candle_transformers::quantized_var_builder::VarBuilder; fn silu(xs: &Tensor) -> Result<Tensor> { xs / (xs.neg()?.exp()? + 1.0)? |