diff options
Diffstat (limited to 'candle-transformers/src/models/with_tracing.rs')
-rw-r--r-- | candle-transformers/src/models/with_tracing.rs | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 69654139..39258085 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -32,6 +32,14 @@ pub struct Linear { span: tracing::Span, } +impl Linear { + pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self { + let inner = candle_nn::Linear::new(weights, bias); + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Self { inner, span } + } +} + pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> { let inner = candle_nn::linear(d1, d2, vb)?; let span = tracing::span!(tracing::Level::TRACE, "linear"); |