summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/with_tracing.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/with_tracing.rs')
-rw-r--r--candle-transformers/src/models/with_tracing.rs8
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
index 69654139..39258085 100644
--- a/candle-transformers/src/models/with_tracing.rs
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -32,6 +32,14 @@ pub struct Linear {
span: tracing::Span,
}
+impl Linear {
+ pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
+ let inner = candle_nn::Linear::new(weights, bias);
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Self { inner, span }
+ }
+}
+
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");