summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/chatglm.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-22 09:35:28 +0100
committerGitHub <noreply@github.com>2024-02-22 09:35:28 +0100
commitc753f72c8552ba3e108bd3f1a04971e8abbf3012 (patch)
treedbd3f076b9c01811dd58ce6e30122d594b617b6f /candle-transformers/src/models/chatglm.rs
parent8013b50829c4256d2a04b7b1acd3de90d9a95650 (diff)
downloadcandle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.gz
candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.tar.bz2
candle-c753f72c8552ba3e108bd3f1a04971e8abbf3012.zip
Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
Diffstat (limited to 'candle-transformers/src/models/chatglm.rs')
-rw-r--r--candle-transformers/src/models/chatglm.rs10
1 files changed, 1 insertions, 9 deletions
diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs
index 95466b34..0686b34e 100644
--- a/candle-transformers/src/models/chatglm.rs
+++ b/candle-transformers/src/models/chatglm.rs
@@ -1,4 +1,4 @@
-use crate::models::with_tracing::Linear;
+use crate::models::with_tracing::{linear_b as linear, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@@ -51,14 +51,6 @@ impl Config {
}
}
-fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
- if bias {
- crate::models::with_tracing::linear(in_dim, out_dim, vb)
- } else {
- crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
- }
-}
-
#[derive(Debug, Clone)]
struct RotaryEmbedding {
cache: Tensor,