diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 12 | ||||
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 1 | ||||
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 19 | ||||
-rw-r--r-- | candle-examples/examples/llama/var_store.rs | 10 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 2 |
5 files changed, 3 insertions, 41 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index d0d600ee..d8f6921e 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -86,7 +84,7 @@ impl Default for Config { } impl Config { - fn all_mini_lm_l6_v2() -> Self { + fn _all_mini_lm_l6_v2() -> Self { // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json Self { vocab_size: 30522, @@ -121,6 +119,7 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { } struct Dropout { + #[allow(dead_code)] pr: f64, } @@ -156,8 +155,6 @@ struct BertEmbeddings { token_type_embeddings: Embedding, layer_norm: LayerNorm, dropout: Dropout, - position_ids: Tensor, - token_type_ids: Tensor, } impl BertEmbeddings { @@ -182,17 +179,12 @@ impl BertEmbeddings { config.layer_norm_eps, vb.pp("LayerNorm"), )?; - let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); - let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?; - let token_type_ids = position_ids.zeros_like()?; Ok(Self { word_embeddings, position_embeddings: Some(position_embeddings), token_type_embeddings, layer_norm, dropout: Dropout::new(config.hidden_dropout_prob), - position_ids, - token_type_ids, }) } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index a59a0349..3cd1d1f8 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] // TODO: Add an offline mode. #[cfg(feature = "mkl")] diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 631ff280..f97fe219 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -28,22 +28,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { Ok(LayerNorm::new(weight, bias, eps)) } -#[derive(Debug)] -struct Dropout { - pr: f64, -} - -impl Dropout { - fn new(pr: f64) -> Self { - Self { pr } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - // TODO - Ok(x.clone()) - } -} - fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) @@ -345,7 +329,6 @@ impl FalconAttention { struct FalconMlp { dense_h_to_4h: Linear, dense_4h_to_h: Linear, - dropout: Dropout, } impl FalconMlp { @@ -354,11 +337,9 @@ impl FalconMlp { let b = cfg.bias; let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; - let dropout = Dropout::new(cfg.hidden_dropout); Ok(Self { dense_h_to_4h, dense_4h_to_h, - dropout, }) } diff --git a/candle-examples/examples/llama/var_store.rs b/candle-examples/examples/llama/var_store.rs index 1a22bd89..bd1114a0 100644 --- a/candle-examples/examples/llama/var_store.rs +++ b/candle-examples/examples/llama/var_store.rs @@ -1,16 +1,8 @@ use super::*; -use candle::{DType, Device, Result, Shape, Tensor}; +use candle::{Device, Result, Tensor}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -#[allow(dead_code)] -#[derive(Clone)] -struct NamedVar { - path: String, - dtype: DType, - shape: Shape, -} - #[derive(Clone)] pub struct VarBuilder { path: Vec<String>, diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 81643466..5c90dd4e 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use anyhow::Result; use candle::Tensor; |