summaryrefslogtreecommitdiff
path: root/candle-nn/src/embedding.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/embedding.rs')
-rw-r--r--candle-nn/src/embedding.rs12
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index a0a853b0..050123be 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -28,3 +28,15 @@ impl Embedding {
Ok(values)
}
}
+
+pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get_or_init(
+ (in_size, out_size),
+ "weight",
+ crate::Init::Randn {
+ mean: 0.,
+ stdev: 1.,
+ },
+ )?;
+ Ok(Embedding::new(embeddings, out_size))
+}