summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_t5.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_t5.rs')
-rw-r--r--candle-transformers/src/models/quantized_t5.rs27
1 files changed, 1 insertions, 26 deletions
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 398e82a7..1426df39 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -2,38 +2,13 @@
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
use serde::Deserialize;
use std::sync::Arc;
-#[derive(Debug)]
-pub struct Embedding {
- inner: candle_nn::Embedding,
- span: tracing::Span,
-}
-
-impl Embedding {
- pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
- let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
- let inner = candle_nn::Embedding::new(embeddings, d2);
- let span = tracing::span!(tracing::Level::TRACE, "embedding");
- Ok(Self { inner, span })
- }
-
- pub fn embeddings(&self) -> &Tensor {
- self.inner.embeddings()
- }
-}
-
-impl Module for Embedding {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(xs)
- }
-}
-
fn default_relative_attention_max_distance() -> usize {
128
}