summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/t5.rs50
1 files changed, 43 insertions, 7 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index b1f3a3aa..fd2720d3 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -18,12 +18,15 @@ fn default_use_cache() -> bool {
true
}
+fn default_tie_word_embeddings() -> bool {
+ true
+}
+
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
- let result = Tensor::from_slice(&mask, (size, size), device)?;
- Ok(result)
+ Tensor::from_slice(&mask, (size, size), device)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@@ -50,6 +53,8 @@ pub struct Config {
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
+ #[serde(default = "default_tie_word_embeddings")]
+ tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
is_decoder: bool,
is_encoder_decoder: bool,
@@ -75,6 +80,7 @@ impl Default for Config {
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
@@ -94,6 +100,7 @@ impl Config {
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
@@ -611,6 +618,9 @@ impl T5EncoderModel {
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,
+ d_model: usize,
+ tie_word_embeddings: bool,
+ lm_head: Option<Linear>,
shared: Arc<Embedding>,
device: Device,
}
@@ -618,6 +628,7 @@ pub struct T5ForConditionalGeneration {
impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
+ let d_model = cfg.d_model;
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
@@ -633,9 +644,23 @@ impl T5ForConditionalGeneration {
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
+ let tie_word_embeddings = cfg.tie_word_embeddings;
+ let lm_head = if tie_word_embeddings {
+ None
+ } else {
+ Some(linear_no_bias(
+ cfg.d_model,
+ cfg.vocab_size,
+ vb.pp("lm_head"),
+ )?)
+ };
+
Ok(Self {
encoder,
decoder,
+ d_model,
+ tie_word_embeddings,
+ lm_head,
shared,
device: vb.device().clone(),
})
@@ -653,12 +678,23 @@ impl T5ForConditionalGeneration {
let decoder_output = self
.decoder
.forward(decoder_input_ids, Some(encoder_output))?;
- let sequence_output = decoder_output
+
+ let scaling_factor = if self.tie_word_embeddings {
+ // Rescale output before projecting on vocab
+ // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ (self.d_model as f64).sqrt()
+ } else {
+ 1.0
+ };
+ let sequence_output = ((decoder_output
.narrow(1, decoder_output.dim(1)? - 1, 1)?
- .squeeze(1)?;
- // TODO: check cfg.tie_word_embeddings to load from model instead.
- let lm_head_weights = self.shared.embeddings().t()?;
- let output = sequence_output.matmul(&lm_head_weights)?;
+ .squeeze(1)?)
+ * scaling_factor)?;
+ let output = match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ };
+
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
Ok(output)
}