summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/cpu_backend.rs48
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/storage.rs17
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs10
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs14
7 files changed, 81 insertions, 18 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 345db0e5..307b56dc 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,6 +37,8 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 4aa2f880..401a2c0e 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -633,6 +633,45 @@ impl Map1 for Affine {
}
}
+struct AvgPool2D((usize, usize), (usize, usize));
+
+impl Map1 for AvgPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ let scale = 1f64 / (k_h * k_w) as f64;
+ let scale = T::from_f64(scale);
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut sum = T::zero();
+ for m in 0..k_h {
+ for n in 0..k_w {
+ sum += src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ dst[h_idx * w_out + w_idx] = sum * scale;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
@@ -1529,6 +1568,15 @@ impl BackendStorage for CpuStorage {
Affine(mul, add).map(self, layout)
}
+ fn avg_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ AvgPool2D(kernel_size, stride).map(self, layout)
+ }
+
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 7b4b358d..e71ecfce 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1381,6 +1381,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 17d4a22e..2d5f955c 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -119,6 +119,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
}
impl crate::backend::BackendDevice for CudaDevice {
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index cbca4fc4..47df689c 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -268,11 +268,20 @@ impl Storage {
pub(crate) fn avg_pool2d(
&self,
- _layout: &Layout,
- _kernel_size: (usize, usize),
- _stride: (usize, usize),
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
) -> Result<Self> {
- todo!()
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
}
pub(crate) fn upsample_nearest2d(
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 227660b1..ac9843f7 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
- token_embedding + position_embedding
+ token_embedding.broadcast_add(&position_embedding)
}
}
@@ -161,9 +161,9 @@ impl ClipAttention {
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
- let attn_weights =
- (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
- + causal_attention_mask)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
@@ -287,7 +287,7 @@ impl ClipTextTransformer {
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
- .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 2203b03a..d8327c0e 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -57,13 +57,9 @@ struct Args {
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,
- #[arg(
- long,
- value_name = "FILE",
- default_value = "data/bpe_simple_vocab_16e6.txt"
- )]
- /// The file specifying the vocabulary to used for tokenization.
- vocab_file: String,
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for tokenization.
+ tokenizer: String,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
@@ -165,7 +161,7 @@ fn run(args: Args) -> Result<()> {
height,
width,
n_steps,
- vocab_file,
+ tokenizer,
final_image,
sliced_attention_size,
num_samples,
@@ -184,7 +180,7 @@ fn run(args: Args) -> Result<()> {
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer
.encode(prompt, true)