summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/with_tracing.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/with_tracing.rs')
-rw-r--r--candle-transformers/src/models/with_tracing.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
index 39258085..a657011c 100644
--- a/candle-transformers/src/models/with_tracing.rs
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -14,6 +14,13 @@ impl Embedding {
Ok(Self { inner, span })
}
+ pub fn from_weights(weights: Tensor) -> Result<Self> {
+ let (_in_size, out_size) = weights.dims2()?;
+ let inner = candle_nn::Embedding::new(weights, out_size);
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
pub fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}