summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-28 17:51:19 +0200
committerGitHub <noreply@github.com>2023-10-28 16:51:19 +0100
commit95a857cf57c56a34ecdaae5372f2a13ebd900001 (patch)
tree9b0bac74758528addfdd27db331d3dcbae20f3ac
parent612f5b81561150ca6651368c245ac2065c04159a (diff)
downloadcandle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.gz
candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.bz2
candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.zip
Move the llama2-c model in transformers. (#1205)
-rw-r--r--candle-examples/examples/llama2-c/main.rs6
-rw-r--r--candle-transformers/Cargo.toml1
-rw-r--r--candle-transformers/src/models/llama2_c.rs (renamed from candle-examples/examples/llama2-c/model.rs)0
-rw-r--r--candle-transformers/src/models/llama2_c_weights.rs (renamed from candle-examples/examples/llama2-c/weights.rs)5
-rw-r--r--candle-transformers/src/models/mod.rs3
-rw-r--r--candle-transformers/src/models/quantized_llama2_c.rs (renamed from candle-examples/examples/llama2-c/qmodel.rs)6
6 files changed, 12 insertions, 9 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 77dbc677..a3f01ae2 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -6,10 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-mod model;
-mod qmodel;
+use candle_transformers::models::llama2_c as model;
+use candle_transformers::models::llama2_c_weights as weights;
+use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
-mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 5af7e55d..e7290be6 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -11,6 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
+byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-transformers/src/models/llama2_c.rs
index 07a6e2f2..07a6e2f2 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-transformers/src/models/llama2_c.rs
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-transformers/src/models/llama2_c_weights.rs
index b78418ce..e5a8bb88 100644
--- a/candle-examples/examples/llama2-c/weights.rs
+++ b/candle-transformers/src/models/llama2_c_weights.rs
@@ -1,9 +1,8 @@
-use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt};
-use candle::{DType, Device, IndexOp, Shape, Tensor};
+use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_nn::VarBuilder;
-use crate::model::Config;
+use super::llama2_c::Config;
pub struct TransformerWeights {
// token embedding table
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index f722e93b..c59bd880 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -8,6 +8,8 @@ pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;
pub mod llama;
+pub mod llama2_c;
+pub mod llama2_c_weights;
pub mod mistral;
pub mod mixformer;
pub mod mpt;
@@ -15,6 +17,7 @@ pub mod persimmon;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
+pub mod quantized_llama2_c;
pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_mpt;
diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-transformers/src/models/quantized_llama2_c.rs
index 07db146e..68ebee0d 100644
--- a/candle-examples/examples/llama2-c/qmodel.rs
+++ b/candle-transformers/src/models/quantized_llama2_c.rs
@@ -1,7 +1,7 @@
-use super::model::{Cache, Config};
+use super::llama2_c::{Cache, Config};
+use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
+pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Module, Result, Tensor, D};
-use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
-pub use candle_transformers::quantized_var_builder::VarBuilder;
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?