diff options
Diffstat (limited to 'candle-transformers/src/models/with_tracing.rs')
-rw-r--r-- | candle-transformers/src/models/with_tracing.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 39258085..a657011c 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -14,6 +14,13 @@ impl Embedding { Ok(Self { inner, span }) } + pub fn from_weights(weights: Tensor) -> Result<Self> { + let (_in_size, out_size) = weights.dims2()?; + let inner = candle_nn::Embedding::new(weights, out_size); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } |