summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/book-cd.yml2
-rw-r--r--Cargo.toml4
-rw-r--r--README.md6
-rw-r--r--candle-book/src/SUMMARY.md26
-rw-r--r--candle-book/src/cuda/README.md1
-rw-r--r--candle-book/src/cuda/porting.md1
-rw-r--r--candle-book/src/cuda/writing.md1
-rw-r--r--candle-book/src/error_manage.md50
-rw-r--r--candle-book/src/inference/README.md6
-rw-r--r--candle-book/src/inference/hub.md103
-rw-r--r--candle-book/src/training/serialization.md (renamed from candle-book/src/inference/serialization.md)0
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/examples/conv1d_benchmark.rs24
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/conv.rs7
-rw-r--r--candle-core/src/cpu_backend.rs116
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/error.rs14
-rw-r--r--candle-core/src/npy.rs18
-rw-r--r--candle-core/src/op.rs7
-rw-r--r--candle-core/src/safetensors.rs26
-rw-r--r--candle-core/src/shape.rs27
-rw-r--r--candle-core/src/storage.rs18
-rw-r--r--candle-core/src/tensor.rs29
-rw-r--r--candle-core/tests/pool_tests.rs61
-rw-r--r--candle-core/tests/tensor_tests.rs11
-rw-r--r--candle-examples/Cargo.toml9
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs12
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs19
-rw-r--r--candle-examples/examples/whisper/main.rs37
-rw-r--r--candle-examples/examples/whisper/melfilters.bytes (renamed from candle-examples/examples/whisper/mel_filters.safetensors)bin64400 -> 64320 bytes
-rw-r--r--candle-examples/examples/whisper/model.rs98
-rw-r--r--candle-examples/src/lib.rs99
35 files changed, 684 insertions, 162 deletions
diff --git a/.github/workflows/book-cd.yml b/.github/workflows/book-cd.yml
index fc693a78..e8149e38 100644
--- a/.github/workflows/book-cd.yml
+++ b/.github/workflows/book-cd.yml
@@ -1,7 +1,5 @@
name: Deploy Rust book
on:
- # TODO put this back only when merging after this PR lands.
- pull_request:
push:
branches:
- main
diff --git a/Cargo.toml b/Cargo.toml
index 887be45d..850b13ef 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -30,9 +30,10 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
-gemm = { version = "0.15.5", package = "candle-gemm" }
+gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
+image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
@@ -40,6 +41,7 @@ memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
+rand_distr = "0.4.3"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
diff --git a/README.md b/README.md
index c8622b88..23908756 100644
--- a/README.md
+++ b/README.md
@@ -59,7 +59,7 @@ And then browse to
## Features
-- Simple syntax, looks and like PyTorch.
+- Simple syntax, looks and feels like PyTorch.
- CPU and Cuda backends, m1, f16, bf16.
- Enable serverless (CPU), small and fast deployments
- WASM support, run your models in a browser.
@@ -78,7 +78,7 @@ Cheatsheet:
| | Using PyTorch | Using Candle |
|------------|------------------------------------------|------------------------------------------------------------------|
-| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
+| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
@@ -120,7 +120,7 @@ Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors
- [dfdx](https://github.com/coreylowman/dfdx) is a formidable crate, with shapes being included
in types preventing a lot of headaches by getting compiler to complain about shape mismatch right off the bat
- However we found that some features still require nightly and writing code can be a bit dauting for non rust experts.
+ However we found that some features still require nightly and writing code can be a bit daunting for non rust experts.
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each
other
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index ddd6e916..3432f66f 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -12,16 +12,16 @@
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- - [Serialization](inference/serialization.md)
- - [Advanced Cuda usage](inference/cuda/README.md)
- - [Writing a custom kernel](inference/cuda/writing.md)
- - [Porting a custom kernel](inference/cuda/porting.md)
-- [Error management](error_manage.md)
-- [Creating apps](apps/README.md)
- - [Creating a WASM app](apps/wasm.md)
- - [Creating a REST api webserver](apps/rest.md)
- - [Creating a desktop Tauri app](apps/dekstop.md)
-- [Training](training/README.md)
- - [MNIST](training/mnist.md)
- - [Fine-tuning](training/finetuning.md)
-- [Using MKL](advanced/mkl.md)
+- [Error management]()
+- [Advanced Cuda usage]()
+ - [Writing a custom kernel]()
+ - [Porting a custom kernel]()
+- [Using MKL]()
+- [Creating apps]()
+ - [Creating a WASM app]()
+ - [Creating a REST api webserver]()
+ - [Creating a desktop Tauri app]()
+- [Training]()
+ - [MNIST]()
+ - [Fine-tuning]()
+ - [Serialization]()
diff --git a/candle-book/src/cuda/README.md b/candle-book/src/cuda/README.md
new file mode 100644
index 00000000..68434cbf
--- /dev/null
+++ b/candle-book/src/cuda/README.md
@@ -0,0 +1 @@
+# Advanced Cuda usage
diff --git a/candle-book/src/cuda/porting.md b/candle-book/src/cuda/porting.md
new file mode 100644
index 00000000..e332146d
--- /dev/null
+++ b/candle-book/src/cuda/porting.md
@@ -0,0 +1 @@
+# Porting a custom kernel
diff --git a/candle-book/src/cuda/writing.md b/candle-book/src/cuda/writing.md
new file mode 100644
index 00000000..0fe1f3dc
--- /dev/null
+++ b/candle-book/src/cuda/writing.md
@@ -0,0 +1 @@
+# Writing a custom kernel
diff --git a/candle-book/src/error_manage.md b/candle-book/src/error_manage.md
index 042e191f..c1a16bd9 100644
--- a/candle-book/src/error_manage.md
+++ b/candle-book/src/error_manage.md
@@ -1 +1,51 @@
# Error management
+
+You might have seen in the code base a lot of `.unwrap()` or `?`.
+If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
+for more information.
+
+What's important to know though, is that if you want to know *where* a particular operation failed
+You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
+
+Let's see on failing code:
+
+```rust,ignore
+let x = Tensor::zeros((1, 784), DType::F32, &device)?;
+let y = Tensor::zeros((1, 784), DType::F32, &device)?;
+let z = x.matmul(&y)?;
+```
+
+Will print at runtime:
+
+```bash
+Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
+```
+
+
+After adding `RUST_BACKTRACE=1`:
+
+
+```bash
+Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
+```
+
+Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
+
+
+Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
+especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
+The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
+
+## Cuda error management
+
+When running a model on Cuda, you might get a stacktrace not really representing the error.
+The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
+
+One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
+You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.
+
+
+If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
+This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.
+
+
diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/README.md
index c82f85e1..1b75a310 100644
--- a/candle-book/src/inference/README.md
+++ b/candle-book/src/inference/README.md
@@ -1 +1,7 @@
# Running a model
+
+
+In order to run an existing model, you will need to download and use existing weights.
+Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
+
+Let's get started by running an old model : `bert-base-uncased`.
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index 6242c070..b924b76d 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -1 +1,104 @@
# Using the hub
+
+Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
+
+```bash
+cargo add hf-hub
+```
+
+Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
+
+
+```rust
+# extern crate candle_core;
+# extern crate hf_hub;
+use hf_hub::api::sync::Api;
+use candle_core::Device;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights = repo.get("model.safetensors").unwrap();
+
+let weights = candle_core::safetensors::load(weights, &Device::Cpu);
+```
+
+We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
+
+You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
+
+
+## Using async
+
+`hf-hub` comes with an async API.
+
+```bash
+cargo add hf-hub --features tokio
+```
+
+```rust,ignore
+# This is tested directly in examples crate because it needs external dependencies unfortunately:
+# See [this](https://github.com/rust-lang/mdBook/issues/706)
+{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
+```
+
+
+## Using in a real model.
+
+Now that we have our weights, we can use them in our bert architecture:
+
+```rust
+# extern crate candle_core;
+# extern crate candle_nn;
+# extern crate hf_hub;
+# use hf_hub::api::sync::Api;
+#
+# let api = Api::new().unwrap();
+# let repo = api.model("bert-base-uncased".to_string());
+#
+# let weights = repo.get("model.safetensors").unwrap();
+use candle_core::{Device, Tensor, DType};
+use candle_nn::Linear;
+
+let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
+
+let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
+let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
+
+let linear = Linear::new(weight.clone(), Some(bias.clone()));
+
+let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
+let output = linear.forward(&input_ids).unwrap();
+```
+
+For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
+
+## Memory mapping
+
+For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
+
+**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
+and will definitely be slower on network mounted disk, because it will issue more read calls.
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
+```
+
+**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
+In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
+
+
+## Tensor Parallel Sharding
+
+When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
+
+For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
+
+```bash
+cargo add safetensors
+```
+
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
+```
diff --git a/candle-book/src/inference/serialization.md b/candle-book/src/training/serialization.md
index 0dfc62d3..0dfc62d3 100644
--- a/candle-book/src/inference/serialization.md
+++ b/candle-book/src/training/serialization.md
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index af77a0e0..7411592e 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -22,6 +22,7 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
+rand_distr = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
diff --git a/candle-core/examples/conv1d_benchmark.rs b/candle-core/examples/conv1d_benchmark.rs
new file mode 100644
index 00000000..52fae5e8
--- /dev/null
+++ b/candle-core/examples/conv1d_benchmark.rs
@@ -0,0 +1,24 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Result;
+use candle_core::{Device, Tensor};
+
+pub const N_ITERS: usize = 5;
+
+fn main() -> Result<()> {
+ let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
+ let res = inp.conv1d(&w, 0, 1);
+ println!("{res:?}");
+ let start = std::time::Instant::now();
+ for i in 0..N_ITERS {
+ let res = inp.conv1d(&w, 0, 1);
+ println!("{i} {res:?}");
+ }
+ println!("{:?}", start.elapsed() / N_ITERS as u32);
+ Ok(())
+}
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index a8e5ac52..4c31ca6f 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -46,6 +46,7 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 0eab508e..2a60fe30 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -88,6 +88,7 @@ impl Tensor {
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
+ | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -172,6 +173,7 @@ impl Tensor {
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 30799459..e3fea861 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
- pub(crate) b_size: Option<usize>,
+ pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@@ -19,10 +19,7 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
- match self.b_size {
- None => vec![self.c_out, l_out],
- Some(n) => vec![n, self.c_out, l_out],
- }
+ vec![self.b_size, self.c_out, l_out]
}
}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0ec19559..d4f5fcdc 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
@@ -672,6 +674,48 @@ impl Map1 for AvgPool2D {
}
}
+struct MaxPool2D((usize, usize), (usize, usize));
+
+impl Map1 for MaxPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut largest =
+ src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
+ if largest < src[src_index + m * stride_h + n * stride_w] {
+ largest = src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ }
+ dst[h_idx * w_out + w_idx] = largest;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -990,19 +1034,14 @@ impl<'a> Map2 for Conv1D<'a> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
- let inp_stride = inp_l.stride();
- let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
- (inp_stride[0], &inp_stride[1..])
- } else {
- (0, inp_stride) // This value never gets used anyway
- };
- let k_stride = k_l.stride();
+ let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
+ let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
- let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
- for b_idx in 0..p.b_size.unwrap_or(1) {
- let inp_idx = b_idx * inp_stride0;
+ for b_idx in 0..p.b_size {
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
@@ -1014,11 +1053,8 @@ impl<'a> Map2 for Conv1D<'a> {
.saturating_sub(p.padding)
.min(p.l_in - 1);
for src_c_idx in 0..p.c_in {
- let inp_idx =
- inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset * k_stride[2];
+ let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
+ let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1043,14 +1079,14 @@ impl<'a> Map2 for Conv2D<'a> {
) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
- let inp_stride = inp_l.stride();
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
- let k_stride = k_l.stride();
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_stride[0];
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * out_h * out_w;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
@@ -1069,13 +1105,13 @@ impl<'a> Map2 for Conv2D<'a> {
.min(p.i_w - 1);
for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx
- + src_c_idx * inp_stride[1]
- + src_h * inp_stride[2]
- + src_w * inp_stride[3];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset_h * k_stride[2]
- + offset_w * k_stride[3];
+ + src_c_idx * inp_s1
+ + src_h * inp_s2
+ + src_w * inp_s3;
+ let k_idx = dst_c_idx * k_s0
+ + src_c_idx * k_s1
+ + offset_h * k_s2
+ + offset_w * k_s3;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1670,6 +1706,15 @@ impl BackendStorage for CpuStorage {
AvgPool2D(kernel_size, stride).map(self, layout)
}
+ fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ MaxPool2D(kernel_size, stride).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
@@ -2025,35 +2070,36 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = bf16::from_f64(std);
- let mean = bf16::from_f64(mean);
+ let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = f16::from_f64(std);
- let mean = f16::from_f64(mean);
+ let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let std = std as f32;
- let mean = mean as f32;
+ let normal =
+ rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
+ let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 727ea073..a7f63353 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
- let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
@@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage {
todo!()
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index ae4dd09f..870a87cd 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 35a33032..c18b43c6 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ /// Adding path information to an error.
+ #[error("path: {path:?} {inner}")]
+ WithPath {
+ inner: Box<Self>,
+ path: std::path::PathBuf,
+ },
+
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@@ -214,6 +221,13 @@ impl Error {
},
}
}
+
+ pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
+ Self::WithPath {
+ inner: Box::new(self),
+ path: p.as_ref().to_path_buf(),
+ }
+ }
}
#[macro_export]
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 6302cf71..e17ba02a 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -307,39 +307,39 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let elem_count = self.elem_count();
+ let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
+ let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
+ let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
- for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
+ for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
- for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
+ for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
- for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
+ for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
- let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
- f.write_all(&data)?;
+ let vs = vs.to_vec1::<u8>()?;
+ f.write_all(&vs)?;
}
}
Ok(())
@@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
- // re-create a zip reader each time.
+ // re-create a zip reader for each tensor.
}
impl NpzTensors {
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index aea8b733..f99d8adc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -93,6 +93,13 @@ pub enum Op {
kernel_size: (usize, usize),
stride: (usize, usize),
},
+
+ MaxPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 1880a041..914e5101 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
- let st = safetensors::SafeTensors::deserialize(&data)?;
+ load_buffer(&data[..], device)
+}
+
+pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
+ let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
@@ -253,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
-pub struct MmapedFile(memmap2::Mmap);
+pub struct MmapedFile {
+ path: std::path::PathBuf,
+ inner: memmap2::Mmap,
+}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@@ -263,13 +270,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
- let file = std::fs::File::open(p)?;
- let mmap = memmap2::MmapOptions::new().map(&file)?;
- Ok(Self(mmap))
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let inner = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok(Self {
+ inner,
+ path: p.to_path_buf(),
+ })
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
- let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ let st = safetensors::SafeTensors::deserialize(&self.inner)
+ .map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index a5e21aad..83d11c09 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
+ pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
+ if dims.len() != $cnt {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: $cnt,
+ got: dims.len(),
+ shape: Shape::from(dims),
+ }
+ .bt())
+ } else {
+ Ok($dims(dims))
+ }
+ }
+
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
- if self.0.len() != $cnt {
- Err(Error::UnexpectedNumberOfDims {
- expected: $cnt,
- got: self.0.len(),
- shape: self.clone(),
- }
- .bt())
- } else {
- Ok($dims(&self.0))
- }
+ $fn_name(self.0.as_slice())
}
}
+
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
-extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
+extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 3ed38e6a..791b65dd 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -311,6 +311,24 @@ impl Storage {
}
}
+ pub(crate) fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index adba7376..c14a4e39 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -773,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = match *self.dims() {
- [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
- [c_in, l_in] => (None, c_in, l_in),
- _ => Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "input rank is not 2 or 3",
- }
- .bt())?,
- };
+ let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
@@ -872,6 +861,22 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
+ pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let (n, c, h, w) = self.dims4()?;
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
+ let h_out = (h - kernel_size.0) / stride.0 + 1;
+ let w_out = (w - kernel_size.1) / stride.1 + 1;
+ let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ });
+ let storage = self
+ .storage()
+ .max_pool2d(self.layout(), kernel_size, stride)?;
+ Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
+ }
+
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
new file mode 100644
index 00000000..c8ddef97
--- /dev/null
+++ b/candle-core/tests/pool_tests.rs
@@ -0,0 +1,61 @@
+mod test_utils;
+use candle_core::{Device, Tensor};
+
+// https://github.com/huggingface/candle/issues/364
+#[test]
+fn avg_pool2d() -> anyhow::Result<()> {
+ let data: Vec<f32> = vec![
+ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
+
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
+ Ok(())
+}
+
+#[test]
+fn max_pool2d() -> anyhow::Result<()> {
+ let data: Vec<f32> = vec![
+ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
+
+ let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
+ Ok(())
+}
+
+/* This test corresponds to the following PyTorch script.
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 2, 4, 4))
+print(t.flatten())
+res = torch.nn.functional.avg_pool2d(t, 2)
+print(res)
+*/
+#[test]
+fn avg_pool2d_pytorch() -> anyhow::Result<()> {
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
+ 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
+ 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
+ 0.2477, 1.3127,
+ ],
+ &Device::Cpu,
+ )?
+ .reshape((1, 2, 4, 4))?;
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
+ assert_eq!(
+ test_utils::to_vec3_round(pool, 4)?,
+ [
+ [[-1.1926, -0.0395], [0.2688, 0.1871]],
+ [[0.1835, -0.1606], [0.6249, 0.3217]]
+ ]
+ );
+ let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
+ assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
+ Ok(())
+}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 599c2665..0b77f1a5 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -869,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+
+// There was originally a bug on the CPU implementation for randn
+// https://github.com/huggingface/candle/issues/381
+#[test]
+fn randn_hasneg() -> Result<()> {
+ let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
+ if t.iter().all(|&v| v >= 0.) {
+ candle_core::bail!("all values in tensors are non-negative")
+ }
+ Ok(())
+}
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index f3a4e325..54eb0be6 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -23,12 +23,13 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
+image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
+hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
-hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
@@ -36,6 +37,8 @@ tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
+# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
+tokio = "1.29.1"
[build-dependencies]
anyhow = { workspace = true }
@@ -51,3 +54,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
+
+[[example]]
+name = "stable-diffusion"
+required-features = ["image"]
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 8bb3c56d..8ce0c234 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -245,10 +245,10 @@ fn run(args: Args) -> Result<()> {
if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
- let _image = (image * 255.)?.to_dtype(DType::U8);
- let _image_filename =
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
- // TODO: save igame
+ crate::utils::save_image(&image, image_filename)?
}
}
@@ -260,9 +260,9 @@ fn run(args: Args) -> Result<()> {
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
- let _image = (image * 255.)?.to_dtype(DType::U8);
- let _image_filename = output_filename(&final_image, idx + 1, num_samples, None);
- // TODO: save image.
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
+ crate::utils::save_image(&image, image_filename)?
}
Ok(())
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 0c95cfef..ef4dd956 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -10,3 +10,22 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
+
+/// Saves an image to disk using the image crate, this expects an input with shape
+/// (c, width, height).
+pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, width, height) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, width, height)")
+ }
+ let img = img.transpose(0, 1)?.t()?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 82c45348..dfe7a27f 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
// https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO:
// - kv-cache support?
@@ -10,7 +9,7 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{safetensors::Load, DType, Device, Tensor};
+use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
-const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
-const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
@@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
-const NO_TIMESTAMP_TOKEN: u32 = 50362;
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
@@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
+#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@@ -66,6 +62,7 @@ struct DecodingResult {
compression_ratio: f64,
}
+#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
@@ -244,16 +241,24 @@ struct Args {
#[arg(long, default_value_t = 299792458)]
seed: u64,
- /// The mel filters in safetensors format.
- #[arg(
- long,
- default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
- )]
- filters: String,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
}
fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
@@ -301,11 +306,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
- let mel_filters = mel_filters.deserialize()?;
- let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
- println!("loaded mel filters {:?}", mel_filters.shape());
- let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+ let mel_bytes = include_bytes!("melfilters.bytes");
+ let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
+ <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/melfilters.bytes
index 98f3af44..0874829e 100644
--- a/candle-examples/examples/whisper/mel_filters.safetensors
+++ b/candle-examples/examples/whisper/melfilters.bytes
Binary files differ
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 4d80c0c8..7015199d 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -1,8 +1,5 @@
-// We use anyhow rather than candle errors as it provides better support for getting the backtrace
-// back when using RUST_LIB_BACKTRACE=1.
-use anyhow::Result;
-use candle::{Device, Tensor};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
+use candle::{Device, Result, Tensor};
+use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@@ -22,6 +19,7 @@ pub struct Config {
}
impl Config {
+ #[allow(dead_code)]
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
@@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
+//
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = vb.get(size2, "bias")?;
- Ok(Linear::new(weight, Some(bias)))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn conv1d(
@@ -66,32 +80,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
-fn conv1d_no_bias(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- config: Conv1dConfig,
- vb: VarBuilder,
-) -> Result<Conv1d> {
- let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
- Ok(Conv1d::new(weight, None, config))
-}
-
-struct Dropout {
- pr: f64,
-}
-
-impl Dropout {
- fn new(pr: f64) -> Self {
- Self { pr }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- // TODO
- Ok(x.clone())
- }
-}
-
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
@@ -105,10 +93,12 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
+ span: tracing::Span,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@@ -119,10 +109,12 @@ impl MultiHeadAttention {
value,
out,
n_head,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
@@ -134,7 +126,7 @@ impl MultiHeadAttention {
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
- Ok(x.reshape(target_dims)?.transpose(1, 2)?)
+ x.reshape(target_dims)?.transpose(1, 2)
}
fn qkv_attention(
@@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
+ span: tracing::Span,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
@@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
mlp_linear1,
mlp_linear2,
mlp_ln,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
@@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
- Ok((x + mlp)?)
+ x + mlp
}
}
@@ -234,10 +230,16 @@ pub struct AudioEncoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
+ span: tracing::Span,
+ conv1_span: tracing::Span,
+ conv2_span: tracing::Span,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
+ let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
+ let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@@ -264,11 +266,22 @@ impl AudioEncoder {
positional_embedding,
blocks,
ln_post,
+ conv1_span,
+ conv2_span,
+ span,
})
}
+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = self.conv1.forward(x)?.gelu()?;
- let x = self.conv2.forward(&x)?.gelu()?;
+ let _enter = self.span.enter();
+ let x = {
+ let _enter = self.conv1_span.enter();
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _enter = self.conv2_span.enter();
+ self.conv2.forward(&x)?.gelu()?
+ };
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
@@ -288,10 +301,12 @@ pub struct TextDecoder {
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
+ span: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@@ -314,10 +329,12 @@ impl TextDecoder {
blocks,
ln,
mask,
+ span,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
@@ -354,6 +371,7 @@ impl Whisper {
})
}
+ #[allow(dead_code)]
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 285aee04..2b6009b4 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}
+
+#[cfg(test)]
+mod tests {
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+ .tensor("bert.encoder.layer.0.attention.self.query.weight")
+ .unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+ panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+}