summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/falcon.rs
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-01-10 22:36:27 +0200
committerGitHub <noreply@github.com>2024-01-10 21:36:27 +0100
commit63944714f267bd3824c548ffcaaaef5e29c4066e (patch)
treee75dc6cc7673a0208ffed3d643b95cce6c9e15f1 /candle-transformers/src/models/falcon.rs
parentd3bdd788cfdcf49b6ea539b77647b82a0b979db0 (diff)
downloadcandle-63944714f267bd3824c548ffcaaaef5e29c4066e.tar.gz
candle-63944714f267bd3824c548ffcaaaef5e29c4066e.tar.bz2
candle-63944714f267bd3824c548ffcaaaef5e29c4066e.zip
Use candle_nn::embedding instead of local copies in a few models. (#1562)
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r--candle-transformers/src/models/falcon.rs7
1 files changed, 1 insertions, 6 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
index 6ede136a..ef5a92fc 100644
--- a/candle-transformers/src/models/falcon.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
@@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
#[derive(Debug)]
pub struct Config {