diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-01-10 22:36:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-10 21:36:27 +0100 |
commit | 63944714f267bd3824c548ffcaaaef5e29c4066e (patch) | |
tree | e75dc6cc7673a0208ffed3d643b95cce6c9e15f1 /candle-transformers/src/models/falcon.rs | |
parent | d3bdd788cfdcf49b6ea539b77647b82a0b979db0 (diff) | |
download | candle-63944714f267bd3824c548ffcaaaef5e29c4066e.tar.gz candle-63944714f267bd3824c548ffcaaaef5e29c4066e.tar.bz2 candle-63944714f267bd3824c548ffcaaaef5e29c4066e.zip |
Use candle_nn::embedding instead of local copies in a few models. (#1562)
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r-- | candle-transformers/src/models/falcon.rs | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 6ede136a..ef5a92fc 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { Ok(LayerNorm::new(weight, bias, eps)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py #[derive(Debug)] pub struct Config { |