diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_t5.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 27 |
1 files changed, 1 insertions, 26 deletions
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 398e82a7..1426df39 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -2,38 +2,13 @@ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::Embedding; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; use serde::Deserialize; use std::sync::Arc; -#[derive(Debug)] -pub struct Embedding { - inner: candle_nn::Embedding, - span: tracing::Span, -} - -impl Embedding { - pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { - let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; - let inner = candle_nn::Embedding::new(embeddings, d2); - let span = tracing::span!(tracing::Level::TRACE, "embedding"); - Ok(Self { inner, span }) - } - - pub fn embeddings(&self) -> &Tensor { - self.inner.embeddings() - } -} - -impl Module for Embedding { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - fn default_relative_attention_max_distance() -> usize { 128 } |