diff options
195 files changed, 14760 insertions, 1790 deletions
@@ -23,6 +23,7 @@ flamegraph.svg *.dylib *.so *.swp +*.swo trace-*.json candle-wasm-examples/*/build diff --git a/CHANGELOG.md b/CHANGELOG.md index a52429cf..df9574d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,58 @@ # Changelog This documents the main changes to the `candle` crate. -## v0.2.1 - Unreleased +## v0.2.3 - Unreleased ### Added ### Modified + +## v0.2.2 - 2023-09-18 + +### Added +- Support for `top_p` sampling + [819](https://github.com/huggingface/candle/pull/819). +- T5 model including decoding + [864](https://github.com/huggingface/candle/pull/864). +- 1-d upsampling + [839](https://github.com/huggingface/candle/pull/839). + +### Modified +- Bugfix for conv2d + [820](https://github.com/huggingface/candle/pull/820). +- Support tensor based indexing using `.i` + [842](https://github.com/huggingface/candle/pull/842). + +## v0.2.1 - 2023-09-11 + +### Added +- Add some RNNs (GRU and LSTM) in `candle-nn` + [674](https://github.com/huggingface/candle/pull/674), + [688](https://github.com/huggingface/candle/pull/688). +- gguf v2 support + [725](https://github.com/huggingface/candle/pull/725). +- Quantized llama example in Python using the pyo3 api + [716](https://github.com/huggingface/candle/pull/716). +- `candle-nn` layer for conv2d-transposed + [760](https://github.com/huggingface/candle/pull/760). +- Add the Segment-Anything Model (SAM) as an example + [773](https://github.com/huggingface/candle/pull/773). +- TinyViT backbone for the segemnt anything example + [787](https://github.com/huggingface/candle/pull/787). +- Shape with holes support + [770](https://github.com/huggingface/candle/pull/770). + +### Modified - Dilations are now supported in conv-transpose2d. [671](https://github.com/huggingface/candle/pull/671). +- Interactive mode for the quantized model + [690](https://github.com/huggingface/candle/pull/690). +- Faster softmax operation + [747](https://github.com/huggingface/candle/pull/747). +- Faster convolution operations on CPU and CUDA via im2col + [802](https://github.com/huggingface/candle/pull/802). +- Moving some models to a more central location + [796](https://github.com/huggingface/candle/pull/796). ## v0.2.0 - 2023-08-30 @@ -8,17 +8,16 @@ members = [ "candle-pyo3", "candle-transformers", "candle-wasm-examples/llama2-c", + "candle-wasm-examples/segment-anything", "candle-wasm-examples/whisper", "candle-wasm-examples/yolo", + "candle-wasm-examples/bert", ] -exclude = [ - "candle-flash-attn", - "candle-kernels", -] +exclude = ["candle-flash-attn", "candle-kernels"] resolver = "2" [workspace.package] -version = "0.2.1" +version = "0.2.3" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,7 +32,7 @@ byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } cudarc = { version = "0.9.14", features = ["f16"] } # TODO: Switch back to the official gemm implementation once it has caught up. -gemm = { version = "0.15.6", package = "candle-gemm" } +gemm = { version = "0.16.0", package = "candle-gemm" } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } @@ -8,7 +8,9 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ and ease of use. Try our online demos: [whisper](https://huggingface.co/spaces/lmz/candle-whisper), [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2), -[yolo](https://huggingface.co/spaces/lmz/candle-yolo). +[yolo](https://huggingface.co/spaces/lmz/candle-yolo), +[Segment +Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm). ## Get started @@ -45,37 +47,54 @@ For more advanced examples, please have a look at the following section. ## Check out our examples -Check out our [examples](./candle-examples/examples/): +These online demos run entirely in your browser: +- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and + object recognition. +- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech. +- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation. +- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation. + +We also provide a some command line based examples using state of the art models: -- [Whisper](./candle-examples/examples/whisper/): speech recognition model. - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM. -- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings. - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. -- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to - image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions. -- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - using self-supervision (can be used for imagenet classification, depth - evaluation, segmentation). - [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of the LLaMA model using the same quantization techniques as [llama.cpp](https://github.com/ggerganov/llama.cpp). + +<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600"> + +- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to + image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions. + +<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200"> + +- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to + image generative model. + +<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200"> + - [yolo-v3](./candle-examples/examples/yolo-v3/) and [yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose estimation models. -Run them using the following commands: -``` -cargo run --example whisper --release -cargo run --example llama --release -cargo run --example falcon --release -cargo run --example bert --release -cargo run --example bigcode --release -cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch" -cargo run --example dinov2 --release -- --image path/to/myinput.jpg + +<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200"> +- [segment-anything](./candle-examples/examples/segment-anything/): image + segmentation model with prompt. + +<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200"> + +- [Whisper](./candle-examples/examples/whisper/): speech recognition model. +- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings. +- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained + using self-supervision (can be used for imagenet classification, depth + evaluation, segmentation). + +Run them using commands like: +``` cargo run --example quantized --release -cargo run --example yolo-v3 --release -- myimage.jpg -cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose ``` In order to use **CUDA** add `--features cuda` to the example command line. If @@ -85,7 +104,8 @@ There are also some wasm examples for whisper and [llama2.c](https://github.com/karpathy/llama2.c). You can either build them with `trunk` or try them online: [whisper](https://huggingface.co/spaces/lmz/candle-whisper), -[llama2](https://huggingface.co/spaces/lmz/candle-llama2). +[llama2](https://huggingface.co/spaces/lmz/candle-llama2), +[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm). For LLaMA2, run the following command to retrieve the weight files and start a test server: @@ -98,6 +118,15 @@ trunk serve --release --port 8081 And then head over to [http://localhost:8081/](http://localhost:8081/). +<!--- ANCHOR: useful_libraries ---> + +## Useful Libraries +- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation. + +If you have an addition to this list, please submit a pull request. + +<!--- ANCHOR_END: useful_libraries ---> + <!--- ANCHOR: features ---> ## Features @@ -110,10 +139,21 @@ And then head over to - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL. - WASM support, run your models in a browser. - Included models. - - LLMs: LLaMA v1 and v2, Falcon, StarCoder. + - Language Models. + - LLaMA v1 and v2. + - Falcon. + - StarCoder. + - T5. + - Bert. - Whisper (multi-lingual support). - - Stable Diffusion. - - Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8. + - Stable Diffusion v1.5, v2.1, XL v1.0. + - Wurstchen v2. + - Computer Vision Models. + - DINOv2. + - EfficientNet. + - yolo-v3. + - yolo-v8. + - Segment-Anything Model (SAM). - File formats: load models from safetensors, npz, ggml, or PyTorch files. - Serverless (on CPU), small and fast deployments. - Quantization support using the llama.cpp quantized types. @@ -243,6 +283,35 @@ authentication token. See issue git submodule update --init ``` +#### Compiling with flash-attention fails + +``` +/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’: +``` + +This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable. +``` +env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ... +``` + +#### Linking error on windows when running rustdoc or mdbook tests + +``` +Couldn't compile the test. +---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ---- +error: linking with `link.exe` failed: exit code: 1181 +//very long chain of linking + = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib' +``` + +Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run: + +``` +mdbook test candle-book -L .\target\debug\deps\ ` +-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib ` +-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib +``` + #### Tracking down errors You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index 320fb887..8ec92e87 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.2.1" } -candle-nn = { path = "../candle-nn", version = "0.2.1" } -candle-transformers = { path = "../candle-transformers", version = "0.2.1" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.2.3" } +candle-nn = { path = "../candle-nn", version = "0.2.3" } +candle-transformers = { path = "../candle-transformers", version = "0.2.3" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index 1d05568a..59831af2 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -10,10 +10,10 @@ # Reference Guide -- [Running a model](inference/README.md) +- [Running a model](inference/inference.md) - [Using the hub](inference/hub.md) - [Error management](error_manage.md) -- [Training](training/README.md) +- [Training](training/training.md) - [Simplified](training/simplified.md) - [MNIST](training/mnist.md) - [Fine-tuning]() diff --git a/candle-book/src/error_manage.md b/candle-book/src/error_manage.md index c1a16bd9..0623e0e3 100644 --- a/candle-book/src/error_manage.md +++ b/candle-book/src/error_manage.md @@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`: Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] } ``` -Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }` +Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }` Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces diff --git a/candle-book/src/guide/hello_world.md b/candle-book/src/guide/hello_world.md index fc4af0e1..b5b8d7b4 100644 --- a/candle-book/src/guide/hello_world.md +++ b/candle-book/src/guide/hello_world.md @@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content: ```rust # extern crate candle_core; -use candle_core::{DType, Device, Result, Tensor}; +use candle_core::{Device, Result, Tensor}; struct Model { first: Tensor, @@ -25,11 +25,11 @@ fn main() -> Result<()> { // Use Device::new_cuda(0)?; to use the GPU. let device = Device::Cpu; - let first = Tensor::zeros((784, 100), DType::F32, &device)?; - let second = Tensor::zeros((100, 10), DType::F32, &device)?; + let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?; + let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?; let model = Model { first, second }; - let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let digit = model.forward(&dummy_image)?; println!("Digit {digit:?} digit"); @@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such ```rust # extern crate candle_core; -# use candle_core::{DType, Device, Result, Tensor}; +# use candle_core::{Device, Result, Tensor}; struct Linear{ weight: Tensor, bias: Tensor, @@ -80,7 +80,7 @@ This will change the model running code into a new function ```rust # extern crate candle_core; -# use candle_core::{DType, Device, Result, Tensor}; +# use candle_core::{Device, Result, Tensor}; # struct Linear{ # weight: Tensor, # bias: Tensor, @@ -110,15 +110,15 @@ fn main() -> Result<()> { let device = Device::cuda_if_available(0)?; // Creating a dummy model - let weight = Tensor::zeros((784, 100), DType::F32, &device)?; - let bias = Tensor::zeros((100, ), DType::F32, &device)?; + let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?; + let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let first = Linear{weight, bias}; - let weight = Tensor::zeros((100, 10), DType::F32, &device)?; - let bias = Tensor::zeros((10, ), DType::F32, &device)?; + let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?; + let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let second = Linear{weight, bias}; let model = Model { first, second }; - let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; // Inference on the model let digit = model.forward(&dummy_image)?; @@ -146,7 +146,7 @@ And rewrite our examples using it ```rust # extern crate candle_core; # extern crate candle_nn; -use candle_core::{DType, Device, Result, Tensor}; +use candle_core::{Device, Result, Tensor}; use candle_nn::{Linear, Module}; struct Model { @@ -167,15 +167,15 @@ fn main() -> Result<()> { let device = Device::Cpu; // This has changed (784, 100) -> (100, 784) ! - let weight = Tensor::zeros((100, 784), DType::F32, &device)?; - let bias = Tensor::zeros((100, ), DType::F32, &device)?; + let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?; + let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let first = Linear::new(weight, Some(bias)); - let weight = Tensor::zeros((10, 100), DType::F32, &device)?; - let bias = Tensor::zeros((10, ), DType::F32, &device)?; + let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?; + let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let second = Linear::new(weight, Some(bias)); let model = Model { first, second }; - let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let digit = model.forward(&dummy_image)?; println!("Digit {digit:?} digit"); @@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i Now that we have the running dummy code we can get to more advanced topics: -- [For PyTorch users](./guide/cheatsheet.md) -- [Running existing models](./inference/README.md) -- [Training models](./training/README.md) +- [For PyTorch users](../guide/cheatsheet.md) +- [Running existing models](../inference/inference.md) +- [Training models](../training/training.md) diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/inference.md index 1b75a310..1b75a310 100644 --- a/candle-book/src/inference/README.md +++ b/candle-book/src/inference/inference.md diff --git a/candle-book/src/training/README.md b/candle-book/src/training/training.md index d68a917e..d68a917e 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/training.md diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index e7213919..7af9b6fa 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true } +candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true } cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs deleted file mode 100644 index 13175ac1..00000000 --- a/candle-core/examples/cpu_benchmarks.rs +++ /dev/null @@ -1,166 +0,0 @@ -/// This example contains some simple benchmarks so that it's easy to run them in perf etc. -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; - -#[cfg(feature = "accelerate")] -extern crate accelerate_src; - -use candle_core::quantized::GgmlType; -use candle_core::{Device, Result, Tensor, D}; -use clap::{Parser, Subcommand}; - -fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> { - let dim = dim.to_index(xs.shape(), "softmax")?; - let max = xs.max_keepdim(dim)?; - let diff = xs.broadcast_sub(&max)?; - let num = diff.exp()?; - let den = num.sum_keepdim(dim)?; - num.broadcast_div(&den) -} - -trait Benchmark { - type PreProcessData; - type RunResult; - - fn preprocess() -> Result<Self::PreProcessData>; - fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>; - - const ITERS: usize; -} - -// Conv1d example as used in whisper. -struct Conv1d; -impl Benchmark for Conv1d { - type PreProcessData = (Tensor, Tensor); - type RunResult = Tensor; - fn preprocess() -> Result<Self::PreProcessData> { - let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?; - let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?; - Ok((inp, w)) - } - - fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { - d.0.conv1d(&d.1, 0, 1, 1, 1) - } - - const ITERS: usize = 5; -} - -// Conv2d example as used in stable-diffusion. -struct Conv2d; -impl Benchmark for Conv2d { - type PreProcessData = (Tensor, Tensor); - type RunResult = Tensor; - - fn preprocess() -> Result<Self::PreProcessData> { - let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; - let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; - Ok((inp, w)) - } - - fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { - d.0.conv2d(&d.1, 0, 1, 1, 1) - } - - const ITERS: usize = 1; -} - -struct Matmul; -impl Benchmark for Matmul { - type PreProcessData = (Tensor, Tensor); - type RunResult = Tensor; - fn preprocess() -> Result<Self::PreProcessData> { - let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; - let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; - Ok((lhs, rhs)) - } - - fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { - d.0.matmul(&d.1) - } - - const ITERS: usize = 100; -} - -// This benchmark is similar to: -// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp -struct QMatMul; -impl Benchmark for QMatMul { - type PreProcessData = (candle_core::quantized::QMatMul, Tensor); - type RunResult = Tensor; - fn preprocess() -> Result<Self::PreProcessData> { - let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; - let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?; - let mm = candle_core::quantized::QMatMul::from_qtensor(mm); - let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; - Ok((mm, arg)) - } - - fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { - d.0.forward(&d.1) - } - - const ITERS: usize = 100; -} - -struct Softmax; -impl Benchmark for Softmax { - type PreProcessData = Tensor; - type RunResult = Tensor; - fn preprocess() -> Result<Self::PreProcessData> { - // Typical whisper tiny size. - let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; - Ok(x) - } - - fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { - softmax(d, D::Minus1) - } - - const ITERS: usize = 100; -} - -fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> { - use std::hint::black_box; - - let iters = iters.unwrap_or(B::ITERS); - let d = B::preprocess()?; - let start = std::time::Instant::now(); - for _iter in 0..iters { - let _res = black_box(B::run_one(black_box(&d))?); - } - println!("{:?}", start.elapsed() / iters as u32); - Ok(()) -} - -#[derive(Subcommand, Debug, Clone)] -enum Task { - Conv1d, - Conv2d, - Matmul, - Qmatmul, - Softmax, -} - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// The benchmark to be run. - #[command(subcommand)] - task: Task, - - #[arg(long)] - iters: Option<usize>, -} - -fn main() -> Result<()> { - let args = Args::parse(); - match args.task { - Task::Conv1d => run::<Conv1d>(args.iters)?, - Task::Conv2d => run::<Conv2d>(args.iters)?, - Task::Matmul => run::<Matmul>(args.iters)?, - Task::Softmax => run::<Softmax>(args.iters)?, - Task::Qmatmul => run::<QMatMul>(args.iters)?, - } - Ok(()) -} diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 2bc1fa2e..c3459004 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R Ok(()) } +fn run_quantize_safetensors( + in_file: std::path::PathBuf, + out_file: std::path::PathBuf, + q: Quantization, +) -> Result<()> { + let mut out_file = std::fs::File::create(out_file)?; + let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + println!("tensors: {}", tensors.len()); + + let quantize_fn = match q { + Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>, + Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>, + Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>, + Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>, + Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>, + Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>, + Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>, + Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>, + Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>, + Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>, + Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>, + Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>, + Quantization::F16 => QTensor::quantize::<half::f16>, + Quantization::F32 => QTensor::quantize::<f32>, + }; + + let qtensors = tensors + .into_par_iter() + .map(|(name, tensor)| { + println!(" quantizing {name} {tensor:?}"); + let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0; + let tensor = if should_quantize { + quantize_fn(&tensor)? + } else { + QTensor::quantize::<f32>(&tensor)? + }; + Ok((name, tensor)) + }) + .collect::<Result<Vec<_>>>()?; + let qtensors = qtensors + .iter() + .map(|(k, v)| (k.as_str(), v)) + .collect::<Vec<_>>(); + gguf_file::write(&mut out_file, &[], &qtensors)?; + Ok(()) +} + fn run_quantize( in_file: std::path::PathBuf, out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, ) -> Result<()> { + if let Some(extension) = in_file.extension() { + if extension == "safetensors" { + return run_quantize_safetensors(in_file, out_file, q); + } + } + // Open the out file early so as to fail directly on missing directories etc. let mut out_file = std::fs::File::create(out_file)?; let mut in_ = std::fs::File::open(&in_file)?; diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs index 87e0ee8d..1cb34e19 100644 --- a/candle-core/src/accelerate.rs +++ b/candle-core/src/accelerate.rs @@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) { y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) } +#[inline] +pub fn vs_tanh_inplace(y: &mut [f32]) { + unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vd_tanh_inplace(y: &mut [f64]) { + unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vs_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vd_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + macro_rules! binary_op { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => { #[inline] diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 67a08714..03a07434 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -57,6 +57,7 @@ pub trait BackendStorage: Sized { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d2099df7..a2548198 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -91,13 +91,14 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest1D(node) | Op::UpsampleNearest2D(node) | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) - | Op::Reduce(node, _, _) + | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -111,6 +112,7 @@ impl Tensor { track_grad |= tg; nodes } + Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { nodes @@ -262,6 +264,9 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; } + Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest1d", + })?, Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, @@ -437,6 +442,10 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, + Op::Unary(_, UnaryOp::GeluErf) => { + Err(Error::BackwardNotSupported { op: "gelu-erf" })? + } Op::Unary(arg, UnaryOp::Relu) => { let sum_grad = grads.or_insert(arg)?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; @@ -517,6 +526,7 @@ impl Tensor { } } +#[derive(Debug)] pub struct GradStore(HashMap<TensorId, Tensor>); impl GradStore { diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs new file mode 100644 index 00000000..ca6be53f --- /dev/null +++ b/candle-core/src/cpu/erf.rs @@ -0,0 +1,763 @@ +#![allow(clippy::excessive_precision)] +// Code taken from https://github.com/statrs-dev/statrs +//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and +//! related functions + +mod evaluate { + //! Provides functions that don't have a numerical solution and must + //! be solved computationally (e.g. evaluation of a polynomial) + + /// evaluates a polynomial at `z` where `coeff` are the coeffecients + /// to a polynomial of order `k` where `k` is the length of `coeff` and the + /// coeffecient + /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to + /// `2z^2 - z + 3` + /// + /// # Remarks + /// + /// Returns 0 for a 0 length coefficient slice + pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { + let n = coeff.len(); + if n == 0 { + return 0.0; + } + + let mut sum = *coeff.last().unwrap(); + for c in coeff[0..n - 1].iter().rev() { + sum = *c + z * sum; + } + sum + } +} +use std::f64; + +/// `erf` calculates the error function at `x`. +pub fn erf(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x >= 0.0 && x.is_infinite() { + 1.0 + } else if x <= 0.0 && x.is_infinite() { + -1.0 + } else if x == 0. { + 0.0 + } else { + erf_impl(x, false) + } +} + +/// `erf_inv` calculates the inverse error function +/// at `x`. +pub fn erf_inv(x: f64) -> f64 { + if x == 0.0 { + 0.0 + } else if x >= 1.0 { + f64::INFINITY + } else if x <= -1.0 { + f64::NEG_INFINITY + } else if x < 0.0 { + erf_inv_impl(-x, 1.0 + x, -1.0) + } else { + erf_inv_impl(x, 1.0 - x, 1.0) + } +} + +/// `erfc` calculates the complementary error function +/// at `x`. +pub fn erfc(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x == f64::INFINITY { + 0.0 + } else if x == f64::NEG_INFINITY { + 2.0 + } else { + erf_impl(x, true) + } +} + +/// `erfc_inv` calculates the complementary inverse +/// error function at `x`. +pub fn erfc_inv(x: f64) -> f64 { + if x <= 0.0 { + f64::INFINITY + } else if x >= 2.0 { + f64::NEG_INFINITY + } else if x > 1.0 { + erf_inv_impl(-1.0 + x, 2.0 - x, -1.0) + } else { + erf_inv_impl(1.0 - x, x, 1.0) + } +} + +// ********************************************************** +// ********** Coefficients for erf_impl polynomial ********** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_impl` +/// in the interval [1e-10, 0.5]. +const ERF_IMPL_AN: &[f64] = &[ + 0.00337916709551257388990745, + -0.00073695653048167948530905, + -0.374732337392919607868241, + 0.0817442448733587196071743, + -0.0421089319936548595203468, + 0.0070165709512095756344528, + -0.00495091255982435110337458, + 0.000871646599037922480317225, +]; + +/// Polynomial coefficients for a denominator of `erf_impl` +/// in the interval [1e-10, 0.5] +const ERF_IMPL_AD: &[f64] = &[ + 1.0, + -0.218088218087924645390535, + 0.412542972725442099083918, + -0.0841891147873106755410271, + 0.0655338856400241519690695, + -0.0120019604454941768171266, + 0.00408165558926174048329689, + -0.000615900721557769691924509, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BN: &[f64] = &[ + -0.0361790390718262471360258, + 0.292251883444882683221149, + 0.281447041797604512774415, + 0.125610208862766947294894, + 0.0274135028268930549240776, + 0.00250839672168065762786937, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BD: &[f64] = &[ + 1.0, + 1.8545005897903486499845, + 1.43575803037831418074962, + 0.582827658753036572454135, + 0.124810476932949746447682, + 0.0113724176546353285778481, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CN: &[f64] = &[ + -0.0397876892611136856954425, + 0.153165212467878293257683, + 0.191260295600936245503129, + 0.10276327061989304213645, + 0.029637090615738836726027, + 0.0046093486780275489468812, + 0.000307607820348680180548455, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CD: &[f64] = &[ + 1.0, + 1.95520072987627704987886, + 1.64762317199384860109595, + 0.768238607022126250082483, + 0.209793185936509782784315, + 0.0319569316899913392596356, + 0.00213363160895785378615014, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DN: &[f64] = &[ + -0.0300838560557949717328341, + 0.0538578829844454508530552, + 0.0726211541651914182692959, + 0.0367628469888049348429018, + 0.00964629015572527529605267, + 0.00133453480075291076745275, + 0.778087599782504251917881e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DD: &[f64] = &[ + 1.0, + 1.75967098147167528287343, + 1.32883571437961120556307, + 0.552528596508757581287907, + 0.133793056941332861912279, + 0.0179509645176280768640766, + 0.00104712440019937356634038, + -0.106640381820357337177643e-7, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_EN: &[f64] = &[ + -0.0117907570137227847827732, + 0.014262132090538809896674, + 0.0202234435902960820020765, + 0.00930668299990432009042239, + 0.00213357802422065994322516, + 0.00025022987386460102395382, + 0.120534912219588189822126e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_ED: &[f64] = &[ + 1.0, + 1.50376225203620482047419, + 0.965397786204462896346934, + 0.339265230476796681555511, + 0.0689740649541569716897427, + 0.00771060262491768307365526, + 0.000371421101531069302990367, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FN: &[f64] = &[ + -0.00546954795538729307482955, + 0.00404190278731707110245394, + 0.0054963369553161170521356, + 0.00212616472603945399437862, + 0.000394984014495083900689956, + 0.365565477064442377259271e-4, + 0.135485897109932323253786e-5, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FD: &[f64] = &[ + 1.0, + 1.21019697773630784832251, + 0.620914668221143886601045, + 0.173038430661142762569515, + 0.0276550813773432047594539, + 0.00240625974424309709745382, + 0.891811817251336577241006e-4, + -0.465528836283382684461025e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GN: &[f64] = &[ + -0.00270722535905778347999196, + 0.0013187563425029400461378, + 0.00119925933261002333923989, + 0.00027849619811344664248235, + 0.267822988218331849989363e-4, + 0.923043672315028197865066e-6, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GD: &[f64] = &[ + 1.0, + 0.814632808543141591118279, + 0.268901665856299542168425, + 0.0449877216103041118694989, + 0.00381759663320248459168994, + 0.000131571897888596914350697, + 0.404815359675764138445257e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HN: &[f64] = &[ + -0.00109946720691742196814323, + 0.000406425442750422675169153, + 0.000274499489416900707787024, + 0.465293770646659383436343e-4, + 0.320955425395767463401993e-5, + 0.778286018145020892261936e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HD: &[f64] = &[ + 1.0, + 0.588173710611846046373373, + 0.139363331289409746077541, + 0.0166329340417083678763028, + 0.00100023921310234908642639, + 0.24254837521587225125068e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_IN: &[f64] = &[ + -0.00056907993601094962855594, + 0.000169498540373762264416984, + 0.518472354581100890120501e-4, + 0.382819312231928859704678e-5, + 0.824989931281894431781794e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_ID: &[f64] = &[ + 1.0, + 0.339637250051139347430323, + 0.043472647870310663055044, + 0.00248549335224637114641629, + 0.535633305337152900549536e-4, + -0.117490944405459578783846e-12, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JN: &[f64] = &[ + -0.000241313599483991337479091, + 0.574224975202501512365975e-4, + 0.115998962927383778460557e-4, + 0.581762134402593739370875e-6, + 0.853971555085673614607418e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JD: &[f64] = &[ + 1.0, + 0.233044138299687841018015, + 0.0204186940546440312625597, + 0.000797185647564398289151125, + 0.117019281670172327758019e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KN: &[f64] = &[ + -0.000146674699277760365803642, + 0.162666552112280519955647e-4, + 0.269116248509165239294897e-5, + 0.979584479468091935086972e-7, + 0.101994647625723465722285e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KD: &[f64] = &[ + 1.0, + 0.165907812944847226546036, + 0.0103361716191505884359634, + 0.000286593026373868366935721, + 0.298401570840900340874568e-5, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LN: &[f64] = &[ + -0.583905797629771786720406e-4, + 0.412510325105496173512992e-5, + 0.431790922420250949096906e-6, + 0.993365155590013193345569e-8, + 0.653480510020104699270084e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LD: &[f64] = &[ + 1.0, + 0.105077086072039915406159, + 0.00414278428675475620830226, + 0.726338754644523769144108e-4, + 0.477818471047398785369849e-6, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MN: &[f64] = &[ + -0.196457797609229579459841e-4, + 0.157243887666800692441195e-5, + 0.543902511192700878690335e-7, + 0.317472492369117710852685e-9, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MD: &[f64] = &[ + 1.0, + 0.052803989240957632204885, + 0.000926876069151753290378112, + 0.541011723226630257077328e-5, + 0.535093845803642394908747e-15, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_NN: &[f64] = &[ + -0.789224703978722689089794e-5, + 0.622088451660986955124162e-6, + 0.145728445676882396797184e-7, + 0.603715505542715364529243e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_ND: &[f64] = &[ + 1.0, + 0.0375328846356293715248719, + 0.000467919535974625308126054, + 0.193847039275845656900547e-5, +]; + +// ********************************************************** +// ********** Coefficients for erf_inv_impl polynomial ****** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AN: &[f64] = &[ + -0.000508781949658280665617, + -0.00836874819741736770379, + 0.0334806625409744615033, + -0.0126926147662974029034, + -0.0365637971411762664006, + 0.0219878681111168899165, + 0.00822687874676915743155, + -0.00538772965071242932965, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AD: &[f64] = &[ + 1.0, + -0.970005043303290640362, + -1.56574558234175846809, + 1.56221558398423026363, + 0.662328840472002992063, + -0.71228902341542847553, + -0.0527396382340099713954, + 0.0795283687341571680018, + -0.00233393759374190016776, + 0.000886216390456424707504, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BN: &[f64] = &[ + -0.202433508355938759655, + 0.105264680699391713268, + 8.37050328343119927838, + 17.6447298408374015486, + -18.8510648058714251895, + -44.6382324441786960818, + 17.445385985570866523, + 21.1294655448340526258, + -3.67192254707729348546, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BD: &[f64] = &[ + 1.0, + 6.24264124854247537712, + 3.9713437953343869095, + -28.6608180499800029974, + -20.1432634680485188801, + 48.5609213108739935468, + 10.8268667355460159008, + -22.6436933413139721736, + 1.72114765761200282724, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CN: &[f64] = &[ + -0.131102781679951906451, + -0.163794047193317060787, + 0.117030156341995252019, + 0.387079738972604337464, + 0.337785538912035898924, + 0.142869534408157156766, + 0.0290157910005329060432, + 0.00214558995388805277169, + -0.679465575181126350155e-6, + 0.285225331782217055858e-7, + -0.681149956853776992068e-9, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CD: &[f64] = &[ + 1.0, + 3.46625407242567245975, + 5.38168345707006855425, + 4.77846592945843778382, + 2.59301921623620271374, + 0.848854343457902036425, + 0.152264338295331783612, + 0.01105924229346489121, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DN: &[f64] = &[ + -0.0350353787183177984712, + -0.00222426529213447927281, + 0.0185573306514231072324, + 0.00950804701325919603619, + 0.00187123492819559223345, + 0.000157544617424960554631, + 0.460469890584317994083e-5, + -0.230404776911882601748e-9, + 0.266339227425782031962e-11, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DD: &[f64] = &[ + 1.0, + 1.3653349817554063097, + 0.762059164553623404043, + 0.220091105764131249824, + 0.0341589143670947727934, + 0.00263861676657015992959, + 0.764675292302794483503e-4, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_EN: &[f64] = &[ + -0.0167431005076633737133, + -0.00112951438745580278863, + 0.00105628862152492910091, + 0.000209386317487588078668, + 0.149624783758342370182e-4, + 0.449696789927706453732e-6, + 0.462596163522878599135e-8, + -0.281128735628831791805e-13, + 0.99055709973310326855e-16, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_ED: &[f64] = &[ + 1.0, + 0.591429344886417493481, + 0.138151865749083321638, + 0.0160746087093676504695, + 0.000964011807005165528527, + 0.275335474764726041141e-4, + 0.282243172016108031869e-6, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FN: &[f64] = &[ + -0.0024978212791898131227, + -0.779190719229053954292e-5, + 0.254723037413027451751e-4, + 0.162397777342510920873e-5, + 0.396341011304801168516e-7, + 0.411632831190944208473e-9, + 0.145596286718675035587e-11, + -0.116765012397184275695e-17, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FD: &[f64] = &[ + 1.0, + 0.207123112214422517181, + 0.0169410838120975906478, + 0.000690538265622684595676, + 0.145007359818232637924e-4, + 0.144437756628144157666e-6, + 0.509761276599778486139e-9, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GN: &[f64] = &[ + -0.000539042911019078575891, + -0.28398759004727721098e-6, + 0.899465114892291446442e-6, + 0.229345859265920864296e-7, + 0.225561444863500149219e-9, + 0.947846627503022684216e-12, + 0.135880130108924861008e-14, + -0.348890393399948882918e-21, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GD: &[f64] = &[ + 1.0, + 0.0845746234001899436914, + 0.00282092984726264681981, + 0.468292921940894236786e-4, + 0.399968812193862100054e-6, + 0.161809290887904476097e-8, + 0.231558608310259605225e-11, +]; + +/// `erf_impl` computes the error function at `z`. +/// If `inv` is true, `1 - erf` is calculated as opposed to `erf` +fn erf_impl(z: f64, inv: bool) -> f64 { + if z < 0.0 { + if !inv { + return -erf_impl(-z, false); + } + if z < -0.5 { + return 2.0 - erf_impl(-z, true); + } + return 1.0 + erf_impl(-z, false); + } + + let result = if z < 0.5 { + if z < 1e-10 { + z * 1.125 + z * 0.003379167095512573896158903121545171688 + } else { + z * 1.125 + + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD) + } + } else if z < 110.0 { + let (r, b) = if z < 0.75 { + ( + evaluate::polynomial(z - 0.5, ERF_IMPL_BN) + / evaluate::polynomial(z - 0.5, ERF_IMPL_BD), + 0.3440242112, + ) + } else if z < 1.25 { + ( + evaluate::polynomial(z - 0.75, ERF_IMPL_CN) + / evaluate::polynomial(z - 0.75, ERF_IMPL_CD), + 0.419990927, + ) + } else if z < 2.25 { + ( + evaluate::polynomial(z - 1.25, ERF_IMPL_DN) + / evaluate::polynomial(z - 1.25, ERF_IMPL_DD), + 0.4898625016, + ) + } else if z < 3.5 { + ( + evaluate::polynomial(z - 2.25, ERF_IMPL_EN) + / evaluate::polynomial(z - 2.25, ERF_IMPL_ED), + 0.5317370892, + ) + } else if z < 5.25 { + ( + evaluate::polynomial(z - 3.5, ERF_IMPL_FN) + / evaluate::polynomial(z - 3.5, ERF_IMPL_FD), + 0.5489973426, + ) + } else if z < 8.0 { + ( + evaluate::polynomial(z - 5.25, ERF_IMPL_GN) + / evaluate::polynomial(z - 5.25, ERF_IMPL_GD), + 0.5571740866, + ) + } else if z < 11.5 { + ( + evaluate::polynomial(z - 8.0, ERF_IMPL_HN) + / evaluate::polynomial(z - 8.0, ERF_IMPL_HD), + 0.5609807968, + ) + } else if z < 17.0 { + ( + evaluate::polynomial(z - 11.5, ERF_IMPL_IN) + / evaluate::polynomial(z - 11.5, ERF_IMPL_ID), + 0.5626493692, + ) + } else if z < 24.0 { + ( + evaluate::polynomial(z - 17.0, ERF_IMPL_JN) + / evaluate::polynomial(z - 17.0, ERF_IMPL_JD), + 0.5634598136, + ) + } else if z < 38.0 { + ( + evaluate::polynomial(z - 24.0, ERF_IMPL_KN) + / evaluate::polynomial(z - 24.0, ERF_IMPL_KD), + 0.5638477802, + ) + } else if z < 60.0 { + ( + evaluate::polynomial(z - 38.0, ERF_IMPL_LN) + / evaluate::polynomial(z - 38.0, ERF_IMPL_LD), + 0.5640528202, + ) + } else if z < 85.0 { + ( + evaluate::polynomial(z - 60.0, ERF_IMPL_MN) + / evaluate::polynomial(z - 60.0, ERF_IMPL_MD), + 0.5641309023, + ) + } else { + ( + evaluate::polynomial(z - 85.0, ERF_IMPL_NN) + / evaluate::polynomial(z - 85.0, ERF_IMPL_ND), + 0.5641584396, + ) + }; + let g = (-z * z).exp() / z; + g * b + g * r + } else { + 0.0 + }; + + if inv && z >= 0.5 { + result + } else if z >= 0.5 || inv { + 1.0 - result + } else { + result + } +} + +// `erf_inv_impl` computes the inverse error function where +// `p`,`q`, and `s` are the first, second, and third intermediate +// parameters respectively +fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 { + let result = if p <= 0.5 { + let y = 0.0891314744949340820313; + let g = p * (p + 10.0); + let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD); + g * y + g * r + } else if q >= 0.25 { + let y = 2.249481201171875; + let g = (-2.0 * q.ln()).sqrt(); + let xs = q - 0.25; + let r = + evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD); + g / (y + r) + } else { + let x = (-q.ln()).sqrt(); + if x < 3.0 { + let y = 0.807220458984375; + let xs = x - 1.125; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN) + / evaluate::polynomial(xs, ERF_INV_IMPL_CD); + y * x + r * x + } else if x < 6.0 { + let y = 0.93995571136474609375; + let xs = x - 3.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN) + / evaluate::polynomial(xs, ERF_INV_IMPL_DD); + y * x + r * x + } else if x < 18.0 { + let y = 0.98362827301025390625; + let xs = x - 6.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN) + / evaluate::polynomial(xs, ERF_INV_IMPL_ED); + y * x + r * x + } else if x < 44.0 { + let y = 0.99714565277099609375; + let xs = x - 18.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN) + / evaluate::polynomial(xs, ERF_INV_IMPL_FD); + y * x + r * x + } else { + let y = 0.99941349029541015625; + let xs = x - 44.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN) + / evaluate::polynomial(xs, ERF_INV_IMPL_GD); + y * x + r * x + } + }; + s * result +} diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 97e195ef..527646d6 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -1,4 +1,7 @@ -pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { +pub trait VecOps: num_traits::NumAssign + Copy { + fn min(self, rhs: Self) -> Self; + fn max(self, rhs: Self) -> Self; + /// Dot-product of two vectors. /// /// # Safety @@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x > *res { - *res = x - } + *res = (*res).max(*xs.add(i)) } } @@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x < *res { - *res = x - } + *res = (*res).min(*xs.add(i)) } } } impl VecOps for f32 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { super::vec_dot_f32(lhs, rhs, res, len) } @@ -76,6 +83,16 @@ impl VecOps for f32 { impl VecOps for half::f16 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { let mut res_f32 = 0f32; super::vec_dot_f16(lhs, rhs, &mut res_f32, len); @@ -83,11 +100,61 @@ impl VecOps for half::f16 { } } -impl VecOps for f64 {} -impl VecOps for half::bf16 {} -impl VecOps for u8 {} -impl VecOps for u32 {} -impl VecOps for i64 {} +impl VecOps for f64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for half::bf16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for u8 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for u32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for i64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index 9a8e6317..50afb30f 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -1,3 +1,4 @@ +pub mod erf; pub mod kernels; trait Cpu<const ARR: usize> { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index ed3dd3fc..4e808b34 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; +use rayon::prelude::*; + +const USE_IM2COL_CONV1D: bool = true; +const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. @@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U } // This function maps over two strided index sequences. -fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( +pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( } // Similar to binary_map but with vectorized variants. -fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( +pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -723,6 +727,36 @@ impl Map1 for MaxPool2D { } } +struct UpsampleNearest1D(usize); + +impl Map1 for UpsampleNearest1D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // TODO: Specialized implementation for the case 2*sz? + let dst_sz = self.0; + let (b_sz, c, src_sz) = layout.shape().dims3()?; + let stride = layout.stride(); + let stride_sz = stride[2]; + let src_index = layout.start_offset(); + let scale_sz = src_sz as f64 / dst_sz as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_sz]; + let src_idxs = (0..dst_sz) + .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize)) + .collect::<Vec<_>>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_sz..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_sz..]; + let src_index = src_index + c_idx * stride[1]; + for (idx, src_idx) in src_idxs.iter().enumerate() { + dst[idx] = src[src_index + src_idx * stride_sz] + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset in 0..p.k_size { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * l_out; let k_cont = (0..p.c_in) .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) @@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + l_k, + stride, + dilation, + padding, + } = self; + let (b, c, l) = layout.shape().dims3()?; + let l_out = self.l_out(l); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * l_out * c * l_k]; + let (src_s0, src_s1, src_s2) = { + let s = layout.stride(); + (s[0], s[1], s[2]) + }; + // TODO: provide specialized kernels for the common use cases. + // - l_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * l_out * c * l_k; + for l_idx in 0..l_out { + let dst_idx = dst_idx + l_idx * c * l_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * l_k; + let src_idx = c_idx * src_s1 + src_idx; + for l_k_idx in 0..l_k { + let src_l = l_idx * stride + l_k_idx * dilation; + if padding != 0 && (src_l < padding || src_l >= l + padding) { + continue; + } + let src_l = src_l - padding; + let src_idx = src_idx + src_l * src_s2; + let dst_idx = dst_idx + l_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + Ok(dst) + } +} + struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { @@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset_h in 0..p.k_h { for offset_w in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * out_w * out_h; let k_cont = (0..p.c_in) .map(|c_in_idx| { @@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> { } } } - let num_threads = crate::utils::get_num_threads(); for k_y in 0..p.k_h { for k_x in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let k_cont = (0..p.c_in) .map(|c_in_idx| { k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] @@ -1298,8 +1461,9 @@ impl Map2 for MatMul { ) -> Result<Vec<T>> { use gemm::{gemm, Parallelism}; - if T::DTYPE == DType::BF16 { - return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } let (b, m, n, k) = self.0; @@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage { MaxPool2D(kernel_size, stride).map(self, layout) } + fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> { + UpsampleNearest1D(sz).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { UpsampleNearest2D(h, w).map(self, layout) } @@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { - Conv1D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV1D { + return Conv1D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col1D { + l_k: params.k_size, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let l_out = params.l_out(); + let k = op.l_k * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv2d( @@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { - Conv2D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d( diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 663f2319..00fd1d04 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; -use candle_kernels as kernels; +pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ @@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice { // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; let slice = match dtype { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { @@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) @@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } @@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice { } #[derive(Debug)] -enum CudaStorageSlice { +pub enum CudaStorageSlice { U8(CudaSlice<u8>), U32(CudaSlice<u32>), I64(CudaSlice<i64>), @@ -394,7 +401,7 @@ enum CudaStorageSlice { } type S = CudaStorageSlice; -trait Map1 { +pub trait Map1 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src: &CudaSlice<T>, @@ -416,7 +423,7 @@ trait Map1 { } } -trait Map2 { +pub trait Map2 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -441,7 +448,7 @@ trait Map2 { } } -trait Map2InPlace { +pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, dst: &mut CudaSlice<T>, @@ -472,7 +479,7 @@ trait Map2InPlace { } } -trait Map1Any { +pub trait Map1Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( &self, src: &CudaSlice<T>, @@ -495,7 +502,7 @@ trait Map1Any { } } -trait Map2Any { +pub trait Map2Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -532,7 +539,7 @@ impl Map1 for Clone { } } -fn kernel_name<T: WithDType>(root: &str) -> String { +pub fn kernel_name<T: WithDType>(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } @@ -593,6 +600,105 @@ impl Map1 for Elu { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let l_out = self.l_out(dims[2]); + let dst_el = dims[0] * l_out * dims[1] * self.l_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + l_out, + self.l_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let (h_out, w_out) = self.hw_out(dims[2], dims[3]); + let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + h_out, + w_out, + self.h_k, + self.w_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + struct Powf(f64); impl Map1 for Powf { fn f<T: DeviceRepr + WithDType>( @@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>( #[derive(Debug)] pub struct CudaStorage { - slice: CudaStorageSlice, - device: CudaDevice, + pub slice: CudaStorageSlice, + pub device: CudaDevice, } pub trait CudaDType: Sized { @@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { + const USE_IM2COL_CONV1D: bool = true; + let device = self.device().clone(); - let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV1D { + let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col1D { + l_k: params.k_size, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(not(feature = "cudnn"))] @@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { + const USE_IM2COL_CONV2D: bool = true; + let device = self.device().clone(); - let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV2D { + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(feature = "cudnn")] @@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> { + crate::bail!("upsample-nearest1d is not supported on cuda") + } + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { let device = self.device().clone(); let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; @@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index 235ad6e3..dd466ba2 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d< let x_shape = [ params.b_size as i32, params.c_in as i32, - params.i_w as i32, params.i_h as i32, + params.i_w as i32, ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { @@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d< [ params.c_out as i32, params.c_in as i32, - params.k_w as i32, params.k_h as i32, + params.k_w as i32, ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); let y = cudnn.create_4d_tensor( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, - [params.b_size as i32, params.c_out as i32, w_out, h_out], + [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; let conv2d = Conv2dForward { conv: &conv, diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index adfc4a3c..c7a1567f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,15 +1,24 @@ +//! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; use crate::{CpuStorage, Error, Result}; +/// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Unsigned 8 bits integer. U8, + // Unsigned 32 bits integer. U32, + // Signed 64 bits integer. I64, + // Brain floating-point using half precision (16 bits). BF16, + // Floating-point using half precision (16 bits). F16, + // Floating-point using single precision (32 bits). F32, + // Floating-point using double precision (64 bits). F64, } @@ -33,6 +42,7 @@ impl std::str::FromStr for DType { } impl DType { + /// String representation for dtypes. pub fn as_str(&self) -> &'static str { match self { Self::U8 => "u8", @@ -45,6 +55,7 @@ impl DType { } } + /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`. pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 6c896653..5cc9c6d8 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 1cf20a84..be8f7b07 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -30,7 +30,7 @@ pub enum Error { UnsupportedDTypeForOp(DType, &'static str), // === Dimension Index Errors === - #[error("{op}: dimension index {dim} out of range for {shape:?}")] + #[error("{op}: dimension index {dim} out of range for shape {shape:?}")] DimOutOfRange { shape: Shape, dim: i32, @@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>; impl Error { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Wrapped(Box::new(err)) + Self::Wrapped(Box::new(err)).bt() } pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Msg(err.to_string()) + Self::Msg(err.to_string()).bt() } pub fn bt(self) -> Self { diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 2b6d694b..7b84d316 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -46,19 +46,31 @@ impl Tensor { current_dim += 1; out } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + crate::bail!("multi-dimensional tensor indexing is not supported") + } + let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"), }; } Ok(x) } } -#[derive(Debug, Clone)] +#[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { /// This selects the elemnts for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), + /// Indexing via a 1d tensor + IndexSelect(Tensor), + Err(Error), } impl From<usize> for TensorIndexer { @@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer { } } +impl From<&[u32]> for TensorIndexer { + fn from(index: &[u32]) -> Self { + match Tensor::new(index, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<Vec<u32>> for TensorIndexer { + fn from(index: Vec<u32>) -> Self { + let len = index.len(); + match Tensor::from_vec(index, len, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<&Tensor> for TensorIndexer { + fn from(tensor: &Tensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer { diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index a0347416..52effdcf 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -59,6 +59,7 @@ mod op; pub mod pickle; pub mod quantized; pub mod safetensors; +pub mod scalar; pub mod shape; mod storage; mod strided_index; @@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) { } // A simple trait defining a module with forward method using a single argument. -pub trait Module: std::fmt::Debug { +pub trait Module { fn forward(&self, xs: &Tensor) -> Result<Tensor>; - - /// Change the module to use training mode vs eval mode. - /// - /// The default implementation does nothing as this is only used for a couple modules such as - /// dropout or batch-normalization. - fn set_training(&mut self, _training: bool) {} } impl Module for quantized::QMatMul { @@ -124,3 +119,9 @@ impl Module for quantized::QMatMul { self.forward(xs) } } + +impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self(xs) + } +} diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbfc9c1a..4882a205 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -58,6 +58,8 @@ pub enum UnaryOp { Sqr, Sqrt, Gelu, + GeluErf, + Erf, Relu, Tanh, } @@ -116,6 +118,7 @@ pub enum Op { stride: (usize, usize), }, + UpsampleNearest1D(Tensor), UpsampleNearest2D(Tensor), Cat(Vec<Tensor>, usize), @@ -324,6 +327,8 @@ pub(crate) struct Recip; pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; +pub(crate) struct GeluErf; +pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; @@ -600,6 +605,92 @@ impl UnaryOpT for Gelu { fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::vd_gelu(xs, ys) } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::vs_gelu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::vd_gelu(xs, ys) + } +} + +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } +} + +impl UnaryOpT for GeluErf { + const NAME: &'static str = "gelu_erf"; + const KERNEL: &'static str = "ugelu_erf"; + const V: Self = GeluErf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } } impl UnaryOpT for Relu { diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 65fd6a6e..a0fe455c 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34); pub struct BlockQ8_1 { pub(crate) d: f16, pub(crate) s: f16, - pub(crate) qs: [u8; QK8_1], + pub(crate) qs: [i8; QK8_1], } const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36); @@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 { } sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) } Ok(sumf) } @@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 { } sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) } Ok(sumf) } @@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 { for j in 0..Self::BLCK_SIZE / 2 { let v0 = xs[j] * id; let v1 = xs[j + Self::BLCK_SIZE / 2] * id; - ys.qs[j] = f32::round(v0) as u8; - ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8; + ys.qs[j] = f32::round(v0) as i8; + ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8; sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32; } ys.s = f16::from_f32(sum as f32) * ys.d; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 5c2bb2b2..f627f0f6 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -229,7 +229,7 @@ impl QTensor { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct QMatMul(std::sync::Arc<QTensor>); impl QMatMul { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index f37bb8ef..d588ea67 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -78,11 +78,7 @@ impl st::View for &Tensor { } impl Tensor { - pub fn save_safetensors<P: AsRef<std::path::Path>>( - &self, - name: &str, - filename: P, - ) -> Result<()> { + pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } @@ -267,7 +263,7 @@ impl MmapedFile { /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. - pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let inner = memmap2::MmapOptions::new() diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs new file mode 100644 index 00000000..43e1f4c8 --- /dev/null +++ b/candle-core/src/scalar.rs @@ -0,0 +1,23 @@ +use crate::{Result, Tensor, WithDType}; + +pub enum TensorScalar { + Tensor(Tensor), + Scalar(Tensor), +} + +pub trait TensorOrScalar { + fn to_tensor_scalar(self) -> Result<TensorScalar>; +} + +impl TensorOrScalar for &Tensor { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + Ok(TensorScalar::Tensor(self.clone())) + } +} + +impl<T: WithDType> TensorOrScalar for T { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + let scalar = Tensor::new(self, &crate::Device::Cpu)?; + Ok(TensorScalar::Scalar(scalar)) + } +} diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index aea8b887..4d500e7f 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -1,3 +1,4 @@ +//! The shape of a tensor is a tuple with the size of each of its dimensions. #![allow(clippy::redundant_closure_call)] use crate::{Error, Result}; @@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape { } } +impl From<(usize, usize, usize, usize, usize, usize)> for Shape { + fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { + Self(vec![ + d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, + ]) + } +} + impl From<Vec<usize>> for Shape { fn from(dims: Vec<usize>) -> Self { Self(dims) @@ -119,6 +128,7 @@ impl Shape { Self(dims.to_vec()) } + /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc. pub fn rank(&self) -> usize { self.0.len() } @@ -127,10 +137,12 @@ impl Shape { self.0 } + /// The dimensions as a slice of `usize`. pub fn dims(&self) -> &[usize] { &self.0 } + /// The total number of elements, this is the product of all dimension sizes. pub fn elem_count(&self) -> usize { self.0.iter().product() } @@ -182,6 +194,8 @@ impl Shape { true } + /// Modifies the shape by adding a list of additional dimensions at the end of the existing + /// dimensions. pub fn extend(mut self, additional_dims: &[usize]) -> Self { self.0.extend(additional_dims); self @@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) { } } +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + let d5 = self.5.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4, d5]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); @@ -457,3 +494,171 @@ mod tests { assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result<Shape>; +} + +impl<S: Into<Shape>> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result<Shape> { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + Ok(el_count.into()) + } +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((el_count / d1, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, ()) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((d1, el_count / d1).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, ()) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, ()) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, (), d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, d4, ()) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, d4, el_count / d).into()) + } +} diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8bd14ea9..9bd1fed6 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -369,6 +369,19 @@ impl Storage { } } + pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e181f240..9dccf2b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,8 +1,10 @@ +//! Tensors are N-dimenional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, }; +use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -103,6 +105,28 @@ macro_rules! binary_op { }; } +macro_rules! binary_op_scalar { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?; + let storage = self.storage().binary_impl::<crate::op::$op_name>( + &*rhs.storage(), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + macro_rules! broadcast_binary_op { ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { @@ -445,8 +469,8 @@ impl Tensor { binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); - binary_op!(maximum, Maximum); - binary_op!(minimum, Minimum); + binary_op_scalar!(maximum, Maximum); + binary_op_scalar!(minimum, Minimum); broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_sub, sub); @@ -465,6 +489,8 @@ impl Tensor { unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); + unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); unary_op!(relu, Relu); /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple @@ -642,7 +668,12 @@ impl Tensor { let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); dims[dim] = 1; - let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())); + let op = match op { + ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => { + BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())) + } + ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(), + }; let res = from_storage(storage, dims, op, false); if keepdim { Ok(res) @@ -775,8 +806,15 @@ impl Tensor { /// comparison operation is specified by the `op` argument. /// /// The returned tensor has the same shape as the original tensors and uses `u8` elements. - pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> { - let shape = self.same_shape_binary_op(rhs, "cmp")?; + pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, "cmp")?; let storage = self .storage() .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; @@ -785,45 +823,68 @@ impl Tensor { } /// Element-wise equality. - pub fn eq(&self, rhs: &Self) -> Result<Self> { + pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Eq) } /// Element-wise non-equality. - pub fn ne(&self, rhs: &Self) -> Result<Self> { + pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ne) } /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < /// rhs` and 0 otherwise. - pub fn lt(&self, rhs: &Self) -> Result<Self> { + pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Lt) } /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > /// rhs` and 0 otherwise. - pub fn gt(&self, rhs: &Self) -> Result<Self> { + pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Gt) } /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= /// rhs` and 0 otherwise. - pub fn ge(&self, rhs: &Self) -> Result<Self> { + pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ge) } /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= /// rhs` and 0 otherwise. - pub fn le(&self, rhs: &Self) -> Result<Self> { + pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Le) } - /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the + /// Clamp the tensor values to be between `min` and `max`. + pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> { + self.maximum(min)?.minimum(max) + } + + /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element. + /// + /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned + /// tensor also has three dimensions, `(batch, channels, target_size)`. + pub fn interpolate1d(&self, target_size: usize) -> Result<Self> { + let (n, c, _l) = self.dims3()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest1D); + let storage = self + .storage() + .upsample_nearest1d(self.layout(), target_size)?; + Ok(from_storage(storage, (n, c, target_size), op, false)) + } + + /// Alias for `interpolate1d`. + pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> { + self.interpolate1d(target_size) + } + + /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the /// nearest element. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let storage = self @@ -832,6 +893,11 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } + /// Alias for `interpolate2d`. + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + self.interpolate2d(target_h, target_w) + } + /// 2D average pooling over an input tensor with multiple channels. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned @@ -1684,12 +1750,15 @@ impl Tensor { Ok(from_storage(storage, shape, BackpropOp::none(), true)) } - // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. /// + /// The shape can be specified using a tuple of `usize` and at most one `()` in which case + /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so + /// as to match the number of elements in the tensor. + /// /// ```rust /// # use candle_core::{Tensor, DType, Device, D}; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; @@ -1699,10 +1768,14 @@ impl Tensor { /// /// let c = a.reshape((3, 2))?; /// assert_eq!(c.shape().dims(), &[3, 2]); + /// + /// let c = a.reshape((2, (), 1))?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { - let shape = shape.into(); + pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> { + let shape = s.into_shape(self.elem_count())?; if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), @@ -1836,6 +1909,34 @@ impl Tensor { for arg in args { arg.as_ref().check_dim(dim, "cat")?; } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } if dim == 0 { Self::cat0(args) } else { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 6af43196..edd0bd79 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,4 +1,4 @@ -use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -33,6 +33,44 @@ fn tensor_2d(device: &Device) -> Result<()> { Ok(()) } +fn clamp(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let tensor = tensor.clamp(1.5, 6.2)?; + assert_eq!( + tensor.to_vec2::<f32>()?, + [[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]], + ); + Ok(()) +} + +fn unary_op(device: &Device) -> Result<()> { + let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + test_utils::to_vec2_round(&tensor.gelu()?, 4)?, + [ + [-0.0036, 0.8412, 3.9999, -0.046, 0.3457], + [2.6911, -0.0647, -0.1091, 1.7353, 2.7933] + ] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, + [ + [-0.004, 0.8413, 3.9999, -0.046, 0.3457], + [2.6906, -0.0647, -0.1091, 1.7353, 2.7928] + ] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.erf()?, 4)?, + [ + [-1.0, 0.8427, 1.0, -0.1125, 0.5205], + [0.9999, -0.9891, -0.3079, 0.9891, 0.9999] + ] + ); + Ok(()) +} + fn binary_op(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let tensor1 = Tensor::new(data, device)?; @@ -877,6 +915,14 @@ fn broadcasting(device: &Device) -> Result<()> { Ok(()) } +fn randn(device: &Device) -> Result<()> { + let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); @@ -889,6 +935,7 @@ test_device!(max, max_cpu, max_gpu); test_device!(argmax, argmax_cpu, argmax_gpu); test_device!(argmin, argmin_cpu, argmin_gpu); test_device!(transpose, transpose_cpu, transpose_gpu); +test_device!(unary_op, unary_op_cpu, unary_op_gpu); test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(cmp, cmp_cpu, cmp_gpu); @@ -899,6 +946,8 @@ test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); +test_device!(randn, randn_cpu, randn_gpu); +test_device!(clamp, clamp_cpu, clamp_gpu); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381 diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index d69318e1..316f31c5 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.2.1" } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.2.3" } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index 30b0d01f..2dac883c 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader}; use std::fs::File; use std::io::{self, BufReader, Read}; -fn read_u32<T: Read>(reader: &mut T) -> Result<u32> { - let mut b = vec![0u8; 4]; - reader.read_exact(&mut b)?; - let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| { - (s + basis * u64::from(x), basis * 256) - }); - Ok(result as u32) +fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> { + use byteorder::ReadBytesExt; + reader.read_u32::<byteorder::BigEndian>() } fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> { diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 9035eae0..0e2e8093 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,19 +11,19 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.2.1" } -candle-nn = { path = "../candle-nn", version = "0.2.1" } -candle-transformers = { path = "../candle-transformers", version = "0.2.1" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true } -safetensors = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -num-traits = { workspace = true } -intel-mkl-src = { workspace = true, optional = true } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.2.3" } +candle-nn = { path = "../candle-nn", version = "0.2.3" } +candle-transformers = { path = "../candle-transformers", version = "0.2.3" } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } +intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } +rayon = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } [dev-dependencies] anyhow = { workspace = true } @@ -50,7 +50,7 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle/cudnn"] -flash-attn = ["cuda", "dep:candle-flash-attn"] +flash-attn = ["cuda", "candle-transformers/flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md new file mode 100644 index 00000000..82ca5f40 --- /dev/null +++ b/candle-examples/examples/bert/README.md @@ -0,0 +1,44 @@ +# candle-bert + +Bert is a general large language model. In this example it can be used for two +different tasks: +- Compute sentence embeddings for a prompt. +- Compute similarities between a set of sentences. + + +## Sentence embeddings + +Bert is used to compute the sentence embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +cargo run --example bert --release -- --prompt "Here is a test sentence" + +> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751], +> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908], +> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515], +> ... +> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777], +> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529], +> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]] +> Tensor[[1, 7, 384], f32] +``` + +## Similarities + +In this example, Bert is used to compute the sentence embeddings for a set of +sentences (hardcoded in the examples). Then cosine similarities are computed for +each sentence pair and they are reported by decreasing values, hence the first +reported pair contains the two sentences that have the highest similarity score. +The sentence embeddings are computed using average pooling through all the +sentence tokens, including some potential padding. + +```bash +cargo run --example bert --release + +> score: 0.85 'The new movie is awesome' 'The new movie is so great' +> score: 0.61 'The cat sits outside' 'The cat plays in the garden' +> score: 0.52 'I love pasta' 'Do you like pizza?' +> score: 0.23 'The new movie is awesome' 'Do you like pizza?' +> score: 0.22 'I love pasta' 'The new movie is awesome' +``` diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6cee66ee..9d0eccdf 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,14 +3,13 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod model; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use anyhow::{anyhow, Error as E, Result}; use candle::Tensor; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; -use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; #[derive(Parser, Debug)] diff --git a/candle-examples/examples/bigcode/README.md b/candle-examples/examples/bigcode/README.md new file mode 100644 index 00000000..cb4e79b1 --- /dev/null +++ b/candle-examples/examples/bigcode/README.md @@ -0,0 +1,19 @@ +# candle-starcoder: code generation model + +[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM +model specialized to code generation. The initial model was trained on 80 +programming languages. + +## Running some example + +```bash +cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 " + +> fn fact(n: u64) -> u64 { +> if n == 0 { +> 1 +> } else { +> n * fact(n - 1) +> } +> } +``` diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 652cd47f..5f17109e 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -7,8 +7,7 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -mod model; -use model::{Config, GPTBigCode}; +use candle_transformers::models::bigcode::{Config, GPTBigCode}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -29,9 +28,10 @@ impl TextGeneration { tokenizer: Tokenizer, seed: u64, temp: Option<f64>, + top_p: Option<f64>, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, tokenizer, @@ -95,6 +95,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -150,7 +154,14 @@ fn main() -> Result<()> { let model = GPTBigCode::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/dinov2/README.md b/candle-examples/examples/dinov2/README.md new file mode 100644 index 00000000..10d4ac1f --- /dev/null +++ b/candle-examples/examples/dinov2/README.md @@ -0,0 +1,19 @@ +# candle-dinov2 + +[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model. +In this example, it is used as an ImageNet classifier: the model returns the +probability for the image to belong to each of the 1000 ImageNet categories. + +## Running some example + +```bash +cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg + +> mountain bike, all-terrain bike, off-roader: 43.67% +> bicycle-built-for-two, tandem bicycle, tandem: 33.20% +> crash helmet : 13.23% +> unicycle, monocycle : 2.44% +> maillot : 2.42% +``` + + diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index e80c81e2..d3adb37c 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -9,285 +9,10 @@ extern crate accelerate_src; use clap::Parser; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2; -const IMG_SIZE: usize = 518; -const PATCH_SIZE: usize = 14; -const NUM_CLASSES: usize = 1000; - -fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - if bias { - candle_nn::linear(in_dim, out_dim, vb) - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb) - } -} - -#[derive(Debug)] -struct Attention { - qkv: Linear, - proj: Linear, - num_heads: usize, - scale: f64, -} - -impl Attention { - fn new( - vb: VarBuilder, - dim: usize, - num_heads: usize, - qkv_bias: bool, - proj_bias: bool, - ) -> Result<Self> { - let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; - let scale = 1. / ((dim / num_heads) as f64).sqrt(); - Ok(Self { - qkv, - proj, - num_heads, - scale, - }) - } -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (b, n, c) = xs.dims3()?; - let qkv = self - .qkv - .forward(xs)? - .reshape((b, n, 3, self.num_heads, c / self.num_heads))? - .transpose(1, 2)? // 02134 - .transpose(0, 1)? // 20134 - .transpose(2, 3)?; // 20314 - let q = (qkv.i(0)? * self.scale)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; - let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; - self.proj.forward(&attn) - } -} - -#[derive(Debug)] -struct LayerScale { - gamma: Tensor, -} - -impl LayerScale { - fn new(vb: VarBuilder, dim: usize) -> Result<Self> { - let gamma = vb.get(dim, "gamma")?; - Ok(Self { gamma }) - } -} - -impl Module for LayerScale { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.broadcast_mul(&self.gamma) - } -} - -#[derive(Debug)] -struct Mlp { - fc1: Linear, - fc2: Linear, -} - -impl Mlp { - fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { - let out_features = in_features; - let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; - let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.fc1.forward(xs)?.gelu()?; - self.fc2.forward(&xs) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - ls1: LayerScale, - norm2: LayerNorm, - mlp: Mlp, - ls2: LayerScale, -} - -impl Block { - fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; - let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; - let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; - let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; - Ok(Self { - norm1, - attn, - ls1, - norm2, - mlp, - ls2, - }) - } -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - let xs = self - .ls1 - .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = self - .ls2 - .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; - xs + residual - } -} - -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, - patch_size: (usize, usize), - num_patches: usize, -} - -impl PatchEmbed { - fn new( - vb: VarBuilder, - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - ) -> Result<Self> { - let config = candle_nn::Conv2dConfig { - stride: patch_size, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; - let num_patches = (img_size / patch_size) * (img_size / patch_size); - Ok(Self { - proj, - patch_size: (patch_size, patch_size), - num_patches, - }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _c, h, w) = xs.dims4()?; - let (patch_h, patch_w) = self.patch_size; - if (h % patch_h) != 0 { - candle::bail!("image height {h} is not a multiple of patch height {patch_h}") - } - if (w % patch_w) != 0 { - candle::bail!("image width {w} is not a multiple of patch width {patch_w}") - } - let xs = self.proj.forward(xs)?; - let (b, c, h, w) = xs.dims4()?; - // flatten embeddings. - xs.reshape((b, c, h * w))?.transpose(1, 2) - } -} - -#[derive(Debug)] -pub struct DinoVisionTransformer { - patch_embed: PatchEmbed, - cls_token: Tensor, - pos_embed: Tensor, - blocks: Vec<Block>, - norm: LayerNorm, - head: Linear, -} - -impl DinoVisionTransformer { - pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { - let patch_embed = - PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; - let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; - let num_tokens = 1; - let pos_embed = vb.get( - (1, patch_embed.num_patches + num_tokens, embed_dim), - "pos_embed", - )?; - let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; - let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; - let vb_b = vb.pp("blocks"); - let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) - .collect::<Result<Vec<_>>>()?; - Ok(Self { - patch_embed, - cls_token, - pos_embed, - blocks, - norm, - head, - }) - } - - fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { - let npatch = xs.dim(1)? - 1; - let n = self.pos_embed.dim(1)? - 1; - let sqrt_n = (n as f64).sqrt(); - if npatch == n && w == h { - return Ok(xs.clone()); - } - let class_pos_embed = self.pos_embed.i((.., ..1))?; - let patch_pos_embed = self.pos_embed.i((.., 1..))?; - let dim = xs.dim(D::Minus1)?; - let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); - let patch_pos_embed = patch_pos_embed - .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? - .transpose(2, 3)? - .transpose(1, 2)?; - // This uses bicubic interpolation in the original implementation. - let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; - let el_count = patch_pos_embed.shape().elem_count(); - let patch_pos_embed = - patch_pos_embed - .transpose(1, 2)? - .transpose(2, 3)? - .reshape((1, el_count / dim, dim))?; - Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) - } - - fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _nc, w, h) = xs.dims4()?; - let xs = self.patch_embed.forward(xs)?; - let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; - &xs + &self.interpolate_pos_encoding(&xs, w, h)? - } -} - -impl Module for DinoVisionTransformer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.prepare_tokens_with_mask(xs)?; - for blk in self.blocks.iter() { - xs = blk.forward(&xs)? - } - let xs = self.norm.forward(&xs)?; - let xs_norm_clstoken = xs.i((.., 0))?; - let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; - let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; - self.head.forward(&xs) - } -} - -pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { - DinoVisionTransformer::new(vb, 12, 384, 6) -} #[derive(Parser)] struct Args { #[arg(long)] @@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> { let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let model = vit_small(vb)?; + let model = dinov2::vit_small(vb)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index cbe2c90a..1e45e301 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -8,340 +8,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig}; use clap::{Parser, ValueEnum}; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn as nn; -use nn::{Module, VarBuilder}; - -// Based on the Python version from torchvision. -// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 -#[derive(Debug, Clone, Copy)] -pub struct MBConvConfig { - expand_ratio: f64, - kernel: usize, - stride: usize, - input_channels: usize, - out_channels: usize, - num_layers: usize, -} - -fn make_divisible(v: f64, divisor: usize) -> usize { - let min_value = divisor; - let new_v = usize::max( - min_value, - (v + divisor as f64 * 0.5) as usize / divisor * divisor, - ); - if (new_v as f64) < 0.9 * v { - new_v + divisor - } else { - new_v - } -} - -fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { - let bneck_conf = |e, k, s, i, o, n| { - let input_channels = make_divisible(i as f64 * width_mult, 8); - let out_channels = make_divisible(o as f64 * width_mult, 8); - let num_layers = (n as f64 * depth_mult).ceil() as usize; - MBConvConfig { - expand_ratio: e, - kernel: k, - stride: s, - input_channels, - out_channels, - num_layers, - } - }; - vec![ - bneck_conf(1., 3, 1, 32, 16, 1), - bneck_conf(6., 3, 2, 16, 24, 2), - bneck_conf(6., 5, 2, 24, 40, 2), - bneck_conf(6., 3, 2, 40, 80, 3), - bneck_conf(6., 5, 1, 80, 112, 3), - bneck_conf(6., 5, 2, 112, 192, 4), - bneck_conf(6., 3, 1, 192, 320, 1), - ] -} - -impl MBConvConfig { - fn b0() -> Vec<Self> { - bneck_confs(1.0, 1.0) - } - fn b1() -> Vec<Self> { - bneck_confs(1.0, 1.1) - } - fn b2() -> Vec<Self> { - bneck_confs(1.1, 1.2) - } - fn b3() -> Vec<Self> { - bneck_confs(1.2, 1.4) - } - fn b4() -> Vec<Self> { - bneck_confs(1.4, 1.8) - } - fn b5() -> Vec<Self> { - bneck_confs(1.6, 2.2) - } - fn b6() -> Vec<Self> { - bneck_confs(1.8, 2.6) - } - fn b7() -> Vec<Self> { - bneck_confs(2.0, 3.1) - } -} - -/// Conv2D with same padding. -#[derive(Debug)] -struct Conv2DSame { - conv2d: nn::Conv2d, - s: usize, - k: usize, -} - -impl Conv2DSame { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - bias: bool, - ) -> Result<Self> { - let conv_config = nn::Conv2dConfig { - stride, - groups, - ..Default::default() - }; - let conv2d = if bias { - nn::conv2d(i, o, k, conv_config, vb)? - } else { - nn::conv2d_no_bias(i, o, k, conv_config, vb)? - }; - Ok(Self { - conv2d, - s: stride, - k, - }) - } -} - -impl Module for Conv2DSame { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let s = self.s; - let k = self.k; - let (_, _, ih, iw) = xs.dims4()?; - let oh = (ih + s - 1) / s; - let ow = (iw + s - 1) / s; - let pad_h = usize::max((oh - 1) * s + k - ih, 0); - let pad_w = usize::max((ow - 1) * s + k - iw, 0); - if pad_h > 0 || pad_w > 0 { - let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; - let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; - self.conv2d.forward(&xs) - } else { - self.conv2d.forward(xs) - } - } -} - -#[derive(Debug)] -struct ConvNormActivation { - conv2d: Conv2DSame, - bn2d: nn::BatchNorm, - activation: bool, -} - -impl ConvNormActivation { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - ) -> Result<Self> { - let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; - let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; - Ok(Self { - conv2d, - bn2d, - activation: true, - }) - } - - fn no_activation(self) -> Self { - Self { - activation: false, - ..self - } - } -} - -impl Module for ConvNormActivation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.conv2d.forward(xs)?; - let xs = self.bn2d.forward(&xs)?; - if self.activation { - swish(&xs) - } else { - Ok(xs) - } - } -} - -#[derive(Debug)] -struct SqueezeExcitation { - fc1: Conv2DSame, - fc2: Conv2DSame, -} - -impl SqueezeExcitation { - fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { - let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; - let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for SqueezeExcitation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - // equivalent to adaptive_avg_pool2d([1, 1]) - let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; - let xs = self.fc1.forward(&xs)?; - let xs = swish(&xs)?; - let xs = self.fc2.forward(&xs)?; - let xs = nn::ops::sigmoid(&xs)?; - residual.broadcast_mul(&xs) - } -} - -#[derive(Debug)] -struct MBConv { - expand_cna: Option<ConvNormActivation>, - depthwise_cna: ConvNormActivation, - squeeze_excitation: SqueezeExcitation, - project_cna: ConvNormActivation, - config: MBConvConfig, -} - -impl MBConv { - fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { - let vb = vb.pp("block"); - let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); - let expand_cna = if exp != c.input_channels { - Some(ConvNormActivation::new( - vb.pp("0"), - c.input_channels, - exp, - 1, - 1, - 1, - )?) - } else { - None - }; - let start_index = if expand_cna.is_some() { 1 } else { 0 }; - let depthwise_cna = - ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; - let squeeze_channels = usize::max(1, c.input_channels / 4); - let squeeze_excitation = - SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; - let project_cna = - ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? - .no_activation(); - Ok(Self { - expand_cna, - depthwise_cna, - squeeze_excitation, - project_cna, - config: c, - }) - } -} - -impl Module for MBConv { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let use_res_connect = - self.config.stride == 1 && self.config.input_channels == self.config.out_channels; - let ys = match &self.expand_cna { - Some(expand_cna) => expand_cna.forward(xs)?, - None => xs.clone(), - }; - let ys = self.depthwise_cna.forward(&ys)?; - let ys = self.squeeze_excitation.forward(&ys)?; - let ys = self.project_cna.forward(&ys)?; - if use_res_connect { - ys + xs - } else { - Ok(ys) - } - } -} - -fn swish(s: &Tensor) -> Result<Tensor> { - s * nn::ops::sigmoid(s)? -} - -#[derive(Debug)] -struct EfficientNet { - init_cna: ConvNormActivation, - blocks: Vec<MBConv>, - final_cna: ConvNormActivation, - classifier: nn::Linear, -} - -impl EfficientNet { - fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { - let f_p = p.pp("features"); - let first_in_c = configs[0].input_channels; - let last_out_c = configs.last().unwrap().out_channels; - let final_out_c = 4 * last_out_c; - let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; - let nconfigs = configs.len(); - let mut blocks = vec![]; - for (index, cnf) in configs.into_iter().enumerate() { - let f_p = f_p.pp(index + 1); - for r_index in 0..cnf.num_layers { - let cnf = if r_index == 0 { - cnf - } else { - MBConvConfig { - input_channels: cnf.out_channels, - stride: 1, - ..cnf - } - }; - blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) - } - } - let final_cna = - ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; - let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; - Ok(Self { - init_cna, - blocks, - final_cna, - classifier, - }) - } -} - -impl Module for EfficientNet { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.init_cna.forward(xs)?; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - let xs = self.final_cna.forward(&xs)?; - // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) - let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; - self.classifier.forward(&xs) - } -} - #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { B0, diff --git a/candle-examples/examples/falcon/README.md b/candle-examples/examples/falcon/README.md new file mode 100644 index 00000000..267c78c2 --- /dev/null +++ b/candle-examples/examples/falcon/README.md @@ -0,0 +1,3 @@ +# candle-falcon + +Falcon is a general large language model. diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 05507f08..b0973d64 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -14,8 +14,7 @@ use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; -mod model; -use model::{Config, Falcon}; +use candle_transformers::models::falcon::{Config, Falcon}; struct TextGeneration { model: Falcon, @@ -26,17 +25,25 @@ struct TextGeneration { repeat_last_n: usize, } +struct GenerationOptions { + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, +} + impl TextGeneration { fn new( model: Falcon, tokenizer: Tokenizer, + generation_options: GenerationOptions, seed: u64, - temp: Option<f64>, device: &Device, - repeat_penalty: f32, - repeat_last_n: usize, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = + LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p); + let repeat_penalty = generation_options.repeat_penalty; + let repeat_last_n = generation_options.repeat_last_n; Self { model, tokenizer, @@ -119,6 +126,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -186,15 +197,14 @@ fn main() -> Result<()> { let model = Falcon::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - &device, - args.repeat_penalty, - args.repeat_last_n, - ); + let generation_options = GenerationOptions { + temp: args.temperature, + top_p: args.top_p, + repeat_penalty: args.repeat_penalty, + repeat_last_n: args.repeat_last_n, + }; + let mut pipeline = + TextGeneration::new(model, tokenizer, generation_options, args.seed, &device); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 6f8766d4..b2d7d938 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; -mod model; +use candle_transformers::models::llama as model; use model::{Config, Llama, LlamaConfig}; const EOS_TOKEN: &str = "</s>"; -const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Parser, Debug)] @@ -43,6 +42,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -194,7 +197,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index e0ade322..e752a494 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -27,6 +27,10 @@ struct InferenceCmd { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + #[arg(long, default_value = "")] prompt: String, @@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> { None => { let cmd = InferenceCmd { temperature: None, + top_p: None, prompt: "".to_string(), config: None, model_id: "karpathy/tinyllamas".to_string(), @@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Llama::load(vb, &cache, config)?; println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); + let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); let mut index_pos = 0; print!("{}", args.prompt); diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 17dc90e2..8a13ce6c 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -89,6 +89,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -222,7 +226,7 @@ fn main() -> Result<()> { .to_vec(); println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let mut new_tokens = vec![]; let start_gen = std::time::Instant::now(); let mut index_pos = 0; diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 3794c22d..0fae67b5 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -13,7 +13,6 @@ extern crate accelerate_src; mod encodec_model; mod musicgen_model; mod nn; -mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; @@ -78,7 +77,7 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let model = MusicgenForConditionalGeneration::load(vb, config)?; + let mut model = MusicgenForConditionalGeneration::load(vb, config)?; let tokens = tokenizer .encode(args.prompt.as_str(), true) diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 7e272fd7..d6d8ae15 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,9 +1,10 @@ -use crate::{encodec_model, t5_model}; +use crate::encodec_model; use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{ embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, VarBuilder, }; +use candle_transformers::models::t5; // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 #[derive(Debug, Clone, PartialEq)] @@ -370,7 +371,7 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { - pub text_encoder: crate::t5_model::T5EncoderModel, + pub text_encoder: t5::T5EncoderModel, pub audio_encoder: crate::encodec_model::EncodecModel, pub decoder: MusicgenForCausalLM, cfg: GenConfig, @@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration { #[derive(Debug, Clone, PartialEq)] pub struct GenConfig { musicgen: Config, - t5: crate::t5_model::Config, + t5: t5::Config, encodec: crate::encodec_model::Config, } @@ -387,7 +388,7 @@ impl GenConfig { pub fn small() -> Self { Self { musicgen: Config::musicgen_small(), - t5: t5_model::Config::musicgen_small(), + t5: t5::Config::musicgen_small(), encodec: encodec_model::Config::musicgen_small(), } } @@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration { } pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { - let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; + let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let audio_encoder = encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs deleted file mode 100644 index 607b5c93..00000000 --- a/candle-examples/examples/musicgen/t5_model.rs +++ /dev/null @@ -1,397 +0,0 @@ -// T5 Text Encoder -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py - -use candle::{DType, Result, Tensor, D}; -use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder}; -use std::sync::Arc; - -#[derive(Debug, Clone, PartialEq)] -pub struct Config { - vocab_size: usize, - d_model: usize, - d_kv: usize, - d_ff: usize, - num_layers: usize, - num_decoder_layers: Option<usize>, - num_heads: usize, - relative_attention_num_buckets: usize, - relative_attention_max_distance: usize, - dropout_rate: f64, - layer_norm_epsilon: f64, - initializer_factor: f64, - feed_forward_proj: Activation, - is_decoder: bool, - is_encoder_decoder: bool, - use_cache: bool, - pad_token_id: usize, - eos_token_id: usize, -} - -impl Default for Config { - fn default() -> Self { - Self { - vocab_size: 32128, - d_model: 512, - d_kv: 64, - d_ff: 2048, - num_layers: 6, - num_decoder_layers: None, - num_heads: 8, - relative_attention_num_buckets: 32, - relative_attention_max_distance: 128, - dropout_rate: 0.1, - layer_norm_epsilon: 1e-6, - initializer_factor: 1.0, - feed_forward_proj: Activation::Relu, - is_decoder: false, - is_encoder_decoder: true, - use_cache: true, - pad_token_id: 0, - eos_token_id: 1, - } - } -} - -impl Config { - // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184 - pub fn musicgen_small() -> Self { - Self { - d_ff: 3072, - d_kv: 64, - d_model: 768, - dropout_rate: 0.1, - eos_token_id: 1, - feed_forward_proj: Activation::Relu, - initializer_factor: 1.0, - is_decoder: false, - is_encoder_decoder: true, - layer_norm_epsilon: 1e-6, - num_decoder_layers: Some(12), - num_heads: 12, - num_layers: 12, - pad_token_id: 0, - relative_attention_max_distance: 128, - relative_attention_num_buckets: 32, - use_cache: true, - vocab_size: 32128, - } - } -} - -#[derive(Debug)] -struct T5LayerNorm { - weight: Tensor, - variance_epsilon: f64, -} - -impl T5LayerNorm { - fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let weight = vb.get(h, "weight")?; - Ok(Self { - weight, - variance_epsilon: eps, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let dtype = xs.dtype(); - let xs_f32 = xs.to_dtype(DType::F32)?; - // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; - let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; - let xs = xs.to_dtype(dtype)?; - let xs = xs.broadcast_mul(&self.weight)?; - Ok(xs) - } -} - -#[derive(Debug)] -struct T5DenseActDense { - wi: Linear, - wo: Linear, - act: Activation, -} - -impl T5DenseActDense { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; - let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; - Ok(Self { - wi, - wo, - act: Activation::Relu, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.wi.forward(xs)?; - let xs = self.act.forward(&xs)?; - let xs = self.wo.forward(&xs)?; - Ok(xs) - } -} - -#[derive(Debug)] -struct T5LayerFF { - dense_relu_dense: T5DenseActDense, - layer_norm: T5LayerNorm, -} - -impl T5LayerFF { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - // is_gated_act is not supported. - let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?; - let layer_norm = - T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; - Ok(Self { - dense_relu_dense, - layer_norm, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let ys = self.layer_norm.forward(xs)?; - let ys = self.dense_relu_dense.forward(&ys)?; - let xs = (xs + ys)?; - Ok(xs) - } -} - -#[derive(Debug)] -struct T5Attention { - q: Linear, - k: Linear, - v: Linear, - o: Linear, - n_heads: usize, - d_kv: usize, - relative_attention_bias: Option<Embedding>, - relative_attention_num_buckets: usize, - relative_attention_max_distance: usize, - inner_dim: usize, -} - -impl T5Attention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; - let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; - let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; - let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; - let relative_attention_bias = if h { - let emb = embedding( - cfg.relative_attention_num_buckets, - cfg.num_heads, - vb.pp("relative_attention_bias"), - )?; - Some(emb) - } else { - None - }; - Ok(Self { - q, - k, - v, - o, - n_heads: cfg.num_heads, - d_kv: cfg.d_kv, - relative_attention_bias, - relative_attention_num_buckets: cfg.relative_attention_num_buckets, - relative_attention_max_distance: cfg.relative_attention_max_distance, - inner_dim, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - // TODO: Apply the mask(s)? - // TODO: kv caching. - let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); - let q = self.q.forward(xs)?; - let k = self.k.forward(xs)?; - let v = self.v.forward(xs)?; - let q = q - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; - let k = k - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; - let v = v - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; - let scores = q.matmul(&k.t()?)?; - - let scores = match &self.relative_attention_bias { - None => scores, - Some(relative_attention_bias) => { - let query_length = seq_len; - let key_length = seq_len; - // This only handles the bidirectional case. - let num_buckets = self.relative_attention_num_buckets / 2; - let relative_position = (0..query_length as u32) - .map(|i| { - (0..key_length as u32) - .map(|j| { - if i < j { - j - i + num_buckets as u32 - } else { - i - j - } - }) - .collect::<Vec<u32>>() - }) - .collect::<Vec<Vec<_>>>(); - let relative_buckets = Tensor::new(relative_position, q.device())?; - let position_bias = relative_attention_bias - .forward(&relative_buckets)? - .permute((2, 0, 1))? - .unsqueeze(0)?; - (scores + position_bias)? - // TODO: position_bias_masked? - } - }; - - let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; - let attn_output = attn_weights.matmul(&v)?; - let attn_output = attn_output - .transpose(1, 2)? - .reshape((b_sz, seq_len, self.inner_dim))?; - let attn_output = self.o.forward(&attn_output)?; - Ok(attn_output) - } -} - -#[derive(Debug)] -struct T5LayerSelfAttention { - self_attention: T5Attention, - layer_norm: T5LayerNorm, -} - -impl T5LayerSelfAttention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?; - let layer_norm = - T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; - Ok(Self { - self_attention, - layer_norm, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let normed_xs = self.layer_norm.forward(xs)?; - let ys = self.self_attention.forward(&normed_xs)?; - let ys = (xs + ys)?; - Ok(ys) - } -} - -#[derive(Debug)] -struct T5LayerCrossAttention {} - -impl T5LayerCrossAttention { - fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> { - todo!() - } - - fn forward(&self, _xs: &Tensor) -> Result<Tensor> { - todo!() - } -} - -#[derive(Debug)] -struct T5Block { - self_attn: T5LayerSelfAttention, - cross_attn: Option<T5LayerCrossAttention>, - ff: T5LayerFF, -} - -impl T5Block { - fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let vb = vb.pp("layer"); - let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; - let cross_attn = if cfg.is_decoder { - Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) - } else { - None - }; - let ff_i = if cross_attn.is_some() { 2 } else { 1 }; - let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; - Ok(Self { - self_attn, - cross_attn, - ff, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.self_attn.forward(xs)?; - // TODO: clamp for f16? - if let Some(cross_attn) = &self.cross_attn { - xs = cross_attn.forward(&xs)?; - // TODO: clamp for f16? - } - let xs = self.ff.forward(&xs)?; - // TODO: clamp for f16? - Ok(xs) - } -} - -#[derive(Debug)] -struct T5Stack { - block: Vec<T5Block>, - shared: Arc<Embedding>, - final_layer_norm: T5LayerNorm, -} - -impl T5Stack { - fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { - let block = (0..cfg.num_layers) - .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg)) - .collect::<Result<Vec<_>>>()?; - let final_layer_norm = T5LayerNorm::load( - cfg.d_model, - cfg.layer_norm_epsilon, - vb.pp("final_layer_norm"), - )?; - Ok(Self { - block, - shared: shared.clone(), - final_layer_norm, - }) - } - - fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { - let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); - - let mut hidden_states = input_embeds; - for block in self.block.iter() { - hidden_states = block.forward(&hidden_states)? - } - let hidden_states = self.final_layer_norm.forward(&hidden_states)?; - Ok(hidden_states) - } -} - -#[derive(Debug)] -pub struct T5EncoderModel { - shared: Arc<Embedding>, - encoder: T5Stack, -} - -impl T5EncoderModel { - pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; - let shared = Arc::new(shared); - let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?; - Ok(Self { shared, encoder }) - } - - pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { - let encoder_outputs = self.encoder.forward(input_ids)?; - Ok(encoder_outputs) - } -} diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md new file mode 100644 index 00000000..1f6b99eb --- /dev/null +++ b/candle-examples/examples/quantized-t5/README.md @@ -0,0 +1,17 @@ +# candle-quantized-t5 + +This example uses a quantized version of the t5 model. + +```bash +$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle." +... + Eine schöne Kerze. +``` + +The weight file is automatically retrieved from the hub. It is also possible to +generate quantized weight files from the original safetensors file by using the +`tensor-tools` command line utility via: + +```bash +cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf +``` diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs new file mode 100644 index 00000000..93a86309 --- /dev/null +++ b/candle-examples/examples/quantized-t5/main.rs @@ -0,0 +1,214 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use std::io::Write; +use std::path::PathBuf; + +use candle_transformers::models::quantized_t5 as t5; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + T5Small, + FlanT5Small, + FlanT5Base, + FlanT5Large, + FlanT5Xl, + FlanT5Xxl, +} + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model repository to use on the HuggingFace hub. + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + #[arg(long)] + weight_file: Option<String>, + + // Enable/disable decoding. + #[arg(long, default_value = "false")] + disable_cache: bool, + + /// Use this prompt, otherwise compute sentence similarities. + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "t5-small")] + which: Which, +} + +struct T5ModelBuilder { + device: Device, + config: t5::Config, + weights_filename: PathBuf, +} + +impl T5ModelBuilder { + pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { + let device = Device::Cpu; + let default_model = "lmz/candle-quantized-t5".to_string(); + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, "main".to_string()), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let api = Api::new()?; + let api = api.repo(repo); + let config_filename = match args.which { + Which::T5Small => api.get("config.json")?, + Which::FlanT5Small => api.get("config-flan-t5-small.json")?, + Which::FlanT5Base => api.get("config-flan-t5-base.json")?, + Which::FlanT5Large => api.get("config-flan-t5-large.json")?, + Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, + Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, + }; + let tokenizer_filename = api.get("tokenizer.json")?; + let weights_filename = match &args.weight_file { + Some(filename) => std::path::PathBuf::from(filename), + None => match args.which { + Which::T5Small => api.get("model.gguf")?, + Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, + Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?, + Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?, + Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?, + Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?, + }, + }; + let config = std::fs::read_to_string(config_filename)?; + let mut config: t5::Config = serde_json::from_str(&config)?; + config.use_cache = !args.disable_cache; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(( + Self { + device, + config, + weights_filename, + }, + tokenizer, + )) + } + + pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> { + let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; + let device = &builder.device; + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(args.prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let mut model = builder.build_model()?; + let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p); + let encoder_output = model.encode(&input_token_ids)?; + let start = std::time::Instant::now(); + + for index in 0.. { + if output_token_ids.len() > 512 { + break; + } + let decoder_token_ids = if index == 0 || !builder.config.use_cache { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; + let logits = model + .decode(&decoder_token_ids, &encoder_output)? + .squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &output_token_ids[start_at..], + )? + }; + + let next_token_id = logits_processor.sample(&logits)?; + if next_token_id as usize == builder.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } + } + let dt = start.elapsed(); + println!( + "\n{} tokens generated ({:.2} token/s)\n", + output_token_ids.len(), + output_token_ids.len() as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-examples/examples/quantized/README.md b/candle-examples/examples/quantized/README.md new file mode 100644 index 00000000..bed09243 --- /dev/null +++ b/candle-examples/examples/quantized/README.md @@ -0,0 +1,37 @@ +# candle-quantized-llama: Fast Inference of quantized LLaMA models + +This example provides a quantized LLaMA model similar to +[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle +built-in quantization methods. Supported features include: + +- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support. +- SIMD optimizations on Apple Silicon and x86. +- Support using the `gguf` and `ggml` file formats. + +The weights are automatically downloaded for you from the [HuggingFace +Hub](https://huggingface.co/) on the first run. There are various command line +flags to use local files instead, run with `--help` to learn about them. + + + +## Running some example. + +```bash +cargo run --example quantized --release -- --prompt "The best thing about coding in rust is " + +> avx: true, neon: false, simd128: false, f16c: true +> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64 +> loaded 291 tensors (3.79GB) in 2.17s +> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 } +> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines. +``` + +## Command-line flags + +Run with `--help` to see all options. + +- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`. +- `--prompt interactive`: interactive mode where multiple prompts can be + entered. +- `--model mymodelfile.gguf`: use a local model file rather than getting one + from the hub. diff --git a/candle-examples/examples/quantized/assets/aoc.gif b/candle-examples/examples/quantized/assets/aoc.gif Binary files differnew file mode 100644 index 00000000..686074af --- /dev/null +++ b/candle-examples/examples/quantized/assets/aoc.gif diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a3f98d8e..a80ad420 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -mod model; +use candle_transformers::models::quantized_llama as model; use model::ModelWeights; const DEFAULT_PROMPT: &str = "My favorite theorem is "; @@ -71,6 +71,10 @@ struct Args { #[arg(long, default_value_t = 0.8)] temperature: f64, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let start_prompt_processing = std::time::Instant::now(); let mut next_token = { diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md new file mode 100644 index 00000000..3c5b034f --- /dev/null +++ b/candle-examples/examples/segment-anything/README.md @@ -0,0 +1,40 @@ +# candle-segment-anything: Segment-Anything Model + +This example is based on Meta AI [Segment-Anything +Model](https://github.com/facebookresearch/segment-anything). This model +provides a robust and fast image segmentation pipeline that can be tweaked via +some prompting (requesting some points to be in the target mask, requesting some +points to be part of the background so _not_ in the target mask, specifying some +bounding box). + +The default backbone can be replaced by the smaller and faster TinyViT model +based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). + +## Running some example. + +```bash +cargo run --example segment-anything --release -- \ + --image candle-examples/examples/yolo-v8/assets/bike.jpg + --use-tiny + --point-x 0.4 + --point-y 0.3 +``` + +Running this command generates a `sam_merged.jpg` file containing the original +image with a blue overlay of the selected mask. The red dot represents the prompt +specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part +of the target mask. + +The values used for `--point-x` and `--point-y` should be between 0 and 1 and +are proportional to the image dimension, i.e. use 0.5 for the image center. + + + + + +### Command-line flags +- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default + one. +- `--point-x`, `--point-y`: specifies the location of the target point. +- `--threshold`: sets the threshold value to be part of the mask, a negative + value results in a larger mask and can be specified via `--threshold=-1.2`. diff --git a/candle-examples/examples/segment-anything/assets/sam_merged.jpg b/candle-examples/examples/segment-anything/assets/sam_merged.jpg Binary files differnew file mode 100644 index 00000000..a5f64e5e --- /dev/null +++ b/candle-examples/examples/segment-anything/assets/sam_merged.jpg diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs new file mode 100644 index 00000000..3d9898b6 --- /dev/null +++ b/candle-examples/examples/segment-anything/main.rs @@ -0,0 +1,164 @@ +//! SAM: Segment Anything Model +//! https://github.com/facebookresearch/segment-anything + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::DType; +use candle_nn::VarBuilder; +use candle_transformers::models::segment_anything::sam; +use clap::Parser; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + generate_masks: bool, + + /// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image). + #[arg(long, default_value_t = 0.5)] + point_x: f64, + + /// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image). + #[arg(long, default_value_t = 0.5)] + point_y: f64, + + /// The detection threshold for the mask, 0 is the default value, negative values mean a larger + /// mask, positive makes the mask more selective. + #[arg(long, default_value_t = 0.)] + threshold: f32, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Use the TinyViT based models from MobileSAM + #[arg(long)] + use_tiny: bool, +} + +pub fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let device = candle_examples::device(args.cpu)?; + + let (image, initial_h, initial_w) = + candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; + let image = image.to_device(&device)?; + println!("loaded image {image:?}"); + + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-sam".to_string()); + let filename = if args.use_tiny { + "mobile_sam-tiny-vitt.safetensors" + } else { + "sam_vit_b_01ec64.safetensors" + }; + api.get(filename)? + } + }; + let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let sam = if args.use_tiny { + sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; + + if args.generate_masks { + // Default options similar to the Python version. + let bboxes = sam.generate_masks( + &image, + /* points_per_side */ 32, + /* crop_n_layer */ 0, + /* crop_overlap_ratio */ 512. / 1500., + /* crop_n_points_downscale_factor */ 1, + )?; + for (idx, bbox) in bboxes.iter().enumerate() { + println!("{idx} {bbox:?}"); + let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?; + let (h, w) = mask.dims2()?; + let mask = mask.broadcast_as((3, h, w))?; + candle_examples::save_image_resize( + &mask, + format!("sam_mask{idx}.png"), + initial_h, + initial_w, + )?; + } + } else { + let point = Some((args.point_x, args.point_y)); + let start_time = std::time::Instant::now(); + let (mask, iou_predictions) = sam.forward(&image, point, false)?; + println!( + "mask generated in {:.2}s", + start_time.elapsed().as_secs_f32() + ); + println!("mask:\n{mask}"); + println!("iou_predictions: {iou_predictions:?}"); + + let mask = (mask.ge(args.threshold)? * 255.)?; + let (_one, h, w) = mask.dims3()?; + let mask = mask.expand((3, h, w))?; + + let mut img = image::io::Reader::open(&args.image)? + .decode() + .map_err(candle::Error::wrap)?; + let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?; + let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = + match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) { + Some(image) => image, + None => anyhow::bail!("error saving merged image"), + }; + let mask_img = image::DynamicImage::from(mask_img).resize_to_fill( + img.width(), + img.height(), + image::imageops::FilterType::CatmullRom, + ); + for x in 0..img.width() { + for y in 0..img.height() { + let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y); + if mask_p.0[0] > 100 { + let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y); + img_p.0[2] = 255 - (255 - img_p.0[2]) / 2; + img_p.0[1] /= 2; + img_p.0[0] /= 2; + imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p) + } + } + } + let (x, y) = ( + (args.point_x * img.width() as f64) as i32, + (args.point_y * img.height() as f64) as i32, + ); + imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200])) + .save("sam_merged.jpg")? + } + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md new file mode 100644 index 00000000..ee83b3f9 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/README.md @@ -0,0 +1,63 @@ +# candle-stable-diffusion: A Diffusers API in Rust/Candle + + + +_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion +XL using Rust and [candle](https://github.com/huggingface/candle). + +The `stable-diffusion` example is a conversion of +[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle +rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, +as well as Stable Diffusion XL 1.0. + +## Getting the weights + +The weights are automatically downloaded for you from the [HuggingFace +Hub](https://huggingface.co/) on the first run. There are various command line +flags to use local files instead, run with `--help` to learn about them. + +## Running some example. + +```bash +cargo run --example stable-diffusion --release --features=cuda,cudnn \ + -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" +``` + +The final image is named `sd_final.png` by default. +The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The +original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). + +### Command-line flags + +- `--prompt`: the prompt to be used to generate the image. +- `--uncond-prompt`: the optional unconditional prompt. +- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or + `xl`. +- `--cpu`: use the cpu rather than the gpu (much slower). +- `--height`, `--width`: set the height and width for the generated image. +- `--n-steps`: the number of steps to be used in the diffusion process. +- `--num-samples`: the number of samples to generate. +- `--final-image`: the filename for the generated image(s). + +### Using flash-attention + +Using flash attention makes image generation a lot faster and uses less memory. +The downside is some long compilation time. You can set the +`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like +`/home/user/.candle` to ensures that the compilation artifacts are properly +cached. + +Enabling flash-attention requires both a feature flag, `--feature flash-attn` +and using the command line flag `--use-flash-attn`. + +## Image to Image Pipeline +... + +## FAQ + +### Memory Issues + +This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used +with the `--cpu` flag but is much slower. +Alternatively, reducing the height and width with the `--height` and `--width` +flag is likely to reduce memory usage significantly. diff --git a/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg b/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg Binary files differnew file mode 100644 index 00000000..a6f7b6c6 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 8372edcd..c8b771a0 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -4,20 +4,10 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -mod attention; -mod clip; -mod ddim; -mod embeddings; -mod resnet; -mod schedulers; -mod stable_diffusion; -mod unet_2d; -mod unet_2d_blocks; -mod utils; -mod vae; +use candle_transformers::models::stable_diffusion; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor, D}; +use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; @@ -96,6 +86,15 @@ struct Args { #[arg(long)] use_f16: bool, + + #[arg(long, value_name = "FILE")] + img2img: Option<String>, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + #[arg(long, default_value_t = 0.8)] + img2img_strength: f64, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -306,6 +305,26 @@ fn text_embeddings( Ok(text_embeddings) } +fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) +} + fn run(args: Args) -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -328,9 +347,15 @@ fn run(args: Args) -> Result<()> { tracing, use_f16, use_flash_attn, + img2img, + img2img_strength, .. } = args; + if !(0. ..=1.).contains(&img2img_strength) { + anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}") + } + let _guard = if tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -382,25 +407,53 @@ fn run(args: Args) -> Result<()> { println!("Building the autoencoder."); let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; + let init_latent_dist = match &img2img { + None => None, + Some(image) => { + let image = image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; println!("Building the unet."); let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; + let t_start = if img2img.is_some() { + n_steps - (n_steps as f64 * img2img_strength) as usize + } else { + 0 + }; let bsize = 1; for idx in 0..num_samples { - let mut latents = Tensor::randn( - 0f32, - 1f32, - (bsize, 4, sd_config.height / 8, sd_config.width / 8), - &device, - )? - .to_dtype(dtype)?; - - // scale the initial noise by the standard deviation required by the scheduler - latents = (latents * scheduler.init_noise_sigma())?; + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; println!("starting sampling"); - for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } let start_time = std::time::Instant::now(); let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md new file mode 100644 index 00000000..6a406467 --- /dev/null +++ b/candle-examples/examples/t5/README.md @@ -0,0 +1,25 @@ +# candle-t5 + +## Encoder-decoder example: + +```bash +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode +... +Running on CPU, to run on GPU, build this example with `--features cuda` + Eine schöne Kerze. +9 tokens generated (2.42 token/s) +``` + +## Sentence embedding example: + +```bash +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle." +... +[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], + [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], + [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962], + [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990], + [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]] +Tensor[[1, 5, 512], f32] +Took 303.766583ms +``` diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs new file mode 100644 index 00000000..55929c33 --- /dev/null +++ b/candle-examples/examples/t5/main.rs @@ -0,0 +1,314 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use std::io::Write; +use std::path::PathBuf; + +use candle_transformers::models::t5; + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +const DTYPE: DType = DType::F32; + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model repository to use on the HuggingFace hub. + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + /// Enable decoding. + #[arg(long)] + decode: bool, + + // Enable/disable decoding. + #[arg(long, default_value = "false")] + disable_cache: bool, + + /// Use this prompt, otherwise compute sentence similarities. + #[arg(long)] + prompt: Option<String>, + + /// If set along with --decode, will use this prompt to initialize the decoder. + #[arg(long)] + decoder_prompt: Option<String>, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +struct T5ModelBuilder { + device: Device, + config: t5::Config, + weights_filename: Vec<PathBuf>, +} + +impl T5ModelBuilder { + pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { + let device = candle_examples::device(args.cpu)?; + let default_model = "t5-small".to_string(); + let default_revision = "refs/pr/15".to_string(); + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision); + let api = Api::new()?; + let api = api.repo(repo); + let config_filename = api.get("config.json")?; + let tokenizer_filename = api.get("tokenizer.json")?; + let weights_filename = if model_id == "google/flan-t5-xxl" { + vec![ + api.get("model-00001-of-00005.safetensors")?, + api.get("model-00002-of-00005.safetensors")?, + api.get("model-00003-of-00005.safetensors")?, + api.get("model-00004-of-00005.safetensors")?, + api.get("model-00005-of-00005.safetensors")?, + ] + } else { + vec![api.get("model.safetensors")?] + }; + let config = std::fs::read_to_string(config_filename)?; + let mut config: t5::Config = serde_json::from_str(&config)?; + config.use_cache = !args.disable_cache; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(( + Self { + device, + config, + weights_filename, + }, + tokenizer, + )) + } + + pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> { + let weights = self + .weights_filename + .iter() + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) + .collect::<candle::Result<Vec<_>>>()?; + let weights = weights + .iter() + .map(|w| w.deserialize()) + .collect::<candle::Result<Vec<_>>>()?; + let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + Ok(t5::T5EncoderModel::load(vb, &self.config)?) + } + + pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> { + let weights = self + .weights_filename + .iter() + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) + .collect::<candle::Result<Vec<_>>>()?; + let weights = weights + .iter() + .map(|w| w.deserialize()) + .collect::<candle::Result<Vec<_>>>()?; + let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; + let device = &builder.device; + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + match args.prompt { + Some(prompt) => { + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + if !args.decode { + let mut model = builder.build_encoder()?; + let start = std::time::Instant::now(); + let ys = model.forward(&input_token_ids)?; + println!("{ys}"); + println!("Took {:?}", start.elapsed()); + } else { + let mut model = builder.build_conditional_generation()?; + let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + if let Some(decoder_prompt) = &args.decoder_prompt { + print!("{decoder_prompt}"); + output_token_ids.extend( + tokenizer + .encode(decoder_prompt.to_string(), false) + .map_err(E::msg)? + .get_ids() + .to_vec(), + ); + } + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p); + let encoder_output = model.encode(&input_token_ids)?; + let start = std::time::Instant::now(); + + for index in 0.. { + if output_token_ids.len() > 512 { + break; + } + let decoder_token_ids = if index == 0 || !builder.config.use_cache { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; + let logits = model + .decode(&decoder_token_ids, &encoder_output)? + .squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &output_token_ids[start_at..], + )? + }; + + let next_token_id = logits_processor.sample(&logits)?; + if next_token_id as usize == builder.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } + } + let dt = start.elapsed(); + println!( + "\n{} tokens generated ({:.2} token/s)\n", + output_token_ids.len(), + output_token_ids.len() as f64 / dt.as_secs_f64(), + ); + } + } + None => { + let mut model = builder.build_encoder()?; + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + let n_sentences = sentences.len(); + let mut all_embeddings = Vec::with_capacity(n_sentences); + for sentence in sentences { + let tokens = tokenizer + .encode(sentence, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; + let embeddings = model.forward(&token_ids)?; + println!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; + println!("pooled embeddings {:?}", embeddings.shape()); + all_embeddings.push(embeddings) + } + + let mut similarities = vec![]; + for (i, e_i) in all_embeddings.iter().enumerate() { + for (j, e_j) in all_embeddings + .iter() + .enumerate() + .take(n_sentences) + .skip(i + 1) + { + let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?; + let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?; + let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + } + } + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-examples/examples/whisper/README.md b/candle-examples/examples/whisper/README.md new file mode 100644 index 00000000..124cd182 --- /dev/null +++ b/candle-examples/examples/whisper/README.md @@ -0,0 +1,39 @@ +# candle-whisper: speech recognition + +An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using +candle. Whisper is a general purpose speech recognition model, it can be used to +convert audio files (in the `.wav` format) to text. Supported features include +language detection as well as multilingual speech recognition. + +## Running some example + +If no audio file is passed as input, a [sample +file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded +from the hub. + +```bash + cargo run --example whisper --release + +> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav +> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 } +> pcm data loaded 176000 +> loaded mel: [1, 80, 3000] +> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country + ``` + + In order to use the multilingual mode, specify a multilingual model via the + `--model` flag, see the details below. + +## Command line flags + +- `--input`: the audio file to be converted to text, in wav format. +- `--language`: force the language to some specific value rather than being + detected, e.g. `en`. +- `--task`: the task to be performed, can be `transcribe` (return the text data + in the original language) or `translate` (translate the text to English). +- `--timestamps`: enable the timestamp mode where some timestamps are reported + for each recognized audio extracts. +- `--model`: the model to be used. Models that do not end with `-en` are + multilingual models, other ones are English only models. The supported models + are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`, + `medium.en`, `large`, and `large-v2`. diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5dd8ee20..c71d562a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,41 +10,16 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor}; +use candle::{Device, IndexOp, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; -mod audio; -mod model; -use model::{Config, Whisper}; mod multilingual; - -const DTYPE: DType = DType::F32; - -// Audio parameters. -const SAMPLE_RATE: usize = 16000; -const N_FFT: usize = 400; -const N_MELS: usize = 80; -const HOP_LENGTH: usize = 160; -const CHUNK_LENGTH: usize = 30; -const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk -const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input - -const NO_SPEECH_THRESHOLD: f64 = 0.6; -const LOGPROB_THRESHOLD: f64 = -1.0; -const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; -const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; - -// Tokenizer dependent bits. -const SOT_TOKEN: &str = "<|startoftranscript|>"; -const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; -const TRANSLATE_TOKEN: &str = "<|translate|>"; -const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; -const EOT_TOKEN: &str = "<|endoftext|>"; -const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; +use candle_transformers::models::whisper::{self as m, audio, model}; +use model::{Config, Whisper}; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -94,7 +69,7 @@ impl Decoder { timestamps: bool, verbose: bool, ) -> Result<Self> { - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; // Suppress the notimestamps token when in timestamps mode. // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) @@ -109,11 +84,11 @@ impl Decoder { }) .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; - let sot_token = token_id(&tokenizer, SOT_TOKEN)?; - let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; - let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let eot_token = token_id(&tokenizer, EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), @@ -220,17 +195,17 @@ impl Decoder { } fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { - for (i, &t) in TEMPERATURES.iter().enumerate() { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { let dr: Result<DecodingResult> = self.decode(segment, t); - if i == TEMPERATURES.len() - 1 { + if i == m::TEMPERATURES.len() - 1 { return dr; } // On errors, we try again with a different temperature. match dr { Ok(dr) => { - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { return Ok(dr); } } @@ -248,13 +223,13 @@ impl Decoder { let mut segments = vec![]; while seek < content_frames { let start = std::time::Instant::now(); - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; let dr = self.decode_with_fallback(&mel_segment)?; seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { println!("no speech detected, skipping {seek} {dr:?}"); continue; } @@ -431,7 +406,6 @@ fn main() -> Result<()> { let args = Args::parse(); let _guard = if args.tracing { - println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) @@ -493,8 +467,8 @@ fn main() -> Result<()> { let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; println!("loaded wav data: {header:?}"); - if header.sampling_rate != SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + if header.sampling_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE) } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] @@ -502,14 +476,14 @@ fn main() -> Result<()> { .map(|v| *v as f32 / 32768.) .collect(); println!("pcm data loaded {}", pcm_data.len()); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters); let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = Whisper::load(&vb, config)?; diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index bc0bae1f..a82b09ef 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::<Result<Vec<_>>>()?; - let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; + let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; diff --git a/candle-examples/examples/wuerstchen/README.md b/candle-examples/examples/wuerstchen/README.md new file mode 100644 index 00000000..1b8accd1 --- /dev/null +++ b/candle-examples/examples/wuerstchen/README.md @@ -0,0 +1,27 @@ +# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models + + + +The `wuerstchen` example is a port of the [diffusers +implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2. +The candle implementation reproduces the same structure/files for models and +pipelines. Useful resources: + +- [Official implementation](https://github.com/dome272/Wuerstchen). +- [Arxiv paper](https://arxiv.org/abs/2306.00637). +- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen). + +## Getting the weights + +The weights are automatically downloaded for you from the [HuggingFace +Hub](https://huggingface.co/) on the first run. There are various command line +flags to use local files instead, run with `--help` to learn about them. + +## Running some example. + +```bash +cargo run --example wuerstchen --release --features cuda,cudnn -- \ + --prompt "Anthropomorphic cat dressed as a fire fighter" +``` + +The final image is named `sd_final.png` by default. diff --git a/candle-examples/examples/wuerstchen/assets/cat.jpg b/candle-examples/examples/wuerstchen/assets/cat.jpg Binary files differnew file mode 100644 index 00000000..9ff67183 --- /dev/null +++ b/candle-examples/examples/wuerstchen/assets/cat.jpg diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs new file mode 100644 index 00000000..95f3b8f4 --- /dev/null +++ b/candle-examples/examples/wuerstchen/main.rs @@ -0,0 +1,396 @@ +#![allow(unused)] + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion; +use candle_transformers::models::wuerstchen; + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use clap::Parser; +use tokenizers::Tokenizer; + +const PRIOR_GUIDANCE_SCALE: f64 = 4.0; +const RESOLUTION_MULTIPLE: f64 = 42.67; +const LATENT_DIM_SCALE: f64 = 10.67; +const PRIOR_CIN: usize = 16; +const DECODER_CIN: usize = 4; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A very realistic photo of a rusty robot walking on a sandy beach" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + /// The height in pixels of the generated image. + #[arg(long)] + height: Option<usize>, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option<usize>, + + /// The decoder weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + decoder_weights: Option<String>, + + /// The CLIP weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + clip_weights: Option<String>, + + /// The CLIP weight file used by the prior model, in .safetensors format. + #[arg(long, value_name = "FILE")] + prior_clip_weights: Option<String>, + + /// The prior weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + prior_weights: Option<String>, + + /// The VQGAN weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + vqgan_weights: Option<String>, + + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option<String>, + + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for prior tokenization. + prior_tokenizer: Option<String>, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + #[arg(long)] + sliced_attention_size: Option<usize>, + + /// The number of steps to run the diffusion for. + #[arg(long, default_value_t = 30)] + n_steps: usize, + + /// The number of samples to generate. + #[arg(long, default_value_t = 1)] + num_samples: i64, + + /// The name of the final image to generate. + #[arg(long, value_name = "FILE", default_value = "sd_final.png")] + final_image: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + PriorTokenizer, + Clip, + PriorClip, + Decoder, + VqGan, + Prior, +} + +impl ModelFile { + fn get(&self, filename: Option<String>) -> Result<std::path::PathBuf> { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let repo_main = "warp-ai/wuerstchen"; + let repo_prior = "warp-ai/wuerstchen-prior"; + let (repo, path) = match self { + Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"), + Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"), + Self::Clip => (repo_main, "text_encoder/model.safetensors"), + Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"), + Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"), + Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"), + Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"), + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +fn output_filename( + basename: &str, + sample_idx: i64, + num_samples: i64, + timestep_idx: Option<usize>, +) -> String { + let filename = if num_samples > 1 { + match basename.rsplit_once('.') { + None => format!("{basename}.{sample_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}.{sample_idx}.{extension}") + } + } + } else { + basename.to_string() + }; + match timestep_idx { + None => filename, + Some(timestep_idx) => match filename.rsplit_once('.') { + None => format!("{filename}-{timestep_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}-{timestep_idx}.{extension}") + } + }, + } +} + +fn encode_prompt( + prompt: &str, + uncond_prompt: Option<&str>, + tokenizer: std::path::PathBuf, + clip_weights: std::path::PathBuf, + clip_config: stable_diffusion::clip::Config, + device: &Device, +) -> Result<Tensor> { + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let pad_id = match &clip_config.pad_with { + Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; + println!("Running with prompt \"{prompt}\"."); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let tokens_len = tokens.len(); + while tokens.len() < clip_config.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + println!("Building the clip transformer."); + let text_model = + stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; + let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?; + match uncond_prompt { + None => Ok(text_embeddings), + Some(uncond_prompt) => { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let uncond_tokens_len = uncond_tokens.len(); + while uncond_tokens.len() < clip_config.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + + let uncond_embeddings = + text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; + let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; + Ok(text_embeddings) + } + } +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + height, + width, + n_steps, + tokenizer, + final_image, + sliced_attention_size, + num_samples, + clip_weights, + prior_weights, + vqgan_weights, + decoder_weights, + tracing, + .. + } = args; + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let device = candle_examples::device(cpu)?; + let height = height.unwrap_or(1024); + let width = width.unwrap_or(1024); + + let prior_text_embeddings = { + let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?; + let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?; + encode_prompt( + &prompt, + Some(&uncond_prompt), + tokenizer.clone(), + weights, + stable_diffusion::clip::Config::wuerstchen_prior(), + &device, + )? + }; + println!("generated prior text embeddings {prior_text_embeddings:?}"); + + let text_embeddings = { + let tokenizer = ModelFile::Tokenizer.get(tokenizer)?; + let weights = ModelFile::Clip.get(clip_weights)?; + encode_prompt( + &prompt, + None, + tokenizer.clone(), + weights, + stable_diffusion::clip::Config::wuerstchen(), + &device, + )? + }; + println!("generated text embeddings {text_embeddings:?}"); + + println!("Building the prior."); + let b_size = 1; + let image_embeddings = { + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, PRIOR_CIN, latent_height, latent_width), + &device, + )?; + + let prior = { + let prior_weights = ModelFile::Prior.get(prior_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::prior::WPrior::new( + /* c_in */ PRIOR_CIN, + /* c */ 1536, + /* c_cond */ 1280, + /* c_r */ 64, + /* depth */ 32, + /* nhead */ 24, + args.use_flash_attn, + vb, + )? + }; + let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = prior_scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; + println!("prior denoising"); + for (index, &t) in timesteps.iter().enumerate() { + let start_time = std::time::Instant::now(); + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); + } + ((latents * 42.)? - 1.)? + }; + + println!("Building the vqgan."); + let vqgan = { + let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::paella_vq::PaellaVQ::new(vb)? + }; + + println!("Building the decoder."); + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json + let decoder = { + let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::diffnext::WDiffNeXt::new( + /* c_in */ DECODER_CIN, + /* c_out */ DECODER_CIN, + /* c_r */ 64, + /* c_cond */ 1024, + /* clip_embd */ 1024, + /* patch_size */ 2, + args.use_flash_attn, + vb, + )? + }; + + for idx in 0..num_samples { + // https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json + let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize; + let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize; + + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, DECODER_CIN, latent_height, latent_width), + &device, + )?; + + println!("diffusion process with prior {image_embeddings:?}"); + let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?; + let timesteps = scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; + for (index, &t) in timesteps.iter().enumerate() { + let start_time = std::time::Instant::now(); + let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?; + let noise_pred = + decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?; + latents = scheduler.step(&noise_pred, t, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); + } + println!( + "Generating the final image for sample {}/{}.", + idx + 1, + num_samples + ); + let image = vqgan.decode(&(&latents * 0.3764)?)?; + // TODO: Add the clamping between 0 and 1. + let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; + let image_filename = output_filename(&final_image, idx + 1, num_samples, None); + candle_examples::save_image(&image, image_filename)? + } + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 5e388921..ecf75bdf 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_examples::object_detection::{non_maximum_suppression, Bbox}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox}; mod darknet; use anyhow::Result; @@ -46,7 +46,7 @@ pub fn report( let (npreds, pred_size) = pred.dims2()?; let nclasses = pred_size - 5; // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); + let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::<f32>::try_from(pred.get(index)?)?; @@ -65,7 +65,7 @@ pub fn report( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints: vec![], + data: (), }; bboxes[class_index].push(bbox) } diff --git a/candle-examples/examples/yolo-v8/README.md b/candle-examples/examples/yolo-v8/README.md new file mode 100644 index 00000000..938dea13 --- /dev/null +++ b/candle-examples/examples/yolo-v8/README.md @@ -0,0 +1,47 @@ +# candle-yolo-v8: Object Detection and Pose Estimation + +This is a port of [Ultralytics +YOLOv8](https://github.com/ultralytics/ultralytics). The implementation is based +on the [tinygrad +version](https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py) +and on the model architecture described in this +[issue](https://github.com/ultralytics/ultralytics/issues/189). The supported +tasks are object detection and pose estimation. + +You can try this model online on the [Candle YOLOv8 +Space](https://huggingface.co/spaces/lmz/candle-yolo). The model then fully runs +in your browser using WebAssembly - if you use a custom image it will never +leave your phone/computer! + +## Running some example + +### Object Detection +```bash +cargo run --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg +``` + +This prints details about the detected objects and generates a `bike.pp.jpg` file. + + + +Image source: +[wikimedia](https://commons.wikimedia.org/wiki/File:Leading_group,_Giro_d%27Italia_2021,_Stage_15.jpg). + + + +### Pose Estimation +```bash +cargo run --example yolo-v8 --release -- \ + candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose +``` + + + +### Command-line flags + +- `--which`: select the model variant to be used, `n`, `s` , `m`, `l`, or `x` by + increasing size and quality. +- `--task`: `detect` for object detection and `pose` for pose estimation. +- `--legend-size`: the size of the characters to print. +- `--model`: use a local model file rather than downloading it from the hub. + diff --git a/candle-examples/examples/yolo-v8/assets/bike.jpg b/candle-examples/examples/yolo-v8/assets/bike.jpg Binary files differnew file mode 100644 index 00000000..05d1faaf --- /dev/null +++ b/candle-examples/examples/yolo-v8/assets/bike.jpg diff --git a/candle-examples/examples/yolo-v8/assets/bike.od.jpg b/candle-examples/examples/yolo-v8/assets/bike.od.jpg Binary files differnew file mode 100644 index 00000000..111b9286 --- /dev/null +++ b/candle-examples/examples/yolo-v8/assets/bike.od.jpg diff --git a/candle-examples/examples/yolo-v8/assets/bike.pose.jpg b/candle-examples/examples/yolo-v8/assets/bike.pose.jpg Binary files differnew file mode 100644 index 00000000..e660f65b --- /dev/null +++ b/candle-examples/examples/yolo-v8/assets/bike.pose.jpg diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index d5c5ac1c..d48bac35 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -8,8 +8,8 @@ mod model; use model::{Multiples, YoloV8, YoloV8Pose}; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use candle_nn::{Module, VarBuilder}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; use image::DynamicImage; @@ -64,7 +64,7 @@ pub fn report_detect( let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); + let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::<f32>::try_from(pred.i((.., index))?)?; @@ -83,7 +83,7 @@ pub fn report_detect( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints: vec![], + data: vec![], }; bboxes[class_index].push(bbox) } @@ -176,7 +176,7 @@ pub fn report_pose( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints, + data: keypoints, }; bboxes.push(bbox) } @@ -204,7 +204,7 @@ pub fn report_pose( image::Rgb([255, 0, 0]), ); } - for kp in b.keypoints.iter() { + for kp in b.data.iter() { if kp.mask < 0.6 { continue; } @@ -219,8 +219,8 @@ pub fn report_pose( } for &(idx1, idx2) in KP_CONNECTIONS.iter() { - let kp1 = &b.keypoints[idx1]; - let kp2 = &b.keypoints[idx2]; + let kp1 = &b.data[idx1]; + let kp2 = &b.data[idx2]; if kp1.mask < 0.6 || kp2.mask < 0.6 { continue; } diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 395162eb..5e0b44fb 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,6 +1,5 @@ pub mod coco_classes; pub mod imagenet; -pub mod object_detection; use candle::{Device, Result, Tensor}; @@ -16,6 +15,36 @@ pub fn device(cpu: bool) -> Result<Device> { } } +pub fn load_image<P: AsRef<std::path::Path>>( + p: P, + resize_longest: Option<usize>, +) -> Result<(Tensor, usize, usize)> { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)?; + let (initial_h, initial_w) = (img.height() as usize, img.width() as usize); + let img = match resize_longest { + None => img, + Some(resize_longest) => { + let (height, width) = (img.height(), img.width()); + let resize_longest = resize_longest as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) + } + }; + let (height, width) = (img.height() as usize, img.width() as usize); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; + Ok((data, initial_h, initial_w)) +} + pub fn load_image_and_resize<P: AsRef<std::path::Path>>( p: P, width: usize, @@ -35,20 +64,44 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>( } /// Saves an image to disk using the image crate, this expects an input with shape -/// (c, width, height). +/// (c, height, width). pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { let p = p.as_ref(); - let (channel, width, height) = img.dims3()?; + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::<u8>()?; + let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} + +pub fn save_image_resize<P: AsRef<std::path::Path>>( + img: &Tensor, + p: P, + h: usize, + w: usize, +) -> Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; if channel != 3 { - candle::bail!("save_image expects an input of shape (3, width, height)") + candle::bail!("save_image expects an input of shape (3, height, width)") } - let img = img.transpose(0, 1)?.t()?.flatten_all()?; + let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::<u8>()?; let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, None => candle::bail!("error saving image {p:?}"), }; + let image = image::DynamicImage::from(image); + let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom); image.save(p).map_err(candle::Error::wrap)?; Ok(()) } diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 0d130519..808e0070 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.2.1" +version = "0.2.3" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.2.1", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], version = "0.2.3", package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.2.1", features = ["cuda"] } +candle-nn = { path = "../candle-nn", version = "0.2.3", features = ["cuda"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 773c5638..64275fda 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -6,7 +6,7 @@ use rayon::prelude::*; use std::path::PathBuf; use std::str::FromStr; -const KERNEL_FILES: [&str; 9] = [ +const KERNEL_FILES: [&str; 17] = [ "flash_api.cu", "flash_fwd_hdim128_fp16_sm80.cu", "flash_fwd_hdim160_fp16_sm80.cu", @@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [ "flash_fwd_hdim32_fp16_sm80.cu", "flash_fwd_hdim64_fp16_sm80.cu", "flash_fwd_hdim96_fp16_sm80.cu", - // "flash_fwd_hdim128_bf16_sm80.cu", - // "flash_fwd_hdim160_bf16_sm80.cu", - // "flash_fwd_hdim192_bf16_sm80.cu", - // "flash_fwd_hdim224_bf16_sm80.cu", - // "flash_fwd_hdim256_bf16_sm80.cu", - // "flash_fwd_hdim32_bf16_sm80.cu", - // "flash_fwd_hdim64_bf16_sm80.cu", - // "flash_fwd_hdim96_bf16_sm80.cu", + "flash_fwd_hdim128_bf16_sm80.cu", + "flash_fwd_hdim160_bf16_sm80.cu", + "flash_fwd_hdim192_bf16_sm80.cu", + "flash_fwd_hdim224_bf16_sm80.cu", + "flash_fwd_hdim256_bf16_sm80.cu", + "flash_fwd_hdim32_bf16_sm80.cu", + "flash_fwd_hdim64_bf16_sm80.cu", + "flash_fwd_hdim96_bf16_sm80.cu", ]; fn main() -> Result<()> { @@ -57,9 +57,20 @@ fn main() -> Result<()> { #[allow(clippy::redundant_clone)] out_dir.clone() } - Ok(build_dir) => PathBuf::from(build_dir), + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) + } }; set_cuda_include_dir()?; + + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + let compute_cap = compute_cap()?; let out_file = build_dir.join("libflashattention.a"); @@ -95,14 +106,21 @@ fn main() -> Result<()> { .args(["--default-stream", "per-thread"]) .arg("-Icutlass/include") .arg("--expt-relaxed-constexpr") - .arg(cu_file); + .arg("--verbose"); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(cu_file); let output = command .spawn() .context("failed spawning nvcc")? .wait_with_output()?; if !output.status.success() { anyhow::bail!( - "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr) ) @@ -122,7 +140,8 @@ fn main() -> Result<()> { .wait_with_output()?; if !output.status.success() { anyhow::bail!( - "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr) ) diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index d928bcb6..72991257 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,20 +1,19 @@ #include "flash_fwd_launch_template.h" -// TODO: Switch back to handling bf16. -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); - }); -} - // void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FP16_SWITCH(!params.is_bf16, [&] { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_<elem_type, kHeadDim>(params, stream); -// }); +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); // }); // } +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_<elem_type, kHeadDim>(params, stream); + }); + }); +} + extern "C" void run_mha( void *q_ptr, void *k_ptr, @@ -52,7 +51,8 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal + int is_causal, + int is_bf16 ) { Flash_fwd_params params; // Reset the parameters @@ -102,7 +102,7 @@ extern "C" void run_mha( params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - params.is_bf16 = 0; + params.is_bf16 = is_bf16; params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index ae61c405..90f34e43 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -38,6 +38,7 @@ extern "C" { seqlen_k_rounded: u32, is_causal: c_int, + is_bf16: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3c5fd455..61980a58 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -4,7 +4,7 @@ use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; use candle::{CpuStorage, Layout, Result, Shape, Tensor}; -use half::f16; +use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, @@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize { (x + m - 1) / m * m } -impl candle::CustomOp3 for FlashAttn { - fn name(&self) -> &'static str { - "flash-attn" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttn { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -40,15 +26,16 @@ impl candle::CustomOp3 for FlashAttn { k_l: &Layout, v: &candle::CudaStorage, v_l: &Layout, + is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { // https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187 let dev = q.device(); let out_shape = q_l.shape().clone(); let out_l = Layout::contiguous(&out_shape); - let q = q.as_cuda_slice::<f16>()?; - let k = k.as_cuda_slice::<f16>()?; - let v = v.as_cuda_slice::<f16>()?; + let q = q.as_cuda_slice::<T>()?; + let k = k.as_cuda_slice::<T>()?; + let v = v.as_cuda_slice::<T>()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -104,10 +91,11 @@ impl candle::CustomOp3 for FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; + let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let causal = if self.causal { 1 } else { 0 }; + let is_bf16 = if is_bf16 { 1 } else { 0 }; unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; @@ -146,6 +134,7 @@ impl candle::CustomOp3 for FlashAttn { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, + /* is_bf16 */ is_bf16, ) } @@ -154,6 +143,40 @@ impl candle::CustomOp3 for FlashAttn { } } +impl candle::CustomOp3 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + /// Flash-attention v2 layer. /// /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. @@ -190,24 +213,10 @@ struct FlashAttnVarLen { seqlens_k: Tensor, } -impl candle::CustomOp3 for FlashAttnVarLen { - fn name(&self) -> &'static str { - "flash-attn-varlen" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -215,6 +224,7 @@ impl candle::CustomOp3 for FlashAttnVarLen { k_l: &Layout, v: &candle::CudaStorage, v_l: &Layout, + is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 let dev = q.device(); @@ -314,6 +324,7 @@ impl candle::CustomOp3 for FlashAttnVarLen { .w()?; let causal = if self.causal { 1 } else { 0 }; + let is_bf16 = if is_bf16 { 1 } else { 0 }; unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; @@ -354,6 +365,7 @@ impl candle::CustomOp3 for FlashAttnVarLen { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, + /* is_bf16 */ is_bf16, ) } @@ -362,6 +374,40 @@ impl candle::CustomOp3 for FlashAttnVarLen { } } +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + #[allow(clippy::too_many_arguments)] /// Flash-attention v2 layer with variable-length batching. /// diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 576c52ea..80b6aaab 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.2.1" +version = "0.2.3" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 3c8e96a9..ad084671 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -164,6 +164,8 @@ mod cuda { println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let children = kernel_paths .par_iter() .flat_map(|p| { @@ -188,8 +190,13 @@ mod cuda { .args(["--output-directory", &out_dir]) // Flash attention only // .arg("--expt-relaxed-constexpr") - .args(&include_options) - .arg(p); + .args(&include_options); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(p); Some((p, command.spawn() .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) }}) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index ab2045a3..ee20fe5f 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -77,20 +77,30 @@ CAST_OP(double, __half, cast_f64_f16) CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) +CAST_OP(uint32_t, int64_t, cast_u32_i64 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) CAST_OP(uint8_t, double, cast_u8_f64) +CAST_OP(int64_t, uint32_t, cast_i64_u32) +CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int64_t, cast_i64_i64 ) +CAST_OP(int64_t, float, cast_i64_f32) +CAST_OP(int64_t, double, cast_i64_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index ba2fa1ad..9c8ce00f 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -51,6 +51,118 @@ __device__ void conv1d( dst[dst_i] = static_cast<T>(d); } +template <typename T> +__device__ void im2col1d( + const size_t dst_numel, + const size_t l_out, + const size_t l_k, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // dst: (b_size, l_out, c_in, l_k) + // src: (b_size, c_in, l_in) + if (dst_i >= dst_numel) { + return; + } + const size_t *src_dims = info; + const size_t *src_s = info + 3; + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + + const size_t dst_s2 = l_k; + const size_t dst_s1 = c_in * dst_s2; + const size_t dst_s0 = l_out * dst_s1; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t l_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= l_idx * dst_s1; + const size_t c_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= c_idx * dst_s2; + const size_t l_k_idx = tmp_dst_i; + size_t src_l_idx = l_idx * stride + l_k_idx * dilation; + if (src_l_idx < padding || src_l_idx >= l_in + padding) { + dst[dst_i] = static_cast<T>(0); + } + else { + src_l_idx -= padding; + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2]; + dst[dst_i] = src[src_i]; + } +} + +template <typename T> +__device__ void im2col( + const size_t dst_numel, + const size_t h_out, + const size_t w_out, + const size_t h_k, + const size_t w_k, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // dst: (b_size, h_out, w_out, c_in, h_k, w_k) + // src: (b_size, c_in, h_in, w_in) + if (dst_i >= dst_numel) { + return; + } + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + + const size_t dst_s4 = w_k; + const size_t dst_s3 = h_k * dst_s4; + const size_t dst_s2 = c_in * dst_s3; + const size_t dst_s1 = w_out * dst_s2; + const size_t dst_s0 = h_out * dst_s1; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t h_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= h_idx * dst_s1; + const size_t w_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= w_idx * dst_s2; + const size_t c_idx = tmp_dst_i / dst_s3; + tmp_dst_i -= c_idx * dst_s3; + const size_t h_k_idx = tmp_dst_i / dst_s4; + tmp_dst_i -= h_k_idx * dst_s4; + const size_t w_k_idx = tmp_dst_i; + size_t src_h_idx = h_idx * stride + h_k_idx * dilation; + size_t src_w_idx = w_idx * stride + w_k_idx * dilation; + if (src_h_idx < padding || src_h_idx >= h_in + padding) { + dst[dst_i] = static_cast<T>(0); + } + else if (src_w_idx < padding || src_w_idx >= w_in + padding) { + dst[dst_i] = static_cast<T>(0); + } + else { + src_h_idx -= padding; + src_w_idx -= padding; + const size_t src_i = + b_idx * src_s[0] + + c_idx * src_s[1] + + src_h_idx * src_s[2] + + src_w_idx * src_s[3]; + dst[dst_i] = src[src_i]; + } +} + // Naive implementation of conv2d. template <typename T, typename A> __device__ void conv2d( @@ -363,6 +475,38 @@ extern "C" __global__ void FN_NAME( \ conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ } \ +#define IM2COL1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_numel, \ + const size_t l_out, \ + const size_t l_k, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \ +} \ + +#define IM2COL_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_numel, \ + const size_t h_out, \ + const size_t w_out, \ + const size_t h_k, \ + const size_t w_k, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \ +} \ + #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t src_numel, \ @@ -428,6 +572,8 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) +IM2COL_OP(__nv_bfloat16, im2col_bf16) +IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -437,6 +583,8 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) +IM2COL_OP(__half, im2col_f16) +IM2COL1D_OP(__half, im2col1d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -468,3 +616,13 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) + +IM2COL_OP(float, im2col_f32) +IM2COL_OP(double, im2col_f64) +IM2COL_OP(uint8_t, im2col_u8) +IM2COL_OP(uint32_t, im2col_u32) + +IM2COL1D_OP(float, im2col1d_f32) +IM2COL1D_OP(double, im2col1d_f64) +IM2COL1D_OP(uint8_t, im2col1d_u8) +IM2COL1D_OP(uint32_t, im2col1d_u32) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 4096d2d1..8e46a07c 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); } __device__ __forceinline__ double powg(double a, double b) { return pow(a, b); } __device__ __forceinline__ float tanhg(float a) { return tanhf(a); } __device__ __forceinline__ double tanhg(double a) { return tanh(a); } +__device__ __forceinline__ float erfg(float a) { return erff(a); } +__device__ __forceinline__ double erfg(double a) { return erf(a); } +__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); } +__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); } __device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); } __device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); } __device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); } @@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); } __device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; } __device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); } __device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); } +__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); } +__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); } __device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); } __device__ __forceinline__ __half logg(__half a) { return hlog(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); } @@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); __device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; } __device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); } __device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); } __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 271502c5..fca6865e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +// Softmax implementation adapted from ggml. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159 +template <typename T, typename ACC> +__device__ void softmax(const T * x, T * dst, const int ncols) { + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; + + T max_val = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + max_val = maxg(max_val, x[i]); + } + + // find the max value in the block +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + } + + ACC tmp = 0.; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + const T val = expg(x[i] - max_val); + tmp += static_cast<ACC>(val); + dst[i] = val; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + const ACC inv_tmp = 1. / tmp; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + dst[i] *= inv_tmp; + } +} + template <typename T> __device__ void fast_max(const size_t src_numel, const size_t el_to_sum_per_block, @@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, } \ } +#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, \ + const int n_cols) { \ + softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \ + } \ + #if __CUDA_ARCH__ >= 800 +SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 +SOFTMAX_OP(__half, float, softmax_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) +SOFTMAX_OP(float, float, softmax_f32) +SOFTMAX_OP(double, double, softmax_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c6142a03..105d8c3a 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -29,6 +29,11 @@ extern "C" __global__ void FN_NAME( \ } \ template<typename T> +__device__ __forceinline__ T gelu_erf_fwd(T x) { + return x * normcdfg(x); +} + +template<typename T> __device__ __forceinline__ T gelu_fwd(T x) { T x_sq = x * x; T x_cube = x_sq * x; @@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) +UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x)) +UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x)) UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) +UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) @@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x)) UNARY_OP(__half, usin_f16, sing(x)) UNARY_OP(__half, ucos_f16, cosg(x)) UNARY_OP(__half, utanh_f16, tanhg(x)) +UNARY_OP(__half, uerf_f16, erfg(x)) +UNARY_OP(__half, unormcdf_f16, normcdfg(x)) UNARY_OP(__half, uabs_f16, absg(x)) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) +UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) UNARY_OP1(__half, upowf_f16, powg(x, param)) @@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x)) UNARY_OP(double, ucos_f64, cosg(x)) UNARY_OP(float, utanh_f32, tanhg(x)) UNARY_OP(double, utanh_f64, tanhg(x)) +UNARY_OP(float, uerf_f32, erfg(x)) +UNARY_OP(double, uerf_f64, erfg(x)) +UNARY_OP(float, unormcdf_f32, normcdfg(x)) +UNARY_OP(double, unormcdf_f64, normcdfg(x)) UNARY_OP(float, uabs_f32, absg(x)) UNARY_OP(double, uabs_f64, absg(x)) UNARY_OP(float, usqr_f32, x*x) @@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x)) UNARY_OP(double, usqrt_f64, sqrtg(x)) UNARY_OP(float, ugelu_f32, gelu_fwd(x)) UNARY_OP(double, ugelu_f64, gelu_fwd(x)) +UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x)) +UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x)) UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index aa055583..a6629d33 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,13 +11,18 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } +rayon = { workspace = true } safetensors = { workspace = true } +serde = { workspace = true } [dev-dependencies] anyhow = { workspace = true } +clap = { workspace = true } [features] default = [] diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs new file mode 100644 index 00000000..204a7109 --- /dev/null +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -0,0 +1,302 @@ +/// This example contains some simple benchmarks so that it's easy to run them in perf etc. +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::quantized::GgmlType; +use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D}; +use clap::{Parser, Subcommand}; + +const CHECK_CONV2D: bool = false; + +trait Benchmark { + type PreProcessData; + type RunResult; + + fn preprocess() -> Result<Self::PreProcessData>; + fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>; + + const ITERS: usize; +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl candle::CustomOp1 for Im2Col { + fn name(&self) -> &'static str { + "im2col" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let slice = storage.as_slice::<f32>()?; + let src = &slice[layout.start_offset()..]; + let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b * h_out * w_out, c * h_k * w_k).into())) + } +} + +// Conv1d example as used in whisper. +struct Conv1d; +impl Benchmark for Conv1d { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + d.0.conv1d(&d.1, 0, 1, 1, 1) + } + + const ITERS: usize = 5; +} + +// Conv2d example as used in stable-diffusion. +struct Conv2d; +impl Benchmark for Conv2d { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + + fn preprocess() -> Result<Self::PreProcessData> { + let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + d.0.conv2d(&d.1, 0, 1, 1, 1) + } + + const ITERS: usize = 5; +} + +// Conv2d example as used in stable-diffusion, im2col implementation. +struct Conv2dIm2Col; +impl Benchmark for Conv2dIm2Col { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + + fn preprocess() -> Result<Self::PreProcessData> { + let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + // d.0.conv2d(&d.1, 0, 1, 1, 1) + let (b, _, h, w) = d.0.dims4()?; + let (_, _, h_k, w_k) = d.1.dims4()?; + let op = Im2Col { + h_k, + w_k, + stride: 1, + dilation: 1, + padding: 0, + }; + let (h_out, w_out) = op.hw_out(h, w); + let col = d.0.apply_op1_no_bwd(&op)?; + let res = col.matmul(&d.1.flatten_from(1)?.t()?)?; + let res = res + .reshape((b, h_out, w_out, ()))? + .permute((0, 3, 1, 2))? + .contiguous()?; + if CHECK_CONV2D { + let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1); + let diff = (&res - res2)?.sqr()?.mean_all()?; + println!("{diff}"); + } + Ok(res) + } + + const ITERS: usize = 5; +} + +struct Matmul; +impl Benchmark for Matmul { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + d.0.matmul(&d.1) + } + + const ITERS: usize = 100; +} + +// This benchmark is similar to: +// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp +struct QMatMul; +impl Benchmark for QMatMul { + type PreProcessData = (candle::quantized::QMatMul, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; + let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?; + let mm = candle::quantized::QMatMul::from_qtensor(mm); + let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; + Ok((mm, arg)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + d.0.forward(&d.1) + } + + const ITERS: usize = 100; +} + +struct Softmax; +impl Benchmark for Softmax { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + candle_nn::ops::softmax(d, D::Minus1) + } + + const ITERS: usize = 100; +} + +struct SoftmaxLastDim; +impl Benchmark for SoftmaxLastDim { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + candle_nn::ops::softmax_last_dim(d) + } + + const ITERS: usize = 100; +} + +fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> { + use std::hint::black_box; + + let iters = iters.unwrap_or(B::ITERS); + let d = B::preprocess()?; + let start = std::time::Instant::now(); + for _iter in 0..iters { + let _res = black_box(B::run_one(black_box(&d))?); + } + println!("{:?}", start.elapsed() / iters as u32); + Ok(()) +} + +#[derive(Subcommand, Debug, Clone)] +enum Task { + Conv1d, + Conv2d, + Conv2dIm2Col, + Matmul, + Qmatmul, + Softmax, + SoftmaxLastDim, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The benchmark to be run. + #[command(subcommand)] + task: Task, + + #[arg(long)] + iters: Option<usize>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + match args.task { + Task::Conv1d => run::<Conv1d>(args.iters)?, + Task::Conv2d => run::<Conv2d>(args.iters)?, + Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?, + Task::Matmul => run::<Matmul>(args.iters)?, + Task::Softmax => run::<Softmax>(args.iters)?, + Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?, + Task::Qmatmul => run::<QMatMul>(args.iters)?, + } + Ok(()) +} diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 0db3edc9..17467b31 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -1,18 +1,29 @@ use candle::Tensor; +use serde::Deserialize; -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] pub enum Activation { + #[default] Gelu, + #[serde(rename = "gated-gelu")] + NewGelu, Relu, Elu(f64), + LeakyRelu(f64), } impl super::Module for Activation { fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { match self { Self::Gelu => xs.gelu(), + // TODO: This is "gelu_new", not the original "gelu". + // There's some small numerical difference: + // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 + Self::NewGelu => xs.gelu(), Self::Relu => xs.relu(), &Self::Elu(alpha) => xs.elu(alpha), + &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), } } } diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 2dac0758..27ef15f7 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct BatchNorm { running_mean: Tensor, running_var: Tensor, diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index dbf23aa5..89e9f42d 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -20,7 +20,7 @@ impl Default for Conv1dConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Conv1d { weight: Tensor, bias: Option<Tensor>, @@ -39,6 +39,14 @@ impl Conv1d { pub fn config(&self) -> &Conv1dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv1d { @@ -80,8 +88,7 @@ impl Default for Conv2dConfig { } } -#[allow(dead_code)] -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Conv2d { weight: Tensor, bias: Option<Tensor>, @@ -100,6 +107,14 @@ impl Conv2d { pub fn config(&self) -> &Conv2dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv2d { @@ -122,15 +137,76 @@ impl crate::Module for Conv2d { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose2dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + // TODO: support groups. +} + +impl Default for ConvTranspose2dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConvTranspose2d { + weight: Tensor, + bias: Option<Tensor>, + config: ConvTranspose2dConfig, +} + +impl ConvTranspose2d { + pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose2dConfig { + &self.config + } +} + +impl crate::Module for ConvTranspose2d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = x.conv_transpose2d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + pub fn conv1d( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: Conv1dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv1d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( (out_channels, in_channels / cfg.groups, kernel_size), "weight", init_ws, @@ -140,7 +216,7 @@ pub fn conv1d( lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv1d::new(ws, Some(bs), cfg)) } @@ -149,10 +225,10 @@ pub fn conv2d( out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv2d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, @@ -167,7 +243,7 @@ pub fn conv2d( lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv2d::new(ws, Some(bs), cfg)) } @@ -176,10 +252,10 @@ pub fn conv2d_no_bias( out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv2d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, @@ -191,3 +267,44 @@ pub fn conv2d_no_bias( )?; Ok(Conv2d::new(ws, None, cfg)) } + +pub fn conv_transpose2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose2dConfig, + vb: crate::VarBuilder, +) -> Result<ConvTranspose2d> { + let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels, kernel_size, kernel_size), + "weight", + init, + )?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; + Ok(ConvTranspose2d::new(ws, Some(bs), cfg)) +} + +pub fn conv_transpose2d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose2dConfig, + vb: crate::VarBuilder, +) -> Result<ConvTranspose2d> { + let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels, kernel_size, kernel_size), + "weight", + init, + )?; + Ok(ConvTranspose2d::new(ws, None, cfg)) +} diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index d84f9f53..52968bc2 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -1,7 +1,7 @@ //! Embedding Layer. use candle::{Result, Tensor}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Embedding { embeddings: Tensor, hidden_size: usize, @@ -18,6 +18,11 @@ impl Embedding { pub fn embeddings(&self) -> &Tensor { &self.embeddings } + + /// Get the hidden size of the embedding matrix + pub fn hidden_size(&self) -> usize { + self.hidden_size + } } impl crate::Module for Embedding { diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index eb1b889f..5b80b970 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -4,7 +4,7 @@ use candle::{DType, Result, Tensor}; // This group norm version handles both weight and bias so removes the mean. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct GroupNorm { weight: Tensor, bias: Tensor, diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 08e2f628..7617fc6c 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,7 +28,7 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 -use candle::{DType, Result, Tensor}; +use candle::{DType, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { @@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig { } // This layer norm version handles both weight and bias so removes the mean. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LayerNorm { weight: Tensor, bias: Option<Tensor>, @@ -104,15 +104,15 @@ impl crate::Module for LayerNorm { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let x = if self.remove_mean { - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; x.broadcast_sub(&mean_x)? } else { x }; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; match &self.bias { @@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>( } /// RmsNorm is a specialized version of the LayerNorm module. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct RmsNorm(LayerNorm); impl RmsNorm { diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 6e268f4e..8e5580df 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -16,7 +16,10 @@ pub mod var_map; pub use activation::Activation; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; -pub use conv::{conv1d, conv2d, conv2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; +pub use conv::{ + conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d, + Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig, +}; pub use embedding::{embedding, Embedding}; pub use func::{func, Func}; pub use group_norm::{group_norm, GroupNorm}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 7028f68c..94632296 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -19,7 +19,7 @@ //! ``` use candle::{Result, Tensor}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Linear { weight: Tensor, bias: Option<Tensor>, @@ -41,8 +41,9 @@ impl Linear { impl super::Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + let w = match *x.dims() { + [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, }; let x = x.matmul(&w)?; diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index cddf278e..72451f83 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -1,6 +1,6 @@ use candle::{Result, Tensor}; -/// The negative loss likelihodd loss. +/// The negative log likelihood loss. /// /// Arguments /// diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c3b6ffa2..32de1af9 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,5 @@ -use candle::{Result, Tensor}; +use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. @@ -43,6 +44,11 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { (xs.neg()?.exp()? + 1.0)?.recip() } +pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> { + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope +} + pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> { // This implementation is inefficient as it stores the full mask for the backward pass. // Instead we could just store the seed and have a specialized kernel that would both @@ -77,3 +83,149 @@ impl Dropout { } } } + +struct SoftmaxLastDim; + +impl candle::CustomOp1 for SoftmaxLastDim { + fn name(&self) -> &'static str { + "softmax-last-dim" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + fn softmax<T: candle::WithDType + num_traits::Float>( + src: &[T], + layout: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut max = T::neg_infinity(); + unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; + for (s, d) in src.iter().zip(dst.iter_mut()) { + *d = (*s - max).exp(); + } + let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; + for d in dst.iter_mut() { + *d /= sum_exp + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + match storage { + CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout), + CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout), + CpuStorage::F32(slice) => softmax::<f32>(slice, layout), + CpuStorage::F64(slice) => softmax::<f64>(slice, layout), + _ => candle::bail!("unsupported dtype for softmax {:?}", storage), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &candle::CudaStorage, + layout: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S; + impl Map1 for S { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1, 32, 1), + shared_mem_bytes: 0, + }; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(el) }.w()?; + let params = (src, &dst, n_cols as i32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = storage.device(); + let slice = S.map(&storage.slice, dev, layout)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } +} + +pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { + xs.apply_op1_no_bwd(&SoftmaxLastDim) +} + +// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html +pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c / upscale_factor / upscale_factor; + xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))? + .permute((0, 1, 4, 2, 5, 3))? + .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor)) +} + +pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c * downscale_factor * downscale_factor; + xs.reshape(( + b_size, + c, + h / downscale_factor, + downscale_factor, + w / downscale_factor, + downscale_factor, + ))? + .permute((0, 1, 3, 5, 2, 4))? + .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) +} + +// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html +pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> { + match pad { + 0 => Ok(xs.clone()), + 1 => { + let (_b_size, _c, h, w) = xs.dims4()?; + let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); + let xs = Tensor::cat(&[&first, xs, &last], 3)?; + let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); + Tensor::cat(&[&first, &xs, &last], 2) + } + n => candle::bail!("replication-pad with a size of {n} is not supported"), + } +} diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index d52a9082..18a4a71c 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -85,7 +85,7 @@ impl LSTMConfig { /// /// <https://en.wikipedia.org/wiki/Long_short-term_memory> #[allow(clippy::upper_case_acronyms, unused)] -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LSTM { w_ih: Tensor, w_hh: Tensor, @@ -235,7 +235,7 @@ impl GRUConfig { /// /// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> #[allow(clippy::upper_case_acronyms, unused)] -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct GRU { w_ih: Tensor, w_hh: Tensor, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index bf5d5b43..4ccbaf17 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -5,14 +5,14 @@ use crate::VarMap; use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::rc::Rc; +use std::sync::Arc; /// A structure used to retrieve variables, these variables can either come from storage or be /// generated via some form of initialization. /// /// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`. pub struct VarBuilderArgs<'a, B: Backend> { - data: Rc<TensorData<B>>, + data: Arc<TensorData<B>>, path: Vec<String>, _phantom: std::marker::PhantomData<&'a B>, } @@ -43,7 +43,7 @@ struct TensorData<B: Backend> { /// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most /// of the time. The main restriction is that it doesn't allow for specific args (besides /// initialization hints). -pub trait Backend { +pub trait Backend: Send + Sync { type Hints: Default; /// Retrieve a tensor with some target shape. @@ -59,7 +59,7 @@ pub trait Backend { fn contains_tensor(&self, name: &str) -> bool; } -pub trait SimpleBackend { +pub trait SimpleBackend: Send + Sync { /// Retrieve a tensor based on a target name and shape. fn get( &self, @@ -99,7 +99,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { device: dev.clone(), }; Self { - data: Rc::new(data), + data: Arc::new(data), path: vec![], _phantom: std::marker::PhantomData, } @@ -333,7 +333,7 @@ impl<'a> VarBuilder<'a> { device, }; Self { - data: Rc::new(data), + data: Arc::new(data), path: vec![], _phantom: std::marker::PhantomData, } diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 209fc10a..5bbaf238 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -59,8 +59,8 @@ fn batch_norm() -> Result<()> { ); let bn2 = BatchNorm::new( 5, - running_mean.clone(), - running_var.clone(), + running_mean, + running_var, Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?, Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, 1e-8, diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 4ba8cfcc..5ca01b37 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -41,6 +41,16 @@ fn softmax() -> Result<()> { [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] ] ); + let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); Ok(()) } diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore new file mode 100644 index 00000000..68bc17f9 --- /dev/null +++ b/candle-pyo3/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 97631b0a..7fd0ac28 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -12,11 +12,10 @@ readme = "README.md" [lib] name = "candle" crate-type = ["cdylib"] -doc = false [dependencies] -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.2.1" } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.2.3" } half = { workspace = true } pyo3 = { version = "0.19.0", features = ["extension-module"] } diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md index 07dff468..be6d4f68 100644 --- a/candle-pyo3/README.md +++ b/candle-pyo3/README.md @@ -1,7 +1,26 @@ +## Installation + From the `candle-pyo3` directory, enable a virtual env where you will want the candle package to be installed then run. ```bash -maturin develop +maturin develop -r python test.py ``` + +## Generating Stub Files for Type Hinting + +For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script. + +### Steps: +1. Install the package using `maturin`. +2. Generate the stub files by running: + ``` + python stub.py + ``` + +### Validation: +To ensure that the stub files match the current implementation, execute: +``` +python stub.py --check +``` diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py new file mode 100644 index 00000000..951609cc --- /dev/null +++ b/candle-pyo3/py_src/candle/__init__.py @@ -0,0 +1,5 @@ +from .candle import * + +__doc__ = candle.__doc__ +if hasattr(candle, "__all__"): + __all__ = candle.__all__
\ No newline at end of file diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi new file mode 100644 index 00000000..414f0bc4 --- /dev/null +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -0,0 +1,375 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device + +class bf16(DType): + pass + +@staticmethod +def cat(tensors: List[Tensor], dim: int) -> Tensor: + """ + Concatenate the tensors across one axis. + """ + pass + +class f16(DType): + pass + +class f32(DType): + pass + +class f64(DType): + pass + +class i64(DType): + pass + +@staticmethod +def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor filled with ones. + """ + pass + +@staticmethod +def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor with random values. + """ + pass + +@staticmethod +def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor with random values from a normal distribution. + """ + pass + +@staticmethod +def stack(tensors: List[Tensor], dim: int) -> Tensor: + """ + Stack the tensors along a new axis. + """ + pass + +@staticmethod +def tensor(data: _ArrayLike) -> Tensor: + """ + Creates a new tensor from a Python value. The value can be a scalar or array-like object. + """ + pass + +class u32(DType): + pass + +class u8(DType): + pass + +@staticmethod +def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor filled with zeros. + """ + pass + +class DType: + """ + A `candle` dtype. + """ + +class QTensor: + """ + A quantized tensor. + """ + + def dequantize(self) -> Tensor: + """ + Dequantizes the tensor. + """ + pass + @property + def ggml_dtype(self) -> str: + """ + Gets the tensors quantized dtype. + """ + pass + def matmul_t(self, lhs: Tensor) -> Tensor: + """ + Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. + """ + pass + @property + def rank(self) -> int: + """ + Gets the rank of the tensor. + """ + pass + @property + def shape(self) -> Tuple[int]: + """ + Gets the shape of the tensor. + """ + pass + +class Tensor: + """ + A `candle` tensor. + """ + + def __init__(self, data: _ArrayLike): + pass + def argmax_keepdim(self, dim: int) -> Tensor: + """ + Returns the indices of the maximum value(s) across the selected dimension. + """ + pass + def argmin_keepdim(self, dim: int) -> Tensor: + """ + Returns the indices of the minimum value(s) across the selected dimension. + """ + pass + def broadcast_add(self, rhs: Tensor) -> Tensor: + """ + Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ + pass + def broadcast_as(self, shape: Sequence[int]) -> Tensor: + """ + Broadcasts the tensor to the given shape. + """ + pass + def broadcast_div(self, rhs: Tensor) -> Tensor: + """ + Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ + pass + def broadcast_left(self, shape: Sequence[int]) -> Tensor: + """ + Broadcasts the tensor to the given shape, adding new dimensions on the left. + """ + pass + def broadcast_mul(self, rhs: Tensor) -> Tensor: + """ + Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ + pass + def broadcast_sub(self, rhs: Tensor) -> Tensor: + """ + Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ + pass + def contiguous(self) -> Tensor: + """ + Makes the tensor contiguous in memory. + """ + pass + def copy(self) -> Tensor: + """ + Returns a copy of the tensor. + """ + pass + def cos(self) -> Tensor: + """ + Performs the `cos` operation on the tensor. + """ + pass + def detach(self) -> Tensor: + """ + Detach the tensor from the computation graph. + """ + pass + @property + def device(self) -> Device: + """ + Gets the tensor's device. + """ + pass + @property + def dtype(self) -> DType: + """ + Gets the tensor's dtype. + """ + pass + def exp(self) -> Tensor: + """ + Performs the `exp` operation on the tensor. + """ + pass + def flatten_all(self) -> Tensor: + """ + Flattens the tensor into a 1D tensor. + """ + pass + def flatten_from(self, dim: int) -> Tensor: + """ + Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. + """ + pass + def flatten_to(self, dim: int) -> Tensor: + """ + Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). + """ + pass + def get(self, index: int) -> Tensor: + """ + Gets the value at the specified index. + """ + pass + def index_select(self, rhs: Tensor, dim: int) -> Tensor: + """ + Select values for the input tensor at the target indexes across the specified dimension. + + The `indexes` is argument is an int tensor with a single dimension. + The output has the same number of dimension as the `self` input. The target dimension of + the output has length the length of `indexes` and the values are taken from `self` using + the index from `indexes`. Other dimensions have the same number of elements as the input + tensor. + """ + pass + def is_contiguous(self) -> bool: + """ + Returns true if the tensor is contiguous in C order. + """ + pass + def is_fortran_contiguous(self) -> bool: + """ + Returns true if the tensor is contiguous in Fortran order. + """ + pass + def log(self) -> Tensor: + """ + Performs the `log` operation on the tensor. + """ + pass + def matmul(self, rhs: Tensor) -> Tensor: + """ + Performs a matrix multiplication between the two tensors. + """ + pass + def max_keepdim(self, dim: int) -> Tensor: + """ + Gathers the maximum value across the selected dimension. + """ + pass + def mean_all(self) -> Tensor: + """ + Returns the mean of the tensor. + """ + pass + def min_keepdim(self, dim: int) -> Tensor: + """ + Gathers the minimum value across the selected dimension. + """ + pass + def narrow(self, dim: int, start: int, len: int) -> Tensor: + """ + Returns a new tensor that is a narrowed version of the input, the dimension `dim` + ranges from `start` to `start + len`. + """ + pass + def powf(self, p: float) -> Tensor: + """ + Performs the `pow` operation on the tensor with the given exponent. + """ + pass + def quantize(self, quantized_dtype: str) -> QTensor: + """ + Quantize the tensor. + """ + pass + @property + def rank(self) -> int: + """ + Gets the tensor's rank. + """ + pass + def recip(self) -> Tensor: + """ + Get the `recip` of the tensor. + """ + pass + def reshape(self, shape: Sequence[int]) -> Tensor: + """ + Reshapes the tensor to the given shape. + """ + pass + @property + def shape(self) -> Tuple[int]: + """ + Gets the tensor's shape. + """ + pass + def sin(self) -> Tensor: + """ + Performs the `sin` operation on the tensor. + """ + pass + def sqr(self) -> Tensor: + """ + Squares the tensor. + """ + pass + def sqrt(self) -> Tensor: + """ + Calculates the square root of the tensor. + """ + pass + def squeeze(self, dim: int) -> Tensor: + """ + Creates a new tensor with the specified dimension removed if its size was one. + """ + pass + @property + def stride(self) -> Tuple[int]: + """ + Gets the tensor's strides. + """ + pass + def sum_all(self) -> Tensor: + """ + Returns the sum of the tensor. + """ + pass + def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor: + """ + Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. + """ + pass + def t(self) -> Tensor: + """ + Transposes the tensor. + """ + pass + def to_device(self, device: Union[str, Device]) -> Tensor: + """ + Move the tensor to a new device. + """ + pass + def to_dtype(self, dtype: Union[str, DType]) -> Tensor: + """ + Convert the tensor to a new dtype. + """ + pass + def transpose(self, dim1: int, dim2: int) -> Tensor: + """ + Returns a tensor that is a transposed version of the input, the given dimensions are swapped. + """ + pass + def unsqueeze(self, dim: int) -> Tensor: + """ + Creates a new tensor with a dimension of size one inserted at the specified position. + """ + pass + def values(self) -> _ArrayLike: + """ + Gets the tensor's data as a Python scalar or array-like object. + """ + pass + def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor: + """ + Returns a tensor with the same shape as the input tensor, the values are taken from + `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + input tensor is equal to zero. + """ + pass diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py new file mode 100644 index 00000000..b8c5cfb7 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/__init__.py @@ -0,0 +1,5 @@ +# Generated content DO NOT EDIT +from .. import nn + +silu = nn.silu +softmax = nn.softmax diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi new file mode 100644 index 00000000..01b30fce --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/__init__.pyi @@ -0,0 +1,19 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device +from candle import Tensor, DType, QTensor + +@staticmethod +def silu(tensor: Tensor) -> Tensor: + """ + Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. + """ + pass + +@staticmethod +def softmax(tensor: Tensor, dim: int) -> Tensor: + """ + Applies the Softmax function to a given tensor.# + """ + pass diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py new file mode 100644 index 00000000..ea85d2a3 --- /dev/null +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -0,0 +1,16 @@ +from typing import TypeVar, Union, Sequence + +_T = TypeVar("_T") + +_ArrayLike = Union[ + _T, + Sequence[_T], + Sequence[Sequence[_T]], + Sequence[Sequence[Sequence[_T]]], + Sequence[Sequence[Sequence[Sequence[_T]]]], +] + +CPU:str = "cpu" +CUDA:str = "cuda" + +Device = TypeVar("Device", CPU, CUDA)
\ No newline at end of file diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py new file mode 100644 index 00000000..62d85dc9 --- /dev/null +++ b/candle-pyo3/py_src/candle/utils/__init__.py @@ -0,0 +1,12 @@ +# Generated content DO NOT EDIT +from .. import utils + +cuda_is_available = utils.cuda_is_available +get_num_threads = utils.get_num_threads +has_accelerate = utils.has_accelerate +has_mkl = utils.has_mkl +load_ggml = utils.load_ggml +load_gguf = utils.load_gguf +load_safetensors = utils.load_safetensors +save_gguf = utils.save_gguf +save_safetensors = utils.save_safetensors diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi new file mode 100644 index 00000000..61964ffc --- /dev/null +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -0,0 +1,70 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device +from candle import Tensor, DType, QTensor + +@staticmethod +def cuda_is_available() -> bool: + """ + Returns true if the 'cuda' backend is available. + """ + pass + +@staticmethod +def get_num_threads() -> int: + """ + Returns the number of threads used by the candle. + """ + pass + +@staticmethod +def has_accelerate() -> bool: + """ + Returns true if candle was compiled with 'accelerate' support. + """ + pass + +@staticmethod +def has_mkl() -> bool: + """ + Returns true if candle was compiled with MKL support. + """ + pass + +@staticmethod +def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: + """ + Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, + a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. + """ + pass + +@staticmethod +def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: + """ + Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, + and the second maps metadata keys to metadata values. + """ + pass + +@staticmethod +def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]: + """ + Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. + """ + pass + +@staticmethod +def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]): + """ + Save quanitzed tensors and metadata to a GGUF file. + """ + pass + +@staticmethod +def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None: + """ + Saves a dictionary of tensors to a safetensors file. + """ + pass diff --git a/candle-pyo3/pyproject.toml b/candle-pyo3/pyproject.toml new file mode 100644 index 00000000..88793493 --- /dev/null +++ b/candle-pyo3/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = 'candle-nn' +requires-python = '>=3.7' +authors = [ + {name = 'The Candle Team'}, +] + +dynamic = [ + 'description', + 'license', + 'readme', +] + +[project.urls] +Homepage = 'https://github.com/huggingface/candle' +Source = 'https://github.com/huggingface/candle' + +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "py_src" +module-name = "candle.candle" +bindings = 'pyo3' +features = ["pyo3/extension-module"] + +[tool.black] +line-length = 119 +target-version = ['py35'] diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 7d74c25e..46d9ff62 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -1,26 +1,28 @@ # This example shows how the candle Python api can be used to replicate llama.cpp. import sys +from typing import Dict, Tuple, Any import candle +from candle import Tensor, QTensor, utils, nn MAX_SEQ_LEN = 4096 -def masked_fill(on_false, mask, on_true): +def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor): shape = mask.shape on_true = candle.tensor(on_true).broadcast_as(shape) return mask.where_cond(on_true, on_false) class RmsNorm: - def __init__(self, qtensor): + def __init__(self, qtensor:QTensor): self.weight = qtensor.dequantize() - def __call__(self, x): + def __call__(self, x:Tensor): b_size, seq_len, hidden_size = x.shape norm_x = x.sqr().sum_keepdim(2) / hidden_size x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) return x_normed.broadcast_mul(self.weight) class QuantizedLayer: - def __init__(self, layer_idx, hparams, all_tensors, cos_sin): + def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]): p = f"layers.{layer_idx}" self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] @@ -40,7 +42,7 @@ class QuantizedLayer: self.cos = cos_sin[0] self.sin = cos_sin[1] - def __call__(self, x, mask, index_pos): + def __call__(self, x:Tensor, mask:Tensor, index_pos:int): residual = x x = self.attn_norm(x) attn = self.forward_attn(x, mask, index_pos) @@ -50,11 +52,11 @@ class QuantizedLayer: x = self.ffn_norm(x) w1 = self.ffw1.matmul_t(x) w3 = self.ffw3.matmul_t(x) - mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3) + mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) return mlp + residual - def forward_attn(self, x, mask, index_pos): + def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int): b_size, seq_len, n_embd = x.shape q = self.attention_wq.matmul_t(x) k = self.attention_wk.matmul_t(x) @@ -79,12 +81,12 @@ class QuantizedLayer: att = q.matmul(k.t()) / self.head_dim**0.5 mask = mask.broadcast_as(att.shape) att = masked_fill(att, mask, float("-inf")) - att = candle.nn.softmax(att, -1) + att = nn.softmax(att, -1) y = att.matmul(v.contiguous()) y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) return self.attention_wo.matmul_t(y) - def apply_rotary_emb(self, x, index_pos): + def apply_rotary_emb(self, x:Tensor, index_pos:int): (b_size, n_head, seq_len, n_embd) = x.shape cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) @@ -106,17 +108,18 @@ def precompute_freqs_cis(hparams, freq_base): return (m.cos(), m.sin()) class QuantizedLlama: - def __init__(self, hparams, all_tensors): + def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]): self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() self.norm = RmsNorm(all_tensors["norm.weight"]) self.output = all_tensors["output.weight"] self.layers = [] - cos_sin = precompute_freqs_cis(hparams, 10000.) + rope_freq = hparams.get("rope_freq", 10000.) + cos_sin = precompute_freqs_cis(hparams, rope_freq) for layer_idx in range(hparams["n_layer"]): layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) self.layers.append(layer) - def __call__(self, token, index_pos): + def __call__(self, token:Tensor, index_pos:int): b_size, seq_len = token.shape vocab_size, hidden_size = self.tok_embeddings.shape token = token.reshape((b_size * seq_len,)) @@ -133,17 +136,47 @@ class QuantizedLlama: x = self.output.matmul_t(x) return x +def gguf_rename(tensor_name:str): + if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight' + if tensor_name == 'output_norm.weight': return 'norm.weight' + tensor_name = tensor_name.replace('blk.', 'layers.') + tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.') + tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.') + tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.') + tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.') + tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.') + tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.') + tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.') + tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.') + return tensor_name + def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors = candle.load_gguf(sys.argv[1]) - hparams = None - vocab = None + all_tensors, metadata = utils.load_gguf(sys.argv[1]) + vocab = metadata["tokenizer.ggml.tokens"] + for i, v in enumerate(vocab): + vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') + hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")} + print(hparams) + hparams = { + 'n_vocab': len(vocab), + 'n_embd': metadata['llama.embedding_length'], + 'n_mult': 256, + 'n_head': metadata['llama.attention.head_count'], + 'n_head_kv': metadata['llama.attention.head_count_kv'], + 'n_layer': metadata['llama.block_count'], + 'n_rot': metadata['llama.rope.dimension_count'], + 'rope_freq': metadata.get('llama.rope.freq_base', 10000.), + 'ftype': metadata['general.file_type'], + } + all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } + else: - all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 79f86479..55b7a888 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,8 +1,7 @@ #![allow(clippy::redundant_closure_call)] -// TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; use std::sync::Arc; @@ -32,6 +31,7 @@ impl From<PyShape> for ::candle::Shape { #[derive(Clone, Debug)] #[pyclass(name = "Tensor")] +/// A `candle` tensor. struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { @@ -44,6 +44,7 @@ impl std::ops::Deref for PyTensor { #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[pyclass(name = "DType")] +/// A `candle` dtype. struct PyDType(DType); #[pymethods] @@ -198,38 +199,40 @@ trait MapDType { #[pymethods] impl PyTensor { #[new] + #[pyo3(text_signature = "(self, data:_ArrayLike)")] // TODO: Handle arbitrary input dtype and shape. - fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> { + /// Creates a new tensor from a Python value. The value can be a scalar or array-like object. + fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> { use Device::Cpu; - let tensor = if let Ok(vs) = vs.extract::<u32>(py) { + let tensor = if let Ok(vs) = data.extract::<u32>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<i64>(py) { + } else if let Ok(vs) = data.extract::<i64>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<f32>(py) { + } else if let Ok(vs) = data.extract::<f32>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { + } else if let Ok(vs) = data.extract::<Vec<u32>>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) { + } else if let Ok(vs) = data.extract::<Vec<i64>>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) { + } else if let Ok(vs) = data.extract::<Vec<f32>>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) { + } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { - let ty = vs.as_ref(py).get_type(); + let ty = data.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( "incorrect type {ty} for tensor" )))? @@ -237,7 +240,8 @@ impl PyTensor { Ok(Self(tensor)) } - /// Gets the tensor data as a Python value/array/array of array/... + /// Gets the tensor's data as a Python scalar or array-like object. + /// &RETURNS&: _ArrayLike fn values(&self, py: Python<'_>) -> PyResult<PyObject> { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { @@ -281,26 +285,36 @@ impl PyTensor { } #[getter] + /// Gets the tensor's shape. + /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } #[getter] + /// Gets the tensor's strides. + /// &RETURNS&: Tuple[int] fn stride(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.stride()).to_object(py) } #[getter] + /// Gets the tensor's dtype. + /// &RETURNS&: DType fn dtype(&self) -> PyDType { PyDType(self.0.dtype()) } #[getter] + /// Gets the tensor's device. + /// &RETURNS&: Device fn device(&self, py: Python<'_>) -> PyObject { PyDevice::from_device(self.0.device()).to_object(py) } #[getter] + /// Gets the tensor's rank. + /// &RETURNS&: int fn rank(&self) -> usize { self.0.rank() } @@ -313,69 +327,117 @@ impl PyTensor { self.__repr__() } + /// Performs the `sin` operation on the tensor. + /// &RETURNS&: Tensor fn sin(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sin().map_err(wrap_err)?)) } + /// Performs the `cos` operation on the tensor. + /// &RETURNS&: Tensor fn cos(&self) -> PyResult<Self> { Ok(PyTensor(self.0.cos().map_err(wrap_err)?)) } + /// Performs the `log` operation on the tensor. + /// &RETURNS&: Tensor fn log(&self) -> PyResult<Self> { Ok(PyTensor(self.0.log().map_err(wrap_err)?)) } + /// Squares the tensor. + /// &RETURNS&: Tensor fn sqr(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sqr().map_err(wrap_err)?)) } + /// Calculates the square root of the tensor. + /// &RETURNS&: Tensor fn sqrt(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?)) } + /// Get the `recip` of the tensor. + /// &RETURNS&: Tensor fn recip(&self) -> PyResult<Self> { Ok(PyTensor(self.0.recip().map_err(wrap_err)?)) } + /// Performs the `exp` operation on the tensor. + /// &RETURNS&: Tensor fn exp(&self) -> PyResult<Self> { Ok(PyTensor(self.0.exp().map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, p:float)")] + /// Performs the `pow` operation on the tensor with the given exponent. + /// &RETURNS&: Tensor fn powf(&self, p: f64) -> PyResult<Self> { Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")] + /// Select values for the input tensor at the target indexes across the specified dimension. + /// + /// The `indexes` is argument is an int tensor with a single dimension. + /// The output has the same number of dimension as the `self` input. The target dimension of + /// the output has length the length of `indexes` and the values are taken from `self` using + /// the index from `indexes`. Other dimensions have the same number of elements as the input + /// tensor. + /// &RETURNS&: Tensor fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Performs a matrix multiplication between the two tensors. + /// &RETURNS&: Tensor fn matmul(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")] + /// Returns a tensor with the same shape as the input tensor, the values are taken from + /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + /// input tensor is equal to zero. + /// &RETURNS&: Tensor fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> { Ok(PyTensor( self.0.where_cond(on_true, on_false).map_err(wrap_err)?, )) } + /// Add two tensors. + /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 + &rhs.0).map_err(wrap_err)? @@ -391,6 +453,8 @@ impl PyTensor { self.__add__(rhs) } + /// Multiply two tensors. + /// &RETURNS&: Tensor fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 * &rhs.0).map_err(wrap_err)? @@ -406,6 +470,8 @@ impl PyTensor { self.__mul__(rhs) } + /// Subtract two tensors. + /// &RETURNS&: Tensor fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 - &rhs.0).map_err(wrap_err)? @@ -417,6 +483,8 @@ impl PyTensor { Ok(Self(tensor)) } + /// Divide two tensors. + /// &RETURNS&: Tensor fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 / &rhs.0).map_err(wrap_err)? @@ -428,62 +496,102 @@ impl PyTensor { Ok(Self(tensor)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Reshapes the tensor to the given shape. + /// &RETURNS&: Tensor fn reshape(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Broadcasts the tensor to the given shape. + /// &RETURNS&: Tensor fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Broadcasts the tensor to the given shape, adding new dimensions on the left. + /// &RETURNS&: Tensor fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Creates a new tensor with the specified dimension removed if its size was one. + /// &RETURNS&: Tensor fn squeeze(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Creates a new tensor with a dimension of size one inserted at the specified position. + /// &RETURNS&: Tensor fn unsqueeze(&self, dim: usize) -> PyResult<Self> { Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, index:int)")] + /// Gets the value at the specified index. + /// &RETURNS&: Tensor fn get(&self, index: i64) -> PyResult<Self> { let index = actual_index(self, 0, index).map_err(wrap_err)?; Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim1:int, dim2:int)")] + /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped. + /// &RETURNS&: Tensor fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")] + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` + /// ranges from `start` to `start + len`. + /// &RETURNS&: Tensor fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; let start = actual_index(self, dim, start).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Returns the indices of the maximum value(s) across the selected dimension. + /// &RETURNS&: Tensor fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Returns the indices of the minimum value(s) across the selected dimension. + /// &RETURNS&: Tensor fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Gathers the maximum value across the selected dimension. + /// &RETURNS&: Tensor fn max_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Gathers the minimum value across the selected dimension. + /// &RETURNS&: Tensor fn min_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")] + /// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. + /// &RETURNS&: Tensor fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> { let dims = if let Ok(dim) = dims.extract::<usize>(py) { vec![dim] @@ -495,10 +603,14 @@ impl PyTensor { )) } + /// Returns the sum of the tensor. + /// &RETURNS&: Tensor fn sum_all(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) } + /// Returns the mean of the tensor. + /// &RETURNS&: Tensor fn mean_all(&self) -> PyResult<Self> { let elements = self.0.elem_count(); let sum = self.0.sum_all().map_err(wrap_err)?; @@ -506,54 +618,83 @@ impl PyTensor { Ok(PyTensor(mean)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. + /// &RETURNS&: Tensor fn flatten_from(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). + /// &RETURNS&: Tensor fn flatten_to(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?)) } + /// Flattens the tensor into a 1D tensor. + /// &RETURNS&: Tensor fn flatten_all(&self) -> PyResult<Self> { Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) } + /// Transposes the tensor. + /// &RETURNS&: Tensor fn t(&self) -> PyResult<Self> { Ok(PyTensor(self.0.t().map_err(wrap_err)?)) } + /// Makes the tensor contiguous in memory. + /// &RETURNS&: Tensor fn contiguous(&self) -> PyResult<Self> { Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?)) } + /// Returns true if the tensor is contiguous in C order. + /// &RETURNS&: bool fn is_contiguous(&self) -> bool { self.0.is_contiguous() } + /// Returns true if the tensor is contiguous in Fortran order. + /// &RETURNS&: bool fn is_fortran_contiguous(&self) -> bool { self.0.is_fortran_contiguous() } + /// Detach the tensor from the computation graph. + /// &RETURNS&: Tensor fn detach(&self) -> PyResult<Self> { Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) } + /// Returns a copy of the tensor. + /// &RETURNS&: Tensor fn copy(&self) -> PyResult<Self> { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dtype:Union[str,DType])")] + /// Convert the tensor to a new dtype. + /// &RETURNS&: Tensor fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> { let dtype = PyDType::from_pyobject(dtype, py)?; Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, device:Union[str,Device])")] + /// Move the tensor to a new device. + /// &RETURNS&: Tensor fn to_device(&self, device: PyDevice) -> PyResult<Self> { let device = device.as_device()?; Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, quantized_dtype:str)")] + /// Quantize the tensor. + /// &RETURNS&: QTensor fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { use ::candle::quantized; let res = match quantized_dtype { @@ -581,8 +722,10 @@ impl PyTensor { } } -/// Concatenate the tensors across one axis. #[pyfunction] +#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")] +/// Concatenate the tensors across one axis. +/// &RETURNS&: Tensor fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> { if tensors.is_empty() { return Err(PyErr::new::<PyValueError, _>("empty input to cat")); @@ -594,6 +737,9 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> { } #[pyfunction] +#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")] +/// Stack the tensors along a new axis. +/// &RETURNS&: Tensor fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>(); let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?; @@ -601,12 +747,17 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { } #[pyfunction] -fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> { - PyTensor::new(py, vs) +#[pyo3(text_signature = "(data:_ArrayLike)")] +/// Creates a new tensor from a Python value. The value can be a scalar or array-like object. +/// &RETURNS&: Tensor +fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> { + PyTensor::new(py, data) } #[pyfunction] -#[pyo3(signature = (shape, *, device=None))] +#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +/// Creates a new tensor with random values. +/// &RETURNS&: Tensor fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -614,7 +765,9 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P } #[pyfunction] -#[pyo3(signature = (shape, *, device=None))] +#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +/// Creates a new tensor with random values from a normal distribution. +/// &RETURNS&: Tensor fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -622,7 +775,9 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult< } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None))] +#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +/// Creates a new tensor filled with ones. +/// &RETURNS&: Tensor fn ones( py: Python<'_>, shape: PyShape, @@ -639,7 +794,9 @@ fn ones( } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None))] +#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +/// Creates a new tensor filled with zeros. +/// &RETURNS&: Tensor fn zeros( py: Python<'_>, shape: PyShape, @@ -655,8 +812,9 @@ fn zeros( Ok(PyTensor(tensor)) } -#[derive(Debug)] +#[derive(Debug, Clone)] #[pyclass(name = "QTensor")] +/// A quantized tensor. struct PyQTensor(Arc<QTensor>); impl std::ops::Deref for PyQTensor { @@ -670,16 +828,22 @@ impl std::ops::Deref for PyQTensor { #[pymethods] impl PyQTensor { #[getter] + ///Gets the tensors quantized dtype. + /// &RETURNS&: str fn ggml_dtype(&self) -> String { format!("{:?}", self.0.dtype()) } #[getter] + ///Gets the rank of the tensor. + /// &RETURNS&: int fn rank(&self) -> usize { self.0.rank() } #[getter] + ///Gets the shape of the tensor. + /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.shape().dims()).to_object(py) } @@ -692,11 +856,16 @@ impl PyQTensor { self.__repr__() } + /// Dequantizes the tensor. + /// &RETURNS&: Tensor fn dequantize(&self) -> PyResult<PyTensor> { let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?; Ok(PyTensor(tensor)) } + #[pyo3(text_signature = "(self, lhs:Tensor)")] + /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. + /// &RETURNS&: Tensor fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> { let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()); let res = qmatmul.forward(lhs).map_err(wrap_err)?; @@ -705,6 +874,9 @@ impl PyQTensor { } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. +/// &RETURNS&: Dict[str,Tensor] fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?; let res = res @@ -715,6 +887,25 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")] +/// Saves a dictionary of tensors to a safetensors file. +/// &RETURNS&: None +fn save_safetensors( + path: &str, + tensors: std::collections::HashMap<String, PyTensor>, +) -> PyResult<()> { + let tensors = tensors + .into_iter() + .map(|(s, t)| (s, t.0)) + .collect::<std::collections::HashMap<_, _>>(); + ::candle::safetensors::save(&tensors, path).map_err(wrap_err) +} + +#[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, +/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. +/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; @@ -746,10 +937,39 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, +/// and the second maps metadata keys to metadata values. +/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] +fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { + use ::candle::quantized::gguf_file; + fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { + let v: PyObject = match v { + gguf_file::Value::U8(x) => x.into_py(py), + gguf_file::Value::I8(x) => x.into_py(py), + gguf_file::Value::U16(x) => x.into_py(py), + gguf_file::Value::I16(x) => x.into_py(py), + gguf_file::Value::U32(x) => x.into_py(py), + gguf_file::Value::I32(x) => x.into_py(py), + gguf_file::Value::U64(x) => x.into_py(py), + gguf_file::Value::I64(x) => x.into_py(py), + gguf_file::Value::F32(x) => x.into_py(py), + gguf_file::Value::F64(x) => x.into_py(py), + gguf_file::Value::Bool(x) => x.into_py(py), + gguf_file::Value::String(x) => x.into_py(py), + gguf_file::Value::Array(x) => { + let list = pyo3::types::PyList::empty(py); + for elem in x.iter() { + list.append(gguf_value_to_pyobject(elem, py)?)?; + } + list.into() + } + }; + Ok(v) + } let mut file = std::fs::File::open(path)?; - let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?; - let res = gguf + let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?; + let tensors = gguf .tensor_infos .keys() .map(|key| { @@ -758,25 +978,129 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { }) .collect::<::candle::Result<Vec<_>>>() .map_err(wrap_err)?; - Ok(res.into_py_dict(py).to_object(py)) + let tensors = tensors.into_py_dict(py).to_object(py); + let metadata = gguf + .metadata + .iter() + .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?))) + .collect::<PyResult<Vec<_>>>()? + .into_py_dict(py) + .to_object(py); + Ok((tensors, metadata)) } #[pyfunction] +#[pyo3( + text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])" +)] +/// Save quanitzed tensors and metadata to a GGUF file. +fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> { + use ::candle::quantized::gguf_file; + + fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> { + let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() { + gguf_file::Value::U8(x) + } else if let Ok(x) = v.extract::<i8>() { + gguf_file::Value::I8(x) + } else if let Ok(x) = v.extract::<u16>() { + gguf_file::Value::U16(x) + } else if let Ok(x) = v.extract::<i16>() { + gguf_file::Value::I16(x) + } else if let Ok(x) = v.extract::<u32>() { + gguf_file::Value::U32(x) + } else if let Ok(x) = v.extract::<i32>() { + gguf_file::Value::I32(x) + } else if let Ok(x) = v.extract::<u64>() { + gguf_file::Value::U64(x) + } else if let Ok(x) = v.extract::<i64>() { + gguf_file::Value::I64(x) + } else if let Ok(x) = v.extract::<f32>() { + gguf_file::Value::F32(x) + } else if let Ok(x) = v.extract::<f64>() { + gguf_file::Value::F64(x) + } else if let Ok(x) = v.extract::<bool>() { + gguf_file::Value::Bool(x) + } else if let Ok(x) = v.extract::<String>() { + gguf_file::Value::String(x) + } else if let Ok(x) = v.extract::<Vec<PyObject>>() { + let x = x + .into_iter() + .map(|f| pyobject_to_gguf_value(f.as_ref(py), py)) + .collect::<PyResult<Vec<_>>>()?; + gguf_file::Value::Array(x) + } else { + return Err(PyErr::new::<PyValueError, _>(format!( + "unsupported type {:?}", + v + ))); + }; + Ok(v) + } + let tensors = tensors + .extract::<&PyDict>(py) + .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))? + .iter() + .map(|(key, value)| { + Ok(( + key.extract::<String>() + .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?, + value.extract::<PyQTensor>()?.0, + )) + }) + .collect::<PyResult<Vec<_>>>()?; + + let metadata = metadata + .extract::<&PyDict>(py) + .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))? + .iter() + .map(|(key, value)| { + Ok(( + key.extract::<String>() + .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?, + pyobject_to_gguf_value(value, py)?, + )) + }) + .collect::<PyResult<Vec<_>>>()?; + + let converted_metadata: Vec<_> = metadata + .iter() + .map(|(name, value)| (name.as_str(), value)) + .collect(); + + let converted_tensors: Vec<_> = tensors + .iter() + .map(|(name, tensor)| (name.as_str(), tensor.as_ref())) + .collect(); + + let mut file = std::fs::File::create(path)?; + + gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err) +} + +#[pyfunction] +/// Returns true if the 'cuda' backend is available. +/// &RETURNS&: bool fn cuda_is_available() -> bool { ::candle::utils::cuda_is_available() } #[pyfunction] +/// Returns true if candle was compiled with 'accelerate' support. +/// &RETURNS&: bool fn has_accelerate() -> bool { ::candle::utils::has_accelerate() } #[pyfunction] +/// Returns true if candle was compiled with MKL support. +/// &RETURNS&: bool fn has_mkl() -> bool { ::candle::utils::has_mkl() } #[pyfunction] +/// Returns the number of threads used by the candle. +/// &RETURNS&: int fn get_num_threads() -> usize { ::candle::utils::get_num_threads() } @@ -786,19 +1110,30 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_num_threads, m)?)?; m.add_function(wrap_pyfunction!(has_accelerate, m)?)?; m.add_function(wrap_pyfunction!(has_mkl, m)?)?; + m.add_function(wrap_pyfunction!(load_ggml, m)?)?; + m.add_function(wrap_pyfunction!(load_gguf, m)?)?; + m.add_function(wrap_pyfunction!(save_gguf, m)?)?; + m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; + m.add_function(wrap_pyfunction!(save_safetensors, m)?)?; Ok(()) } #[pyfunction] -fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> { - let dim = actual_dim(&t, dim).map_err(wrap_err)?; - let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?; +#[pyo3(text_signature = "(tensor:Tensor, dim:int)")] +/// Applies the Softmax function to a given tensor.# +/// &RETURNS&: Tensor +fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> { + let dim = actual_dim(&tensor, dim).map_err(wrap_err)?; + let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?; Ok(PyTensor(sm)) } #[pyfunction] -fn silu(t: PyTensor) -> PyResult<PyTensor> { - let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?; +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. +/// &RETURNS&: Tensor +fn silu(tensor: PyTensor) -> PyResult<PyTensor> { + let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?; Ok(PyTensor(s)) } @@ -827,9 +1162,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add("f32", PyDType(DType::F32))?; m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; - m.add_function(wrap_pyfunction!(load_ggml, m)?)?; - m.add_function(wrap_pyfunction!(load_gguf, m)?)?; - m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; m.add_function(wrap_pyfunction!(ones, m)?)?; m.add_function(wrap_pyfunction!(rand, m)?)?; m.add_function(wrap_pyfunction!(randn, m)?)?; diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py new file mode 100644 index 00000000..149715c2 --- /dev/null +++ b/candle-pyo3/stub.py @@ -0,0 +1,232 @@ +#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py +import argparse +import inspect +import os +from typing import Optional +import black +from pathlib import Path + + +INDENT = " " * 4 +GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" +TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +""" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" +CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" +RETURN_TYPE_MARKER = "&RETURNS&: " + + +def do_indent(text: Optional[str], indent: str): + if text is None: + return "" + return text.replace("\n", f"\n{indent}") + + +def function(obj, indent:str, text_signature:str=None): + if text_signature is None: + text_signature = obj.__text_signature__ + + text_signature = text_signature.replace("$self", "self").lstrip().rstrip() + doc_string = obj.__doc__ + if doc_string is None: + doc_string = "" + + # Check if we have a return type annotation in the docstring + return_type = None + doc_lines = doc_string.split("\n") + if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER): + # Extract the return type and remove it from the docstring + return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip() + doc_string = "\n".join(doc_lines[:-1]) + + string = "" + if return_type: + string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n" + else: + string += f"{indent}def {obj.__name__}{text_signature}:\n" + indent += INDENT + string += f'{indent}"""\n' + string += f"{indent}{do_indent(doc_string, indent)}\n" + string += f'{indent}"""\n' + string += f"{indent}pass\n" + string += "\n" + string += "\n" + return string + + +def member_sort(member): + if inspect.isclass(member): + value = 10 + len(inspect.getmro(member)) + else: + value = 1 + return value + + +def fn_predicate(obj): + value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) + if value: + return obj.__text_signature__ and not obj.__name__.startswith("_") + if inspect.isgetsetdescriptor(obj): + return not obj.__name__.startswith("_") + return False + + +def get_module_members(module): + members = [ + member + for name, member in inspect.getmembers(module) + if not name.startswith("_") and not inspect.ismodule(member) + ] + members.sort(key=member_sort) + return members + + +def pyi_file(obj, indent=""): + string = "" + if inspect.ismodule(obj): + string += GENERATED_COMMENT + string += TYPING + string += CANDLE_SPECIFIC_TYPING + if obj.__name__ != "candle.candle": + string += CANDLE_TENSOR_IMPORTS + members = get_module_members(obj) + for member in members: + string += pyi_file(member, indent) + + elif inspect.isclass(obj): + indent += INDENT + mro = inspect.getmro(obj) + if len(mro) > 2: + inherit = f"({mro[1].__name__})" + else: + inherit = "" + string += f"class {obj.__name__}{inherit}:\n" + + body = "" + if obj.__doc__: + body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' + + fns = inspect.getmembers(obj, fn_predicate) + + # Init + if obj.__text_signature__: + body += f"{indent}def __init__{obj.__text_signature__}:\n" + body += f"{indent+INDENT}pass\n" + body += "\n" + + for (name, fn) in fns: + body += pyi_file(fn, indent=indent) + + if not body: + body += f"{indent}pass\n" + + string += body + string += "\n\n" + + elif inspect.isbuiltin(obj): + string += f"{indent}@staticmethod\n" + string += function(obj, indent) + + elif inspect.ismethoddescriptor(obj): + string += function(obj, indent) + + elif inspect.isgetsetdescriptor(obj): + # TODO it would be interesing to add the setter maybe ? + string += f"{indent}@property\n" + string += function(obj, indent, text_signature="(self)") + + elif obj.__class__.__name__ == "DType": + string += f"class {str(obj).lower()}(DType):\n" + string += f"{indent+INDENT}pass\n" + else: + raise Exception(f"Object {obj} is not supported") + return string + + +def py_file(module, origin): + members = get_module_members(module) + + string = GENERATED_COMMENT + string += f"from .. import {origin}\n" + string += "\n" + for member in members: + if hasattr(member, "__name__"): + name = member.__name__ + else: + name = str(member) + string += f"{name} = {origin}.{name}\n" + return string + + +def do_black(content, is_pyi): + mode = black.Mode( + target_versions={black.TargetVersion.PY35}, + line_length=119, + is_pyi=is_pyi, + string_normalization=True, + experimental_string_processing=False, + ) + try: + return black.format_file_contents(content, fast=True, mode=mode) + except black.NothingChanged: + return content + + +def write(module, directory, origin, check=False): + submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)] + + filename = os.path.join(directory, "__init__.pyi") + pyi_content = pyi_file(module) + pyi_content = do_black(pyi_content, is_pyi=True) + os.makedirs(directory, exist_ok=True) + if check: + with open(filename, "r") as f: + data = f.read() + assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(pyi_content) + + filename = os.path.join(directory, "__init__.py") + py_content = py_file(module, origin) + py_content = do_black(py_content, is_pyi=False) + os.makedirs(directory, exist_ok=True) + + is_auto = False + if not os.path.exists(filename): + is_auto = True + else: + with open(filename, "r") as f: + line = f.readline() + if line == GENERATED_COMMENT: + is_auto = True + + if is_auto: + if check: + with open(filename, "r") as f: + data = f.read() + assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(py_content) + + for name, submodule in submodules: + write(submodule, os.path.join(directory, name), f"{name}", check=check) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") + + args = parser.parse_args() + + #Enable execution from the candle and candle-pyo3 directories + cwd = Path.cwd() + directory = "py_src/candle/" + if cwd.name != "candle-pyo3": + directory = f"candle-pyo3/{directory}" + + import candle + + write(candle.candle, directory, "candle", check=args.check) diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index a05b9bb7..a3115c2b 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -11,14 +11,21 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.2.1" } +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true } +candle-nn = { path = "../candle-nn", version = "0.2.3" } intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } rand = { workspace = true } +rayon = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } wav = { workspace = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] +flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b1d20168..b1a567c3 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,35 +1,82 @@ -use candle::{DType, Error, Result, Tensor, D}; +use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; pub struct LogitsProcessor { rng: rand::rngs::StdRng, temperature: Option<f64>, + top_p: Option<f64>, } impl LogitsProcessor { - pub fn new(seed: u64, temperature: Option<f64>) -> Self { + pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self { + let temperature = if temperature.map_or(true, |v| v < 1e-7) { + None + } else { + temperature + }; Self { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, + top_p, + } + } + + fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> { + let logits_v: Vec<f32> = logits.to_vec1()?; + let next_token = logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap(); + Ok(next_token) + } + + fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> { + let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; + let next_token = distr.sample(&mut self.rng) as u32; + Ok(next_token) + } + + fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability top_p. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); + + // Sort by descending probability. + argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); + + // Clamp smaller probabilities to zero. + let mut cumsum = 0.; + for index in &argsort_indices { + if cumsum >= top_p { + prs[*index] = 0.0; + } else { + cumsum += prs[*index]; + } } + // Sample with clamped probabilities. + self.sample_multinomial(prs) } pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; - let temperature = self.temperature.unwrap_or(0.); - let next_token = if temperature > 0. { - let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; - let prs: Vec<f32> = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; - distr.sample(&mut self.rng) as u32 - } else { - let logits_v: Vec<f32> = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() + let next_token = match self.temperature { + None => self.sample_argmax(logits)?, + Some(temperature) => { + let logits = &(&logits / temperature)?; + let prs = candle_nn::ops::softmax_last_dim(logits)?; + let mut prs: Vec<f32> = prs.to_vec1()?; + let top_p = self.top_p.unwrap_or(1.); + if top_p <= 0.0 || top_p >= 1.0 { + // simply sample from the predicted probability distribution + self.sample_multinomial(&prs)? + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + self.sample_topp(&mut prs, top_p as f32)? + } + } }; Ok(next_token) } diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index a8890dc8..b83e5056 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,4 +1,5 @@ pub mod generation; pub mod models; +pub mod object_detection; pub mod pipelines; pub mod utils; diff --git a/candle-examples/examples/bert/model.rs b/candle-transformers/src/models/bert.rs index 3f164a3a..3f164a3a 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-transformers/src/models/bert.rs diff --git a/candle-examples/examples/bigcode/model.rs b/candle-transformers/src/models/bigcode.rs index 1e63956b..1e63956b 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-transformers/src/models/bigcode.rs diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs new file mode 100644 index 00000000..0edc8494 --- /dev/null +++ b/candle-transformers/src/models/dinov2.rs @@ -0,0 +1,279 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 1000; + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result<Self> { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result<Self> { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result<Self> { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + pos_embed: Tensor, + blocks: Vec<Block>, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let num_tokens = 1; + let pos_embed = vb.get( + (1, patch_embed.num_patches + num_tokens, embed_dim), + "pos_embed", + )?; + let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + patch_embed, + cls_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(xs.clone()); + } + let class_pos_embed = self.pos_embed.i((.., ..1))?; + let patch_pos_embed = self.pos_embed.i((.., 1..))?; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + let patch_pos_embed = + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim))?; + Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _nc, w, h) = xs.dims4()?; + let xs = self.patch_embed.forward(xs)?; + let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; + &xs + &self.interpolate_pos_encoding(&xs, w, h)? + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; + let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; + self.head.forward(&xs) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { + DinoVisionTransformer::new(vb, 12, 384, 6) +} diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs new file mode 100644 index 00000000..ab51c76d --- /dev/null +++ b/candle-transformers/src/models/efficientnet.rs @@ -0,0 +1,331 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + pub fn b0() -> Vec<Self> { + bneck_confs(1.0, 1.0) + } + pub fn b1() -> Vec<Self> { + bneck_confs(1.0, 1.1) + } + pub fn b2() -> Vec<Self> { + bneck_confs(1.1, 1.2) + } + pub fn b3() -> Vec<Self> { + bneck_confs(1.2, 1.4) + } + pub fn b4() -> Vec<Self> { + bneck_confs(1.4, 1.8) + } + pub fn b5() -> Vec<Self> { + bneck_confs(1.6, 2.2) + } + pub fn b6() -> Vec<Self> { + bneck_confs(1.8, 2.6) + } + pub fn b7() -> Vec<Self> { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result<Self> { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual.broadcast_mul(&xs) + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option<ConvNormActivation>, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result<Tensor> { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +pub struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec<MBConv>, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; + self.classifier.forward(&xs) + } +} diff --git a/candle-examples/examples/falcon/model.rs b/candle-transformers/src/models/falcon.rs index b638dd51..6ede136a 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,4 @@ -use anyhow::Result; -use candle::{DType, Device, Tensor, D}; +use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { - return Err(err.into()); + return Err(err); } } }; @@ -82,13 +81,13 @@ impl Default for Config { impl Config { pub fn validate(&self) -> Result<()> { if self.alibi { - anyhow::bail!("alibi is not supported"); + candle::bail!("alibi is not supported"); } if self.new_decoder_architecture { - anyhow::bail!("new_decoder_architecture is not supported"); + candle::bail!("new_decoder_architecture is not supported"); } if self.n_head_kv.is_some() { - anyhow::bail!("n_head_kv is not supported"); + candle::bail!("n_head_kv is not supported"); } Ok(()) } diff --git a/candle-examples/examples/llama/model.rs b/candle-transformers/src/models/llama.rs index 275856e0..eed4df5e 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-transformers/src/models/llama.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use super::MAX_SEQ_LEN; +pub const MAX_SEQ_LEN: usize = 4096; #[derive(Deserialize)] pub struct LlamaConfig { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8b137891..d783a2c6 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1 +1,13 @@ - +pub mod bert; +pub mod bigcode; +pub mod dinov2; +pub mod efficientnet; +pub mod falcon; +pub mod llama; +pub mod quantized_llama; +pub mod quantized_t5; +pub mod segment_anything; +pub mod stable_diffusion; +pub mod t5; +pub mod whisper; +pub mod wuerstchen; diff --git a/candle-examples/examples/quantized/model.rs b/candle-transformers/src/models/quantized_llama.rs index da0bd0b0..2988b0fb 100644 --- a/candle-examples/examples/quantized/model.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -144,7 +144,7 @@ impl LayerWeights { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = mask.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs new file mode 100644 index 00000000..a10c3b80 --- /dev/null +++ b/candle-transformers/src/models/quantized_t5.rs @@ -0,0 +1,884 @@ +// T5 Text Model, quantized version +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +use candle::quantized::QTensor; +use candle::{DType, Device, Module, Result, Shape, Tensor, D}; +use candle_nn::Activation; +use serde::Deserialize; +use std::sync::Arc; + +// VarBuilder specialized for QTensors +pub struct VarBuilder { + data: Arc<std::collections::HashMap<String, Arc<QTensor>>>, + path: Vec<String>, + device: Device, +} + +impl VarBuilder { + pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + let mut file = std::fs::File::open(p)?; + let content = candle::quantized::gguf_file::Content::read(&mut file)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut file, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + + fn pp<S: ToString>(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + device: self.device.clone(), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } + + fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => { + let shape = s.into(); + if qtensor.shape() != &shape { + candle::bail!( + "shape mismatch for {name}, got {:?}, expected {shape:?}", + qtensor.shape() + ) + } + Ok(qtensor.clone()) + } + } + } +} + +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?; + let inner = candle_nn::Embedding::new(embeddings, d2); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> { + let ws = vb.get((in_dim, out_dim), "weight")?; + let inner = candle::quantized::QMatMul::from_arc(ws); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } +} + +impl Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +impl std::fmt::Debug for QMatMul { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QMatMul") + } +} + +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + +fn default_tie_word_embeddings() -> bool { + true +} + +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + d_model: usize, + d_kv: usize, + d_ff: usize, + num_layers: usize, + num_decoder_layers: Option<usize>, + num_heads: usize, + relative_attention_num_buckets: usize, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, + dropout_rate: f64, + layer_norm_epsilon: f64, + initializer_factor: f64, + #[serde(default)] + feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, + #[serde(default = "default_is_decoder")] + is_decoder: bool, + is_encoder_decoder: bool, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 32128, + d_model: 512, + d_kv: 64, + d_ff: 2048, + num_layers: 6, + num_decoder_layers: None, + num_heads: 8, + relative_attention_num_buckets: 32, + relative_attention_max_distance: 128, + dropout_rate: 0.1, + layer_norm_epsilon: 1e-6, + initializer_factor: 1.0, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + is_decoder: false, + is_encoder_decoder: true, + use_cache: true, + pad_token_id: 0, + eos_token_id: 1, + } + } +} + +#[derive(Debug)] +struct T5LayerNorm { + weight: Tensor, + variance_epsilon: f64, + span: tracing::Span, +} + +impl T5LayerNorm { + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(h, "weight")?.dequantize(&vb.device)?; + Ok(Self { + weight, + variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } +} + +impl Module for T5LayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseActDense { + wi: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi, + wo, + act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), + }) + } +} + +impl Module for T5DenseActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: QMatMul, + wi_1: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), + }) + } +} + +impl Module for T5DenseGatedActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5LayerFF { + dense_act: Option<T5DenseActDense>, + gated_dense_act: Option<T5DenseGatedActDense>, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerFF { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; + Ok(Self { + dense_act, + gated_dense_act, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), + }) + } +} + +impl Module for T5LayerFF { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let ys = self.layer_norm.forward(xs)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; + let xs = (xs + ys)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5Attention { + q: QMatMul, + k: QMatMul, + v: QMatMul, + o: QMatMul, + n_heads: usize, + d_kv: usize, + relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, + span_sm: tracing::Span, +} + +impl T5Attention { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let inner_dim = cfg.num_heads * cfg.d_kv; + let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let relative_attention_bias = if has_relative_attention_bias { + let emb = Embedding::new( + cfg.relative_attention_num_buckets, + cfg.num_heads, + vb.pp("relative_attention_bias"), + )?; + Some(emb) + } else { + None + }; + Ok(Self { + q, + k, + v, + o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, + relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; + let q = self.q.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; + let q = q + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut k = k + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + + if self.use_cache { + let _enter = self.span_cache.enter(); + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; + // TODO: Use flash_attn. + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; + + let (scores, position_bias) = match position_bias { + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (q_start..q_end) + .map(|i| { + (0..kv_len as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores.broadcast_add(&position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, + }; + + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.inner_dim))?; + let attn_output = self.o.forward(&attn_output)?; + Ok((attn_output, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug)] +struct T5LayerSelfAttention { + self_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerSelfAttention { + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + self_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_xs = self.layer_norm.forward(xs)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; + let ys = (xs + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerCrossAttention { + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), + }) + } + + fn forward( + &mut self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5Block { + self_attn: T5LayerSelfAttention, + cross_attn: Option<T5LayerCrossAttention>, + ff: T5LayerFF, + span: tracing::Span, +} + +impl T5Block { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let vb = vb.pp("layer"); + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; + let cross_attn = if cfg.is_decoder { + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) + } else { + None + }; + let ff_i = if cross_attn.is_some() { 2 } else { 1 }; + let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?; + Ok(Self { + self_attn, + cross_attn, + ff, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; + // TODO: clamp for f16? + if let Some(cross_attn) = &mut self.cross_attn { + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok((xs, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } +} + +#[derive(Debug)] +struct T5Stack { + block: Vec<T5Block>, + shared: Arc<Embedding>, + final_layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5Stack { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { + let block = (0..cfg.num_layers) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg)) + .collect::<Result<Vec<_>>>()?; + let final_layer_norm = T5LayerNorm::load( + cfg.d_model, + cfg.layer_norm_epsilon, + vb.pp("final_layer_norm"), + )?; + Ok(Self { + block, + shared: shared.clone(), + final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let mut hidden_states = input_embeds; + let mut position_bias = None; + for block in self.block.iter_mut() { + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? + } + self.final_layer_norm.forward(&hidden_states) + } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} + +#[derive(Debug)] +pub struct T5EncoderModel { + encoder: T5Stack, + device: Device, + span: tracing::Span, +} + +impl T5EncoderModel { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; + Ok(Self { + encoder, + device: vb.device.clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option<QMatMul>, + shared: Arc<Embedding>, + device: Device, + span_decode: tracing::Span, + span_decode_head: tracing::Span, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + }; + + Ok(Self { + encoder, + decoder, + d_model, + tie_word_embeddings, + lm_head, + shared, + device: vb.device.clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), + }) + } + + pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result<Tensor> { + let _enter = self.span_decode.enter(); + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(encoder_output))?; + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?) + * scaling_factor)?; + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } + }; + + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) + } + + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } +} diff --git a/candle-transformers/src/models/segment_anything/image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs new file mode 100644 index 00000000..0b313830 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/image_encoder.rs @@ -0,0 +1,483 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + span: tracing::Span, +} + +impl PatchEmbed { + fn new( + in_chans: usize, + embed_dim: usize, + k_size: usize, + stride: usize, + padding: usize, + vb: VarBuilder, + ) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { proj, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.proj)?.permute((0, 2, 3, 1)) + } +} + +// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final +// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 +// (attn.reshape((b, q_h, q_w, k_h, k_w))? +// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? +// .reshape((b, q_h * q_w, k_h * k_w)) +// Ideally we would perform this operation in place but this is not supported in candle at the +// moment. We should also investigate using f16 rather than f32. +struct Add3(usize, usize, usize, usize, usize); +impl candle::CustomOp3 for Add3 { + fn name(&self) -> &'static str { + "add3" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + s3: &candle::CpuStorage, + l3: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use rayon::prelude::*; + + let Add3(b, q_h, q_w, k_h, k_w) = *self; + let s1 = s1.as_slice::<f32>()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::<f32>()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let s3 = s3.as_slice::<f32>()?; + let s3 = match l3.contiguous_offsets() { + None => candle::bail!("input3 has to be contiguous"), + Some((o1, o2)) => &s3[o1..o2], + }; + let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; + dst.par_chunks_exact_mut(k_h * k_w) + .enumerate() + .for_each(|(b_idx, dst)| { + let s1_idx = b_idx * k_h * k_w; + let s2_idx = b_idx * k_h; + let s3_idx = b_idx * k_w; + for h_idx in 0..k_h { + let s1_idx = s1_idx + h_idx * k_w; + let s2_idx = s2_idx + h_idx; + let dst_idx = h_idx * k_w; + for w_idx in 0..k_w { + let s1_idx = s1_idx + w_idx; + let s3_idx = s3_idx + w_idx; + let dst_idx = dst_idx + w_idx; + dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] + } + } + }); + let dst = candle::WithDType::to_cpu_storage_owned(dst); + Ok((dst, (b, q_h * q_w, k_h * k_w).into())) + } +} + +#[derive(Debug)] +struct Attention { + qkv: super::Linear, + proj: super::Linear, + num_heads: usize, + scale: f64, + rel_pos_hw: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_matmul: tracing::Span, + span_rel_pos: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = super::linear(vb.pp("proj"), dim, dim, true)?; + let head_dim = dim / num_heads; + let scale = 1. / (head_dim as f64).sqrt(); + let rel_pos_hw = if use_rel_pos { + let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; + Some((h, w)) + } else { + None + }; + Ok(Self { + qkv, + proj, + num_heads, + scale, + rel_pos_hw, + span, + span_matmul, + span_rel_pos, + span_softmax, + }) + } + + fn add_decomposed_rel_pos( + &self, + attn: Tensor, + q: &Tensor, + (q_h, q_w): (usize, usize), + (k_h, k_w): (usize, usize), + ) -> Result<Tensor> { + match &self.rel_pos_hw { + Some((rel_pos_h, rel_pos_w)) => { + let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; + let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; + let (b, _, dim) = q.dims3()?; + let r_q = q.reshape((b, q_h, q_w, dim))?; + // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + let rel_w = r_q + .transpose(1, 2)? // -> bwhc + .contiguous()? + .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk + .transpose(1, 2)? + .contiguous()?; + if attn.device().is_cpu() { + let op = Add3(b, q_h, q_w, k_h, k_w); + attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) + } else { + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } + } + None => Ok(attn), + } + } +} + +fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { + let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; + let dev = rel_pos.device(); + let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { + todo!("interpolation") + } else { + rel_pos + }; + let q_coords = Tensor::arange(0u32, q_size as u32, dev)? + .reshape((q_size, 1))? + .to_dtype(DType::F32)?; + let k_coords = Tensor::arange(0u32, k_size as u32, dev)? + .reshape((1, k_size))? + .to_dtype(DType::F32)?; + let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; + let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; + let relative_coords = (q_coords.broadcast_sub(&k_coords)? + + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; + let (d1, d2) = relative_coords.dims2()?; + let relative_coords = relative_coords.to_dtype(DType::U32)?; + rel_pos_resized + .index_select(&relative_coords.reshape(d1 * d2)?, 0)? + .reshape((d1, d2, ())) +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, h, w, c) = xs.dims4()?; + let qkv = self + .qkv + .forward(&xs.flatten_to(1)?)? + .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? + .permute((2, 0, 3, 1, 4))? + .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = { + let _enter = self.span_matmul.enter(); + (&q * self.scale)?.matmul(&k.t()?)? + }; + let attn = { + let _enter = self.span_rel_pos.enter(); + self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? + }; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + let attn = attn + .reshape((b, self.num_heads, h, w, c / self.num_heads))? + .permute((0, 2, 3, 1, 4))? + .reshape((b, h * w, c))?; + self.proj.forward(&attn)?.reshape((b, h, w, c)) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + window_size: usize, + span: tracing::Span, +} + +impl Block { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let input_size_attn = if window_size == 0 { + input_size + } else { + (window_size, window_size) + }; + let attn = Attention::new( + dim, + num_heads, + qkv_bias, + use_rel_pos, + input_size_attn, + vb.pp("attn"), + )?; + let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let span = tracing::span!(tracing::Level::TRACE, "ie-block"); + Ok(Self { + norm1, + attn, + norm2, + mlp, + window_size, + span, + }) + } +} + +fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { + let (b, h, w, c) = xs.dims4()?; + let pad_h = (window_size - h % window_size) % window_size; + let pad_w = (window_size - w % window_size) % window_size; + let xs = if pad_h > 0 { + xs.pad_with_zeros(1, 0, pad_h)? + } else { + xs + }; + let xs = if pad_w > 0 { + xs.pad_with_zeros(2, 0, pad_w)? + } else { + xs + }; + let (h_p, w_p) = (h + pad_h, w + pad_w); + let windows = xs + .reshape(( + b, + h_p / window_size, + window_size, + w_p / window_size, + window_size, + c, + ))? + .transpose(2, 3)? + .contiguous()? + .flatten_to(2)?; + Ok((windows, (h_p, w_p))) +} + +fn window_unpartition( + windows: Tensor, + window_size: usize, + (h_p, w_p): (usize, usize), + (h, w): (usize, usize), +) -> Result<Tensor> { + let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); + let xs = windows + .reshape(( + b, + h_p / window_size, + w_p / window_size, + window_size, + window_size, + windows.elem_count() / b / h_p / w_p, + ))? + .transpose(2, 3)? + .contiguous()? + .reshape((b, h_p, w_p, ()))?; + let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; + let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; + Ok(xs) +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = self.norm1.forward(xs)?; + let hw = (xs.dim(1)?, xs.dim(2)?); + let (xs, pad_hw) = if self.window_size > 0 { + window_partition(xs, self.window_size)? + } else { + (xs, (0, 0)) + }; + let xs = self.attn.forward(&xs)?; + let xs = if self.window_size > 0 { + window_unpartition(xs, self.window_size, pad_hw, hw)? + } else { + xs + }; + let xs = (xs + shortcut)?; + &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? + } +} + +#[derive(Debug)] +pub struct ImageEncoderViT { + patch_embed: PatchEmbed, + blocks: Vec<Block>, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + pos_embed: Option<Tensor>, + span: tracing::Span, +} + +impl ImageEncoderViT { + #[allow(clippy::too_many_arguments)] + pub fn new( + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + depth: usize, + num_heads: usize, + out_chans: usize, + qkv_bias: bool, + use_rel_pos: bool, + use_abs_pos: bool, + window_size: usize, + global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new( + in_chans, + embed_dim, + patch_size, + patch_size, + 0, + vb.pp("patch_embed"), + )?; + let mut blocks = Vec::with_capacity(depth); + let vb_b = vb.pp("blocks"); + for i in 0..depth { + let window_size = if global_attn_indexes.contains(&i) { + 0 + } else { + window_size + }; + let block = Block::new( + embed_dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + (img_size / patch_size, img_size / patch_size), + vb_b.pp(i), + )?; + blocks.push(block) + } + let neck_conv1 = candle_nn::conv2d_no_bias( + embed_dim, + out_chans, + 1, + Default::default(), + vb.pp("neck.0"), + )?; + let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; + let pos_embed = if use_abs_pos { + let p = vb.get( + (1, img_size / patch_size, img_size / patch_size, embed_dim), + "pos_embed", + )?; + Some(p) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); + Ok(Self { + patch_embed, + blocks, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + pos_embed, + span, + }) + } +} + +impl Module for ImageEncoderViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = match &self.pos_embed { + Some(pos_embed) => (xs + pos_embed)?, + None => xs, + }; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs new file mode 100644 index 00000000..2a91cd44 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -0,0 +1,239 @@ +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::transformer::TwoWayTransformer; + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec<super::Linear>, + sigmoid_output: bool, + span: tracing::Span, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result<Self> { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); + Ok(Self { + layers, + sigmoid_output, + span, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +pub struct MaskDecoder { + iou_token: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, + output_upscaling_conv1: candle_nn::ConvTranspose2d, + output_upscaling_ln: super::LayerNorm2d, + output_upscaling_conv2: candle_nn::ConvTranspose2d, + num_mask_tokens: usize, + output_hypernetworks_mlps: Vec<MlpMaskDecoder>, + transformer: TwoWayTransformer, + span: tracing::Span, +} + +impl MaskDecoder { + pub fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_mask_tokens = num_multimask_outputs + 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let output_upscaling_conv1 = candle_nn::conv_transpose2d( + transformer_dim, + transformer_dim / 4, + 2, + cfg, + vb.pp("output_upscaling.0"), + )?; + let output_upscaling_ln = + super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + let output_upscaling_conv2 = candle_nn::conv_transpose2d( + transformer_dim / 4, + transformer_dim / 8, + 2, + cfg, + vb.pp("output_upscaling.3"), + )?; + let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); + let vb_o = vb.pp("output_hypernetworks_mlps"); + for i in 0..num_mask_tokens { + let mlp = MlpMaskDecoder::new( + transformer_dim, + transformer_dim, + transformer_dim / 8, + 3, + false, + vb_o.pp(i), + )?; + output_hypernetworks_mlps.push(mlp) + } + let transformer = TwoWayTransformer::new( + /* depth */ 2, + /* embedding_dim */ transformer_dim, + /* num_heads */ 8, + /* mlp_dim */ 2048, + vb.pp("transformer"), + )?; + let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); + Ok(Self { + iou_token, + mask_tokens, + iou_prediction_head, + output_upscaling_conv1, + output_upscaling_ln, + output_upscaling_conv2, + num_mask_tokens, + output_hypernetworks_mlps, + transformer, + span, + }) + } + + pub fn forward( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let (masks, iou_pred) = self.predict_masks( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + )?; + let masks = if multimask_output { + masks.i((.., 1..))? + } else { + masks.i((.., 0..1))? + }; + let iou_pred = if multimask_output { + iou_pred.i((.., 1..))? + } else { + iou_pred.i((.., 0..1))? + }; + Ok((masks, iou_pred)) + } + + fn predict_masks( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Concatenate ouput tokens. + let output_tokens = Tensor::cat( + &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], + 0, + )?; + let (d1, d2) = output_tokens.dims2()?; + let output_tokens = + output_tokens + .unsqueeze(0)? + .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; + let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; + + // Expand per-image data in batch direction to be per mask + let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; + let src = src.broadcast_add(dense_prompt_embeddings)?; + let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; + let (b, c, h, w) = src.dims4()?; + + // Run the transformer + let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; + let iou_token_out = hs.i((.., 0))?; + let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; + + // Upscale mask embeddings and predict masks using the masks tokens. + let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; + let upscaled_embedding = self + .output_upscaling_conv1 + .forward(&src)? + .apply(&self.output_upscaling_ln)? + .gelu()? + .apply(&self.output_upscaling_conv2)? + .gelu()?; + let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); + for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { + let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; + hyper_in_list.push(h) + } + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; + let (b, c, h, w) = upscaled_embedding.dims4()?; + let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; + let masks = masks.reshape((b, (), h, w))?; + + // Generate mask quality predictions. + let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; + Ok((masks, iou_pred)) + } +} + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) +} diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs new file mode 100644 index 00000000..c29db70a --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -0,0 +1,100 @@ +use candle::{Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +pub mod image_encoder; +pub mod mask_decoder; +pub mod prompt_encoder; +pub mod sam; +pub mod tiny_vit; +pub mod transformer; + +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] +pub struct MlpBlock { + lin1: Linear, + lin2: Linear, + activation: candle_nn::Activation, + span: tracing::Span, +} + +impl MlpBlock { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); + Ok(Self { + lin1, + lin2, + activation, + span, + }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs new file mode 100644 index 00000000..9d0074b1 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -0,0 +1,239 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct PostionEmbeddingRandom { + positional_encoding_gaussian_matrix: Tensor, +} + +impl PostionEmbeddingRandom { + fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { + let positional_encoding_gaussian_matrix = + vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; + Ok(Self { + positional_encoding_gaussian_matrix, + }) + } + + fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> { + let coords = coords.affine(2., -1.)?; + let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = (coords * (2. * std::f64::consts::PI))?; + Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) + } + + fn forward(&self, h: usize, w: usize) -> Result<Tensor> { + let device = self.positional_encoding_gaussian_matrix.device(); + let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let x_embed = (x_embed / w as f64)? + .reshape((1, ()))? + .broadcast_as((h, w))?; + let y_embed = (y_embed / h as f64)? + .reshape(((), 1))? + .broadcast_as((h, w))?; + let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; + self.pe_encoding(&coords)?.permute((2, 0, 1)) + } + + fn forward_with_coords( + &self, + coords_input: &Tensor, + image_size: (usize, usize), + ) -> Result<Tensor> { + let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?; + let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?; + let c = coords_input.dim(D::Minus1)?; + let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?; + let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?; + self.pe_encoding(&coords) + } +} + +#[derive(Debug)] +pub struct PromptEncoder { + pe_layer: PostionEmbeddingRandom, + point_embeddings: Vec<candle_nn::Embedding>, + not_a_point_embed: candle_nn::Embedding, + mask_downscaling_conv1: candle_nn::Conv2d, + mask_downscaling_ln1: super::LayerNorm2d, + mask_downscaling_conv2: candle_nn::Conv2d, + mask_downscaling_ln2: super::LayerNorm2d, + mask_downscaling_conv3: candle_nn::Conv2d, + no_mask_embed: candle_nn::Embedding, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + embed_dim: usize, + span: tracing::Span, +} + +impl PromptEncoder { + pub fn new( + embed_dim: usize, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + mask_in_chans: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_points_embeddings = 4; + let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; + let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let mask_downscaling_conv1 = + candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?; + let mask_downscaling_conv2 = candle_nn::conv2d( + mask_in_chans / 4, + mask_in_chans, + 2, + cfg, + vb.pp("mask_downscaling.3"), + )?; + let mask_downscaling_conv3 = candle_nn::conv2d( + mask_in_chans, + embed_dim, + 1, + Default::default(), + vb.pp("mask_downscaling.6"), + )?; + let mask_downscaling_ln1 = + super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = + super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + let mut point_embeddings = Vec::with_capacity(num_points_embeddings); + let vb_e = vb.pp("point_embeddings"); + for i in 0..num_points_embeddings { + let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; + point_embeddings.push(emb) + } + let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder"); + Ok(Self { + pe_layer, + point_embeddings, + not_a_point_embed, + mask_downscaling_conv1, + mask_downscaling_ln1, + mask_downscaling_conv2, + mask_downscaling_ln2, + mask_downscaling_conv3, + no_mask_embed, + image_embedding_size, + input_image_size, + embed_dim, + span, + }) + } + + pub fn get_dense_pe(&self) -> Result<Tensor> { + self.pe_layer + .forward(self.image_embedding_size.0, self.image_embedding_size.1)? + .unsqueeze(0) + } + + fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> { + masks + .apply(&self.mask_downscaling_conv1)? + .apply(&self.mask_downscaling_ln1)? + .gelu()? + .apply(&self.mask_downscaling_conv2)? + .apply(&self.mask_downscaling_ln2)? + .gelu()? + .apply(&self.mask_downscaling_conv3) + } + + fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> { + let points = (points + 0.5)?; + let dev = points.device(); + let (points, labels) = if pad { + let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; + let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; + let points = Tensor::cat(&[&points, &padding_point], 1)?; + let labels = Tensor::cat(&[labels, &padding_label], 1)?; + (points, labels) + } else { + (points, labels.clone()) + }; + let point_embedding = self + .pe_layer + .forward_with_coords(&points, self.input_image_size)?; + let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; + let zeros = point_embedding.zeros_like()?; + let point_embedding = labels.lt(0f32)?.where_cond( + &self + .not_a_point_embed + .embeddings() + .broadcast_as(zeros.shape())?, + &point_embedding, + )?; + let labels0 = labels.eq(0f32)?.where_cond( + &self.point_embeddings[0] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels0)?; + let labels1 = labels.eq(1f32)?.where_cond( + &self.point_embeddings[1] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels1)?; + Ok(point_embedding) + } + + fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> { + let boxes = (boxes + 0.5)?; + let coords = boxes.reshape(((), 2, 2))?; + let corner_embedding = self + .pe_layer + .forward_with_coords(&coords, self.input_image_size)?; + let ce1 = corner_embedding.i((.., 0))?; + let ce2 = corner_embedding.i((.., 1))?; + let ce1 = (ce1 + self.point_embeddings[2].embeddings())?; + let ce2 = (ce2 + self.point_embeddings[3].embeddings())?; + Tensor::cat(&[&ce1, &ce2], 1) + } + + pub fn forward( + &self, + points: Option<(&Tensor, &Tensor)>, + boxes: Option<&Tensor>, + masks: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let se_points = match points { + Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), + None => None, + }; + let se_boxes = match boxes { + Some(boxes) => Some(self.embed_boxes(boxes)?), + None => None, + }; + let sparse_embeddings = match (se_points, se_boxes) { + (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, + (Some(se_points), None) => se_points, + (None, Some(se_boxes)) => se_boxes, + (None, None) => { + Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + } + }; + + let dense_embeddings = match masks { + None => { + let emb = self.no_mask_embed.embeddings(); + emb.reshape((1, (), 1, 1))?.expand(( + 1, + emb.elem_count(), + self.image_embedding_size.0, + self.image_embedding_size.1, + ))? + } + Some(masks) => self.embed_masks(masks)?, + }; + Ok((sparse_embeddings, dense_embeddings)) + } +} diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs new file mode 100644 index 00000000..07e9a759 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -0,0 +1,433 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::image_encoder::ImageEncoderViT; +use super::mask_decoder::MaskDecoder; +use super::prompt_encoder::PromptEncoder; +use super::tiny_vit::{tiny_vit_5m, TinyViT}; + +const PROMPT_EMBED_DIM: usize = 256; +pub const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; + +#[derive(Debug)] +enum ImageEncoder { + Original(ImageEncoderViT), + TinyViT(TinyViT), +} + +impl Module for ImageEncoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Self::Original(vit) => vit.forward(xs), + Self::TinyViT(vit) => vit.forward(xs), + } + } +} + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoder, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::Original(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn new_tiny(vb: VarBuilder) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::TinyViT(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> { + let img = self.preprocess(img)?.unsqueeze(0)?; + self.image_encoder.forward(&img) + } + + pub fn forward( + &self, + img: &Tensor, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (_c, original_h, original_w) = img.dims3()?; + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let (low_res_mask, iou) = self.forward_for_embeddings( + &img_embeddings, + original_h, + original_w, + point, + multimask_output, + )?; + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou)) + } + + pub fn forward_for_embeddings( + &self, + img_embeddings: &Tensor, + original_h: usize, + original_w: usize, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = match point { + None => None, + Some((x, y)) => { + let points = Tensor::new( + &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], + img_embeddings.device(), + )?; + let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?; + Some((points, labels)) + } + }; + let points = points.as_ref().map(|(x, y)| (x, y)); + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(points, None, None)?; + self.mask_decoder.forward( + img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + ) + } + + pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { + let img = img + .broadcast_mul(&self.pixel_std)? + .broadcast_add(&self.pixel_mean)?; + img.maximum(&img.zeros_like()?)? + .minimum(&(img.ones_like()? * 255.)?) + } + + pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + let (_c, h, w) = img.dims3()?; + let img = img + .to_dtype(DType::F32)? + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } + + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + // Crop the image and calculate embeddings. + let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; + let img = self.preprocess(&img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + + let crop_w = cb.x1 - cb.x0; + let crop_h = cb.y1 - cb.y0; + + // Generate masks for this crop. + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = point_grids + .iter() + .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) + .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); + for points in points.chunks(64) { + // Run the model on this batch. + let points_len = points.len(); + let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; + let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder + .forward(Some((&in_points, &in_labels)), None, None)?; + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + /* multimask_output */ true, + )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } + + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = crate::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } + } + + let mut bboxes = vec![bboxes]; + // Remove duplicates within this crop. + crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); + + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) + } + + pub fn generate_masks( + &self, + img: &Tensor, + points_per_side: usize, + crop_n_layer: usize, + crop_overlap_ratio: f64, + crop_n_points_downscale_factor: usize, + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + let (_c, h, w) = img.dims3()?; + let point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layer, + crop_n_points_downscale_factor, + ); + let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); + for crop_box in crop_boxes.into_iter() { + let layer_idx = crop_box.layer_idx; + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) + } + // TODO: remove duplicates + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) + } +} + +#[derive(Debug)] +struct CropBox { + x0: usize, + y0: usize, + x1: usize, + y1: usize, + layer_idx: usize, +} + +impl CropBox { + fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { + Self { + x0, + y0, + x1, + y1, + layer_idx, + } + } +} + +fn generate_crop_boxes( + (im_h, im_w): (usize, usize), + n_layers: usize, + overlap_ratio: f64, +) -> Vec<CropBox> { + fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { + f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize + } + + let short_side = usize::min(im_h, im_w); + + let mut crop_boxes = Vec::new(); + + // Original image. + crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); + + for layer_idx in 1..=n_layers { + let n_crops_per_side = 1 << layer_idx; + let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; + let crop_w = crop_len(im_w, n_crops_per_side, overlap); + let crop_h = crop_len(im_w, n_crops_per_side, overlap); + + for i_x in 0..n_crops_per_side { + let x0 = (crop_w - overlap) * i_x; + for i_y in 0..n_crops_per_side { + let y0 = (crop_h - overlap) * i_y; + let x1 = usize::min(im_w, x0 + crop_w); + let y1 = usize::min(im_h, y0 + crop_h); + crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); + } + } + } + + crop_boxes +} + +// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. +fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { + let offset = 1f64 / (2 * n_per_side) as f64; + let mut points = Vec::with_capacity(n_per_side * n_per_side); + for i_x in 0..n_per_side { + let x = offset + i_x as f64 / n_per_side as f64; + for i_y in 0..n_per_side { + let y = offset + i_y as f64 / n_per_side as f64; + points.push((x, y)) + } + } + points +} + +fn build_all_layer_point_grids( + n_per_side: usize, + n_layers: usize, + scale_per_layer: usize, +) -> Vec<Vec<(f64, f64)>> { + let mut points_by_layer = Vec::with_capacity(n_layers + 1); + for i in 0..=n_layers { + let n_points = n_per_side / scale_per_layer.pow(i as u32); + points_by_layer.push(build_point_grid(n_points)) + } + points_by_layer +} diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs new file mode 100644 index 00000000..cd2936ab --- /dev/null +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -0,0 +1,633 @@ +// Adapted from: +// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{Conv2dConfig, Module, VarBuilder}; + +const MBCONV_EXPAND_RATIO: usize = 4; +const MLP_RATIO: usize = 4; +const LOCAL_CONV_SIZE: usize = 3; +const IMG_SIZE: usize = 1024; +const IN_CHANNELS: usize = 3; + +#[derive(Debug)] +struct Conv2dBN { + c: candle_nn::Conv2d, + bn: candle_nn::BatchNorm, + span: tracing::Span, +} + +impl Conv2dBN { + fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { + let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; + let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; + let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn"); + Ok(Self { c, bn, span }) + } +} + +impl Module for Conv2dBN { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.c)?.apply(&self.bn) + } +} + +#[derive(Debug)] +struct PatchEmbed { + conv1: Conv2dBN, + conv2: Conv2dBN, + span: tracing::Span, +} + +impl PatchEmbed { + fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride: 2, + padding: 1, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; + let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { conv1, conv2, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) + } +} + +#[derive(Debug)] +struct MBConv { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + span: tracing::Span, +} + +impl MBConv { + fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> { + let hidden = in_ * expand_ratio; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + groups: hidden, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "mb-conv"); + Ok(Self { + conv1, + conv2, + conv3, + span, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = xs + .apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)?; + (xs + shortcut)?.gelu() + } +} + +#[derive(Debug)] +struct PatchMerging { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl PatchMerging { + fn new( + input_resolution: (usize, usize), + dim: usize, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 }; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + stride, + groups: out, + ..Default::default() + }; + let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-merging"); + Ok(Self { + conv1, + conv2, + conv3, + input_resolution, + span, + }) + } +} + +impl Module for PatchMerging { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = if xs.rank() == 3 { + let (h, w) = self.input_resolution; + let b = xs.dim(0)?; + xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))? + } else { + xs.clone() + }; + xs.apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)? + .flatten_from(2)? + .transpose(1, 2) + } +} + +#[derive(Debug)] +struct ConvLayer { + blocks: Vec<MBConv>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl ConvLayer { + fn new( + dim: usize, + out: usize, + input_resolution: (usize, usize), + depth: usize, + downsample: bool, + conv_expand_ratio: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "conv-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for ConvLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +struct Mlp { + norm: candle_nn::LayerNorm, + fc1: super::Linear, + fc2: super::Linear, + span: tracing::Span, +} + +impl Mlp { + fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { + let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; + let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?; + let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + Ok(Self { + norm, + fc1, + fc2, + span, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.norm)? + .apply(&self.fc1)? + .gelu()? + .apply(&self.fc2) + } +} + +#[derive(Debug)] +struct Attention { + norm: candle_nn::LayerNorm, + qkv: super::Linear, + proj: super::Linear, + ab: Tensor, + key_dim: usize, + num_heads: usize, + d: usize, + dh: usize, + scale: f64, + span: tracing::Span, + span_matmul: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + key_dim: usize, + num_heads: usize, + attn_ratio: usize, + resolution: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let d = attn_ratio * key_dim; + let dh = d * num_heads; + let nh_kd = key_dim * num_heads; + let h = dh + nh_kd * 2; + let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; + let qkv = super::linear(vb.pp("qkv"), dim, h, true)?; + let proj = super::linear(vb.pp("proj"), dh, dim, true)?; + + let points = (0..resolution.0) + .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) + .collect::<Vec<_>>(); + let mut idxs = Vec::with_capacity(points.len() * points.len()); + let mut attention_offsets = std::collections::HashMap::new(); + for &(x1, y1) in points.iter() { + for &(x2, y2) in points.iter() { + let offset = ((x2 - x1).abs(), (y2 - y1).abs()); + let l = attention_offsets.len(); + let idx = attention_offsets.entry(offset).or_insert(l); + idxs.push(*idx as u32) + } + } + let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; + let idxs = Tensor::new(idxs, attention_biases.device())?; + let ab = + attention_biases + .index_select(&idxs, 1)? + .reshape(((), points.len(), points.len()))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + Ok(Self { + norm, + qkv, + proj, + ab, + key_dim, + num_heads, + d, + dh, + scale: 1f64 / (key_dim as f64).sqrt(), + span, + span_matmul, + span_softmax, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, n, _) = xs.dims3()?; + let xs = xs.apply(&self.norm)?; + let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; + let q = qkv + .narrow(D::Minus1, 0, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let k = qkv + .narrow(D::Minus1, self.key_dim, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let v = qkv + .narrow(D::Minus1, 2 * self.key_dim, self.d)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let attn = { + let _enter = self.span_matmul.enter(); + (q.matmul(&k.t()?)? * self.scale)? + }; + let attn = attn.broadcast_add(&self.ab)?; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + attn.transpose(1, 2)? + .reshape((b, n, self.dh))? + .apply(&self.proj) + } +} + +#[derive(Debug)] +struct TinyViTBlock { + attn: Attention, + local_conv: Conv2dBN, + mlp: Mlp, + window_size: usize, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl TinyViTBlock { + fn new( + dim: usize, + input_resolution: (usize, usize), + num_heads: usize, + window_size: usize, + vb: VarBuilder, + ) -> Result<Self> { + let head_dim = dim / num_heads; + let attn = Attention::new( + dim, + head_dim, + num_heads, + 1, + (window_size, window_size), + vb.pp("attn"), + )?; + let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; + let cfg = candle_nn::Conv2dConfig { + padding: LOCAL_CONV_SIZE / 2, + groups: dim, + ..Default::default() + }; + let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + Ok(Self { + attn, + local_conv, + mlp, + window_size, + input_resolution, + span, + }) + } +} + +impl Module for TinyViTBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (h, w) = self.input_resolution; + let (b, l, c) = xs.dims3()?; + let res_x = xs; + let xs = if h == self.window_size && w == self.window_size { + self.attn.forward(xs)? + } else { + let xs = xs.reshape((b, h, w, c))?; + let pad_b = (self.window_size - h % self.window_size) % self.window_size; + let pad_r = (self.window_size - w % self.window_size) % self.window_size; + + let xs = if pad_b > 0 { + xs.pad_with_zeros(1, 0, pad_b)? + } else { + xs + }; + let xs = if pad_r > 0 { + xs.pad_with_zeros(2, 0, pad_r)? + } else { + xs + }; + let (p_h, p_w) = (h + pad_b, w + pad_r); + let n_h = p_h / self.window_size; + let n_w = p_w / self.window_size; + let xs = xs + .reshape((b, n_h, self.window_size, n_w, self.window_size, c))? + .transpose(2, 3)? + .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?; + let xs = self.attn.forward(&xs)?; + let xs = xs + .reshape((b, n_h, n_w, self.window_size, self.window_size, c))? + .transpose(2, 3)? + .reshape((b, p_h, p_w, c))?; + let xs = if pad_r > 0 { + xs.i((.., .., ..w))?.contiguous()? + } else { + xs + }; + let xs = if pad_b > 0 { + xs.i((.., ..h, ..))?.contiguous()? + } else { + xs + }; + xs.reshape((b, l, c))? + }; + let xs = (xs + res_x)?; + let xs = xs + .transpose(1, 2)? + .reshape((b, c, h, w))? + .apply(&self.local_conv)? + .reshape((b, c, l))? + .transpose(1, 2)?; + &xs + self.mlp.forward(&xs)? + } +} + +#[derive(Debug)] +struct BasicLayer { + blocks: Vec<TinyViTBlock>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl BasicLayer { + #[allow(clippy::too_many_arguments)] + fn new( + dim: usize, + input_resolution: (usize, usize), + depth: usize, + num_heads: usize, + window_size: usize, + downsample: bool, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = TinyViTBlock::new( + dim, + input_resolution, + num_heads, + window_size, + vb_b.pp(index), + )?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "basic-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for BasicLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +pub struct TinyViT { + patch_embed: PatchEmbed, + layer0: ConvLayer, + layers: Vec<BasicLayer>, + // norm_head: candle_nn::LayerNorm, + // head: candle_nn::Linear, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + span: tracing::Span, + span_neck: tracing::Span, +} + +impl TinyViT { + pub fn new( + embed_dims: &[usize], + depths: &[usize], + num_heads: &[usize], + window_sizes: &[usize], + _num_classes: usize, + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; + let patches_resolution = IMG_SIZE / 4; + + let vb_l = vb.pp("layers"); + let layer0 = ConvLayer::new( + /* dim */ embed_dims[0], + /* out */ embed_dims[1], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[0], + /* downsample */ true, + /* conv_expand_ratio */ MBCONV_EXPAND_RATIO, + vb_l.pp(0), + )?; + + let num_layers = embed_dims.len(); + let mut layers = Vec::with_capacity(num_layers - 1); + for i_layer in 1..num_layers { + let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2)); + let layer = BasicLayer::new( + /* dim */ embed_dims[i_layer], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[i_layer], + /* num_heads */ num_heads[i_layer], + /* window_size */ window_sizes[i_layer], + /* downsample */ i_layer < num_layers - 1, + /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)], + vb_l.pp(i_layer), + )?; + layers.push(layer) + } + + let last_embed_dim = embed_dims[embed_dims.len() - 1]; + // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; + // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; + let neck_conv1 = + candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; + let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; + + let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); + let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); + Ok(Self { + patch_embed, + layer0, + layers, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + span, + span_neck, + }) + } +} + +impl Module for TinyViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = self.layer0.forward(&xs)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + let (b, _, c) = xs.dims3()?; + let _enter = self.span_neck.enter(); + xs.reshape((b, 64, 64, c))? + .permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} + +pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> { + TinyViT::new( + /* embed_dims */ &[64, 128, 160, 320], + /* depths */ &[2, 2, 6, 2], + /* num_heads */ &[2, 4, 5, 10], + /* window_sizes */ &[7, 7, 14, 7], + /* num_classes */ 1000, + vb, + ) +} diff --git a/candle-transformers/src/models/segment_anything/transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs new file mode 100644 index 00000000..80efb38c --- /dev/null +++ b/candle-transformers/src/models/segment_anything/transformer.rs @@ -0,0 +1,221 @@ +use candle::{Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, +} + +impl Attention { + fn new( + embedding_dim: usize, + num_heads: usize, + downsample_rate: usize, + vb: VarBuilder, + ) -> Result<Self> { + let internal_dim = embedding_dim / downsample_rate; + let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; + let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; + let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; + let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + }) + } + + fn separate_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n, c) = x.dims3()?; + x.reshape((b, n, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? + .contiguous() + } + + fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; + x.transpose(1, 2)? + .reshape((b, n_tokens, n_heads * c_per_head)) + } + + fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; + + let q = self.separate_heads(&q)?; + let k = self.separate_heads(&k)?; + let v = self.separate_heads(&v)?; + + let (_, _, _, c_per_head) = q.dims4()?; + let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + + let out = attn.matmul(&v)?; + self.recombine_heads(&out)?.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct TwoWayAttentionBlock { + self_attn: Attention, + norm1: LayerNorm, + cross_attn_token_to_image: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + norm3: LayerNorm, + norm4: LayerNorm, + cross_attn_image_to_token: Attention, + skip_first_layer_pe: bool, +} + +impl TwoWayAttentionBlock { + fn new( + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + skip_first_layer_pe: bool, + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; + let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; + let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; + let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; + let cross_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_token_to_image"), + )?; + let cross_attn_image_to_token = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_image_to_token"), + )?; + let mlp = super::MlpBlock::new( + embedding_dim, + mlp_dim, + candle_nn::Activation::Relu, + vb.pp("mlp"), + )?; + Ok(Self { + self_attn, + norm1, + cross_attn_image_to_token, + norm2, + mlp, + norm3, + norm4, + cross_attn_token_to_image, + skip_first_layer_pe, + }) + } + + fn forward( + &self, + queries: &Tensor, + keys: &Tensor, + query_pe: &Tensor, + key_pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Self attention block + let queries = if self.skip_first_layer_pe { + self.self_attn.forward(queries, queries, queries)? + } else { + let q = (queries + query_pe)?; + let attn_out = self.self_attn.forward(&q, &q, queries)?; + (queries + attn_out)? + }; + let queries = self.norm1.forward(&queries)?; + + // Cross attention block, tokens attending to image embedding + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; + let queries = (&queries + attn_out)?; + let queries = self.norm2.forward(&queries)?; + + // MLP block + let mlp_out = self.mlp.forward(&queries); + let queries = (queries + mlp_out)?; + let queries = self.norm3.forward(&queries)?; + + // Cross attention block, image embedding attending to tokens + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; + let keys = (keys + attn_out)?; + let keys = self.norm4.forward(&keys)?; + + Ok((queries, keys)) + } +} + +#[derive(Debug)] +pub struct TwoWayTransformer { + layers: Vec<TwoWayAttentionBlock>, + final_attn_token_to_image: Attention, + norm_final_attn: LayerNorm, +} + +impl TwoWayTransformer { + pub fn new( + depth: usize, + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(depth); + for i in 0..depth { + let layer = + TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; + layers.push(layer) + } + let final_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("final_attn_token_to_image"), + )?; + let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; + Ok(Self { + layers, + final_attn_token_to_image, + norm_final_attn, + }) + } + + pub fn forward( + &self, + image_embedding: &Tensor, + image_pe: &Tensor, + point_embedding: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; + let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; + + let mut queries = point_embedding.clone(); + let mut keys = image_embedding; + + for layer in self.layers.iter() { + (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? + } + + let q = (&queries + point_embedding)?; + let k = (&keys + image_pe)?; + let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; + let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; + + Ok((queries, keys)) + } +} diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 1ae1bfc3..b3ea91f9 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -17,7 +17,7 @@ impl GeGlu { } } -impl GeGlu { +impl Module for GeGlu { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; @@ -53,7 +53,7 @@ impl FeedForward { } } -impl FeedForward { +impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let xs = self.project_in.forward(xs)?; @@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten } #[derive(Debug)] -struct CrossAttention { +pub struct CrossAttention { to_q: nn::Linear, to_k: nn::Linear, to_v: nn::Linear, @@ -94,7 +94,7 @@ struct CrossAttention { impl CrossAttention { // Defaults should be heads = 8, dim_head = 64, context_dim = None - fn new( + pub fn new( vs: nn::VarBuilder, query_dim: usize, context_dim: Option<usize>, @@ -198,14 +198,14 @@ impl CrossAttention { let xs = query.matmul(&(key.t()? * self.scale)?)?; let xs = { let _enter = self.span_softmax.enter(); - nn::ops::softmax(&xs, D::Minus1)? + nn::ops::softmax_last_dim(&xs)? }; xs.matmul(&value)?.to_dtype(in_dtype)? }; self.reshape_batch_dim_to_heads(&xs) } - fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { let _enter = self.span.enter(); let query = self.to_q.forward(xs)?; let context = context.unwrap_or(xs).contiguous()?; @@ -501,8 +501,10 @@ impl AttentionBlock { xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? .transpose(1, 2) } +} - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl Module for AttentionBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let in_dtype = xs.dtype(); let residual = xs; diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index d26c1c46..e7a20270 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -12,13 +12,15 @@ use candle_nn::Module; pub enum Activation { QuickGelu, Gelu, + GeluErf, } -impl Activation { +impl Module for Activation { fn forward(&self, xs: &Tensor) -> Result<Tensor> { match self { Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, Activation::Gelu => xs.gelu(), + Activation::GeluErf => xs.gelu_erf(), } } } @@ -99,6 +101,36 @@ impl Config { activation: Activation::Gelu, } } + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json + pub fn wuerstchen() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 1024, + activation: Activation::GeluErf, + } + } + + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json + pub fn wuerstchen_prior() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1280, + intermediate_size: 5120, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 32, + num_attention_heads: 20, + projection_dim: 512, + activation: Activation::GeluErf, + } + } } // CLIP Text Model @@ -129,7 +161,7 @@ impl ClipTextEmbeddings { } } -impl ClipTextEmbeddings { +impl Module for ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; @@ -319,21 +351,39 @@ impl ClipTextTransformer { } // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 - fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { + fn build_causal_attention_mask( + bsz: usize, + seq_len: usize, + mask_after: usize, + device: &Device, + ) -> Result<Tensor> { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if j > i || j > mask_after { + f32::MIN + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) } -} -impl ClipTextTransformer { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> { let (bsz, seq_len) = xs.dims2()?; let xs = self.embeddings.forward(xs)?; - let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; let xs = self.encoder.forward(&xs, &causal_attention_mask)?; self.final_layer_norm.forward(&xs) } } + +impl Module for ClipTextTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.forward_with_mask(xs, usize::MAX) + } +} diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index f2e021ce..916b7349 100644 --- a/candle-examples/examples/stable-diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -7,7 +7,7 @@ //! //! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! https://arxiv.org/abs/2010.02502 -use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; use candle::{Result, Tensor}; /// The configuration for the DDIM scheduler. @@ -67,14 +67,14 @@ impl DDIMScheduler { .rev() .collect(); let betas = match config.beta_schedule { - BetaSchedule::ScaledLinear => crate::utils::linspace( + BetaSchedule::ScaledLinear => super::utils::linspace( config.beta_start.sqrt(), config.beta_end.sqrt(), config.train_timesteps, )? .sqr()?, BetaSchedule::Linear => { - crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? } BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, }; @@ -163,6 +163,17 @@ impl DDIMScheduler { } } + pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? + } + pub fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs new file mode 100644 index 00000000..d393f39a --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs @@ -0,0 +1,205 @@ +use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DDPMVarianceType { + FixedSmall, + FixedSmallLog, + FixedLarge, + FixedLargeLog, + Learned, +} + +impl Default for DDPMVarianceType { + fn default() -> Self { + Self::FixedSmall + } +} + +#[derive(Debug, Clone)] +pub struct DDPMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Option to predicted sample between -1 and 1 for numerical stability. + pub clip_sample: bool, + /// Option to clip the variance used when adding noise to the denoised sample. + pub variance_type: DDPMVarianceType, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for DDPMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + clip_sample: false, + variance_type: DDPMVarianceType::FixedSmall, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +pub struct DDPMScheduler { + alphas_cumprod: Vec<f64>, + init_noise_sigma: f64, + timesteps: Vec<usize>, + step_ratio: usize, + pub config: DDPMSchedulerConfig, +} + +impl DDPMScheduler { + pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + + // min(train_timesteps, inference_steps) + // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187 + let inference_steps = inference_steps.min(config.train_timesteps); + // arange the number of the scheduler's timesteps + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect(); + + Ok(Self { + alphas_cumprod, + init_noise_sigma: 1.0, + timesteps, + step_ratio, + config, + }) + } + + fn get_variance(&self, timestep: usize) -> f64 { + let prev_t = timestep as isize - self.step_ratio as isize; + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev; + + // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + // and sample from it to get previous sample + // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t; + + // retrieve variance + match self.config.variance_type { + DDPMVarianceType::FixedSmall => variance.max(1e-20), + // for rl-diffuser https://arxiv.org/abs/2205.09991 + DDPMVarianceType::FixedSmallLog => { + let variance = variance.max(1e-20).ln(); + (variance * 0.5).exp() + } + DDPMVarianceType::FixedLarge => current_beta_t, + DDPMVarianceType::FixedLargeLog => current_beta_t.ln(), + DDPMVarianceType::Learned => variance, + } + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let prev_t = timestep as isize - self.step_ratio as isize; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272 + // 1. compute alphas, betas + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + let current_alpha_t = alpha_prod_t / alpha_prod_t_prev; + let current_beta_t = 1. - current_alpha_t; + + // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15) + let mut pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => { + ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())? + } + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => { + ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())? + } + }; + + // 3. clip predicted x_0 + if self.config.clip_sample { + pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?; + } + + // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t; + let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t; + + // 5. Compute predicted previous sample µ_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)? + + sample * current_sample_coeff)?; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305 + // 6. Add noise + let mut variance = model_output.zeros_like()?; + if timestep > 0 { + let variance_noise = model_output.randn_like(0., 1.)?; + if self.config.variance_type == DDPMVarianceType::FixedSmallLog { + variance = (variance_noise * self.get_variance(timestep))?; + } else { + variance = (variance_noise * self.get_variance(timestep).sqrt())?; + } + } + &pred_prev_sample + variance + } + + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Result<Tensor> { + (original_samples * self.alphas_cumprod[timestep].sqrt())? + + noise * (1. - self.alphas_cumprod[timestep]).sqrt() + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-transformers/src/models/stable_diffusion/embeddings.rs index 97bc61f1..0de5f9a7 100644 --- a/candle-examples/examples/stable-diffusion/embeddings.rs +++ b/candle-transformers/src/models/stable_diffusion/embeddings.rs @@ -17,8 +17,8 @@ impl TimestepEmbedding { } } -impl TimestepEmbedding { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl Module for TimestepEmbedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; self.linear_2.forward(&xs) } @@ -41,8 +41,8 @@ impl Timesteps { } } -impl Timesteps { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl Module for Timesteps { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { let half_dim = (self.num_channels / 2) as u32; let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)? * -f64::ln(10000.))?; diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index cffc00d8..c6f1b904 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -1,5 +1,15 @@ -use crate::schedulers::PredictionType; -use crate::{clip, ddim, unet_2d, vae}; +pub mod attention; +pub mod clip; +pub mod ddim; +pub mod ddpm; +pub mod embeddings; +pub mod resnet; +pub mod schedulers; +pub mod unet_2d; +pub mod unet_2d_blocks; +pub mod utils; +pub mod vae; + use candle::{DType, Device, Result}; use candle_nn as nn; @@ -80,7 +90,7 @@ impl StableDiffusionConfig { sliced_attention_size: Option<usize>, height: Option<usize>, width: Option<usize>, - prediction_type: PredictionType, + prediction_type: schedulers::PredictionType, ) -> Self { let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { out_channels, @@ -154,7 +164,7 @@ impl StableDiffusionConfig { sliced_attention_size, height, width, - PredictionType::VPrediction, + schedulers::PredictionType::VPrediction, ) } @@ -162,7 +172,7 @@ impl StableDiffusionConfig { sliced_attention_size: Option<usize>, height: Option<usize>, width: Option<usize>, - prediction_type: PredictionType, + prediction_type: schedulers::PredictionType, ) -> Self { let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { out_channels, @@ -235,7 +245,7 @@ impl StableDiffusionConfig { height, width, // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json - PredictionType::Epsilon, + schedulers::PredictionType::Epsilon, ) } diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs index 4cfd386d..0d818115 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-transformers/src/models/stable_diffusion/resnet.rs @@ -4,7 +4,7 @@ //! //! Denoising Diffusion Implicit Models, K. He and al, 2015. //! https://arxiv.org/abs/1512.03385 -use crate::utils::{conv2d, Conv2d}; +use super::utils::{conv2d, Conv2d}; use candle::{Result, Tensor, D}; use candle_nn as nn; use candle_nn::Module; diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 3f6a1d72..3f6a1d72 100644 --- a/candle-examples/examples/stable-diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs index 81bd9547..a3ed136e 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs @@ -2,9 +2,9 @@ //! //! The 2D Unet models take as input a noisy sample and the current diffusion //! timestep and return a denoised version of the input. -use crate::embeddings::{TimestepEmbedding, Timesteps}; -use crate::unet_2d_blocks::*; -use crate::utils::{conv2d, Conv2d}; +use super::embeddings::{TimestepEmbedding, Timesteps}; +use super::unet_2d_blocks::*; +use super::utils::{conv2d, Conv2d}; use candle::{Result, Tensor}; use candle_nn as nn; use candle_nn::Module; diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs index 26a1035b..29510cef 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -1,11 +1,11 @@ //! 2D UNet Building Blocks //! -use crate::attention::{ +use super::attention::{ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, }; -use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; -use crate::utils::{conv2d, Conv2d}; -use candle::{Result, Tensor, D}; +use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; +use super::utils::{conv2d, Conv2d}; +use candle::{Module, Result, Tensor, D}; use candle_nn as nn; #[derive(Debug)] @@ -43,7 +43,7 @@ impl Downsample2D { } } -impl Downsample2D { +impl Module for Downsample2D { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); match &self.conv { @@ -172,8 +172,8 @@ impl DownEncoderBlock2D { } } -impl DownEncoderBlock2D { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl Module for DownEncoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let mut xs = xs.clone(); for resnet in self.resnets.iter() { @@ -256,8 +256,8 @@ impl UpDecoderBlock2D { } } -impl UpDecoderBlock2D { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl Module for UpDecoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let mut xs = xs.clone(); for resnet in self.resnets.iter() { @@ -754,6 +754,7 @@ impl UpBlock2D { let mut xs = xs.clone(); for (index, resnet) in self.resnets.iter().enumerate() { xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = xs.contiguous()?; xs = resnet.forward(&xs, temb)?; } match &self.upsampler { @@ -855,6 +856,7 @@ impl CrossAttnUpBlock2D { let mut xs = xs.clone(); for (index, resnet) in self.upblock.resnets.iter().enumerate() { xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = xs.contiguous()?; xs = resnet.forward(&xs, temb)?; xs = self.attentions[index].forward(&xs, encoder_hidden_states)?; } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index c62f17af..c62f17af 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index aa8e13a0..21709afe 100644 --- a/candle-examples/examples/stable-diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -4,7 +4,7 @@ //! Auto-encoder models compress their input to a usually smaller latent space //! before expanding it back to its original shape. This results in the latent values //! compressing the original information. -use crate::unet_2d_blocks::{ +use super::unet_2d_blocks::{ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, UpDecoderBlock2D, UpDecoderBlock2DConfig, }; @@ -132,14 +132,15 @@ impl Encoder { impl Encoder { fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.conv_in.forward(xs)?; + let mut xs = xs.apply(&self.conv_in)?; for down_block in self.down_blocks.iter() { - xs = down_block.forward(&xs)? + xs = xs.apply(down_block)? } - let xs = self.mid_block.forward(&xs, None)?; - let xs = self.conv_norm_out.forward(&xs)?; - let xs = nn::ops::silu(&xs)?; - self.conv_out.forward(&xs) + let xs = self + .mid_block + .forward(&xs, None)? + .apply(&self.conv_norm_out)?; + nn::ops::silu(&xs)?.apply(&self.conv_out) } } @@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution { } pub fn sample(&self) -> Result<Tensor> { - let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device()); + let sample = self.mean.randn_like(0., 1.); &self.mean + &self.std * sample } } diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs new file mode 100644 index 00000000..539ae89b --- /dev/null +++ b/candle-transformers/src/models/t5.rs @@ -0,0 +1,841 @@ +// T5 Text Model +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let inner = candle_nn::embedding(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug)] +struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let inner = candle_nn::linear_no_bias(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Self { inner, span }) + } +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + +fn default_tie_word_embeddings() -> bool { + true +} + +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + d_model: usize, + d_kv: usize, + d_ff: usize, + num_layers: usize, + num_decoder_layers: Option<usize>, + num_heads: usize, + relative_attention_num_buckets: usize, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, + dropout_rate: f64, + layer_norm_epsilon: f64, + initializer_factor: f64, + #[serde(default)] + feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, + #[serde(default = "default_is_decoder")] + is_decoder: bool, + is_encoder_decoder: bool, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 32128, + d_model: 512, + d_kv: 64, + d_ff: 2048, + num_layers: 6, + num_decoder_layers: None, + num_heads: 8, + relative_attention_num_buckets: 32, + relative_attention_max_distance: 128, + dropout_rate: 0.1, + layer_norm_epsilon: 1e-6, + initializer_factor: 1.0, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + is_decoder: false, + is_encoder_decoder: true, + use_cache: true, + pad_token_id: 0, + eos_token_id: 1, + } + } +} + +impl Config { + // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184 + pub fn musicgen_small() -> Self { + Self { + d_ff: 3072, + d_kv: 64, + d_model: 768, + dropout_rate: 0.1, + eos_token_id: 1, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + initializer_factor: 1.0, + is_decoder: false, + is_encoder_decoder: true, + layer_norm_epsilon: 1e-6, + num_decoder_layers: Some(12), + num_heads: 12, + num_layers: 12, + pad_token_id: 0, + relative_attention_max_distance: 128, + relative_attention_num_buckets: 32, + use_cache: true, + vocab_size: 32128, + } + } +} + +#[derive(Debug)] +struct T5LayerNorm { + weight: Tensor, + variance_epsilon: f64, + span: tracing::Span, +} + +impl T5LayerNorm { + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(h, "weight")?; + Ok(Self { + weight, + variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } +} + +impl Module for T5LayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseActDense { + wi: Linear, + wo: Linear, + act: Activation, + span: tracing::Span, +} + +impl T5DenseActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi, + wo, + act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), + }) + } +} + +impl Module for T5DenseActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: Linear, + wi_1: Linear, + wo: Linear, + act: Activation, + span: tracing::Span, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), + }) + } +} + +impl Module for T5DenseGatedActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5LayerFF { + dense_act: Option<T5DenseActDense>, + gated_dense_act: Option<T5DenseGatedActDense>, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerFF { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; + Ok(Self { + dense_act, + gated_dense_act, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), + }) + } +} + +impl Module for T5LayerFF { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let ys = self.layer_norm.forward(xs)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; + let xs = (xs + ys)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5Attention { + q: Linear, + k: Linear, + v: Linear, + o: Linear, + n_heads: usize, + d_kv: usize, + relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, + span_sm: tracing::Span, +} + +impl T5Attention { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let inner_dim = cfg.num_heads * cfg.d_kv; + let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let relative_attention_bias = if has_relative_attention_bias { + let emb = Embedding::new( + cfg.relative_attention_num_buckets, + cfg.num_heads, + vb.pp("relative_attention_bias"), + )?; + Some(emb) + } else { + None + }; + Ok(Self { + q, + k, + v, + o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, + relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; + let q = self.q.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; + let q = q + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut k = k + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + + if self.use_cache { + let _enter = self.span_cache.enter(); + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; + // TODO: Use flash_attn. + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; + + let (scores, position_bias) = match position_bias { + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (q_start..q_end) + .map(|i| { + (0..kv_len as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores.broadcast_add(&position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, + }; + + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.inner_dim))?; + let attn_output = self.o.forward(&attn_output)?; + Ok((attn_output, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug)] +struct T5LayerSelfAttention { + self_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerSelfAttention { + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + self_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_xs = self.layer_norm.forward(xs)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; + let ys = (xs + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerCrossAttention { + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), + }) + } + + fn forward( + &mut self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5Block { + self_attn: T5LayerSelfAttention, + cross_attn: Option<T5LayerCrossAttention>, + ff: T5LayerFF, + span: tracing::Span, +} + +impl T5Block { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let vb = vb.pp("layer"); + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; + let cross_attn = if cfg.is_decoder { + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) + } else { + None + }; + let ff_i = if cross_attn.is_some() { 2 } else { 1 }; + let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; + Ok(Self { + self_attn, + cross_attn, + ff, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; + // TODO: clamp for f16? + if let Some(cross_attn) = &mut self.cross_attn { + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok((xs, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } +} + +#[derive(Debug)] +struct T5Stack { + block: Vec<T5Block>, + shared: Arc<Embedding>, + final_layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5Stack { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { + let block = (0..cfg.num_layers) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg)) + .collect::<Result<Vec<_>>>()?; + let final_layer_norm = T5LayerNorm::load( + cfg.d_model, + cfg.layer_norm_epsilon, + vb.pp("final_layer_norm"), + )?; + Ok(Self { + block, + shared: shared.clone(), + final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let mut hidden_states = input_embeds; + let mut position_bias = None; + for block in self.block.iter_mut() { + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? + } + self.final_layer_norm.forward(&hidden_states) + } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} + +#[derive(Debug)] +pub struct T5EncoderModel { + encoder: T5Stack, + device: Device, + span: tracing::Span, +} + +impl T5EncoderModel { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; + Ok(Self { + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option<Linear>, + shared: Arc<Embedding>, + device: Device, + span_decode: tracing::Span, + span_decode_head: tracing::Span, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + }; + + Ok(Self { + encoder, + decoder, + d_model, + tie_word_embeddings, + lm_head, + shared, + device: vb.device().clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), + }) + } + + pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result<Tensor> { + let _enter = self.span_decode.enter(); + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(encoder_output))?; + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?) + * scaling_factor)?; + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } + }; + + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) + } + + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } +} diff --git a/candle-examples/examples/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs index 2ceed065..4e01de32 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>( mel } -pub fn pcm_to_mel<T: Float + std::fmt::Display>( - samples: &[T], - filters: &[T], -) -> anyhow::Result<Vec<T>> { - let mel = log_mel_spectrogram_( +pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> { + log_mel_spectrogram_( samples, filters, super::N_FFT, super::HOP_LENGTH, super::N_MELS, false, - ); - Ok(mel) + ) } diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs new file mode 100644 index 00000000..7dc8107b --- /dev/null +++ b/candle-transformers/src/models/whisper/mod.rs @@ -0,0 +1,26 @@ +pub mod audio; +pub mod model; + +pub const DTYPE: candle::DType = candle::DType::F32; + +// Audio parameters. +pub const SAMPLE_RATE: usize = 16000; +pub const N_FFT: usize = 400; +pub const N_MELS: usize = 80; +pub const HOP_LENGTH: usize = 160; +pub const CHUNK_LENGTH: usize = 30; +pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input + +pub const NO_SPEECH_THRESHOLD: f64 = 0.6; +pub const LOGPROB_THRESHOLD: f64 = -1.0; +pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +pub const SOT_TOKEN: &str = "<|startoftranscript|>"; +pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +pub const TRANSLATE_TOKEN: &str = "<|translate|>"; +pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; +pub const EOT_TOKEN: &str = "<|endoftext|>"; +pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; diff --git a/candle-examples/examples/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index e58ab2ca..d2eda796 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,5 +1,5 @@ use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -166,7 +166,7 @@ impl MultiHeadAttention { } let w = { let _enter = self.softmax_span.enter(); - softmax(&qk, D::Minus1)? + candle_nn::ops::softmax_last_dim(&qk)? }; let wv = { let _enter = self.matmul_span.enter(); diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs new file mode 100644 index 00000000..0b90cb9d --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs @@ -0,0 +1,118 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +// A simplified version of: +// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38 +#[derive(Debug)] +pub struct Attention { + to_q: Linear, + to_k: Linear, + to_v: Linear, + to_out: Linear, + heads: usize, + scale: f64, + use_flash_attn: bool, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +impl Attention { + pub fn new( + query_dim: usize, + heads: usize, + dim_head: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?; + let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?; + let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?; + let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?; + Ok(Self { + to_q, + to_k, + to_v, + to_out, + scale, + heads, + use_flash_attn, + }) + } + + fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size / self.heads, self.heads, seq_len, dim))? + .permute((0, 2, 1, 3))? + .reshape((b_size / self.heads, seq_len, dim * self.heads)) + } + + fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size, seq_len, self.heads, dim / self.heads))? + .permute((0, 2, 1, 3))? + .reshape((b_size * self.heads, seq_len, dim / self.heads)) + } + + fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> { + let attn_probs = (query.matmul(&key.t()?)? * self.scale)?; + candle_nn::ops::softmax_last_dim(&attn_probs) + } + + pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { + let (b_size, channel, h, w) = xs.dims4()?; + let xs = xs.reshape((b_size, channel, h * w))?.t()?; + + let query = self.to_q.forward(&xs)?; + let key = self.to_k.forward(encoder_hidden_states)?; + let value = self.to_v.forward(encoder_hidden_states)?; + + let query = self.head_to_batch_dim(&query)?; + let key = self.head_to_batch_dim(&key)?; + let value = self.head_to_batch_dim(&value)?; + + let xs = if self.use_flash_attn { + let init_dtype = query.dtype(); + let q = query + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let k = key + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let v = value + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + flash_attn(&q, &k, &v, self.scale as f32, false)? + .transpose(1, 2)? + .squeeze(0)? + .to_dtype(init_dtype)? + } else { + let attn_prs = self.get_attention_scores(&query, &key)?; + attn_prs.matmul(&value)? + }; + let xs = self.batch_to_head_dim(&xs)?; + + self.to_out + .forward(&xs)? + .t()? + .reshape((b_size, channel, h, w)) + } +} diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs new file mode 100644 index 00000000..c89ec919 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -0,0 +1,203 @@ +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22 +#[derive(Debug)] +pub struct WLayerNorm { + eps: f64, +} + +impl WLayerNorm { + pub fn new(_size: usize) -> Result<Self> { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for WLayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = xs.permute((0, 2, 3, 1))?; + + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)? + .permute((0, 3, 1, 2)) + } +} + +#[derive(Debug)] +pub struct LayerNormNoWeights { + eps: f64, +} + +impl LayerNormNoWeights { + pub fn new(_size: usize) -> Result<Self> { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for LayerNormNoWeights { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype) + } +} + +#[derive(Debug)] +pub struct TimestepBlock { + mapper: candle_nn::Linear, +} + +impl TimestepBlock { + pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> { + let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?; + Ok(Self { mapper }) + } + + pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> { + let ab = self + .mapper + .forward(t)? + .unsqueeze(2)? + .unsqueeze(3)? + .chunk(2, 1)?; + xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1]) + } +} + +#[derive(Debug)] +pub struct GlobalResponseNorm { + gamma: Tensor, + beta: Tensor, +} + +impl GlobalResponseNorm { + pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> { + let gamma = vb.get((1, 1, 1, dim), "gamma")?; + let beta = vb.get((1, 1, 1, dim), "beta")?; + Ok(Self { gamma, beta }) + } +} + +impl Module for GlobalResponseNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + let stand_div_norm = + agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; + xs.broadcast_mul(&stand_div_norm)? + .broadcast_mul(&self.gamma)? + .broadcast_add(&self.beta)? + + xs + } +} + +#[derive(Debug)] +pub struct ResBlock { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlock { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + padding: ksize / 2, + groups: c, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c)?; + let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?, + }; + let xs = xs + .apply(&self.depthwise)? + .apply(&self.norm)? + .permute((0, 2, 3, 1))?; + let xs = xs + .apply(&self.channelwise_lin1)? + .gelu_erf()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} +use super::attention_processor::Attention; +#[derive(Debug)] +pub struct AttnBlock { + self_attn: bool, + norm: WLayerNorm, + attention: Attention, + kv_mapper_lin: candle_nn::Linear, +} + +impl AttnBlock { + pub fn new( + c: usize, + c_cond: usize, + nhead: usize, + self_attn: bool, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let norm = WLayerNorm::new(c)?; + let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?; + let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; + Ok(Self { + self_attn, + norm, + attention, + kv_mapper_lin, + }) + } + + pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> { + let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?; + let norm_xs = self.norm.forward(xs)?; + let kv = if self.self_attn { + let (b_size, channel, _, _) = xs.dims4()?; + let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?; + Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()? + } else { + kv + }; + xs + self.attention.forward(&norm_xs, &kv) + } +} diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs new file mode 100644 index 00000000..9e69b868 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/ddpm.rs @@ -0,0 +1,103 @@ +use candle::{Result, Tensor}; + +#[derive(Debug, Clone)] +pub struct DDPMWSchedulerConfig { + scaler: f64, + s: f64, +} + +impl Default for DDPMWSchedulerConfig { + fn default() -> Self { + Self { + scaler: 1f64, + s: 0.008f64, + } + } +} + +pub struct DDPMWScheduler { + init_alpha_cumprod: f64, + init_noise_sigma: f64, + timesteps: Vec<f64>, + pub config: DDPMWSchedulerConfig, +} + +impl DDPMWScheduler { + pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> { + let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI) + .cos() + .powi(2); + let timesteps = (0..=inference_steps) + .map(|i| 1. - i as f64 / inference_steps as f64) + .collect::<Vec<_>>(); + Ok(Self { + init_alpha_cumprod, + init_noise_sigma: 1.0, + timesteps, + config, + }) + } + + pub fn timesteps(&self) -> &[f64] { + &self.timesteps + } + + fn alpha_cumprod(&self, t: f64) -> f64 { + let scaler = self.config.scaler; + let s = self.config.s; + let t = if scaler > 1. { + 1. - (1. - t).powf(scaler) + } else if scaler < 1. { + t.powf(scaler) + } else { + t + }; + let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5) + .cos() + .powi(2) + / self.init_alpha_cumprod; + alpha_cumprod.clamp(0.0001, 0.9999) + } + + fn previous_timestep(&self, ts: f64) -> f64 { + let index = self + .timesteps + .iter() + .enumerate() + .map(|(idx, v)| (idx, (v - ts).abs())) + .min_by(|x, y| x.1.total_cmp(&y.1)) + .unwrap() + .0; + self.timesteps[index + 1] + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> { + let prev_t = self.previous_timestep(ts); + + let alpha_cumprod = self.alpha_cumprod(ts); + let alpha_cumprod_prev = self.alpha_cumprod(prev_t); + let alpha = alpha_cumprod / alpha_cumprod_prev; + + let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?; + let mu = (mu * (1. / alpha).sqrt())?; + + let std_noise = mu.randn_like(0., 1.)?; + let std = + std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt(); + if prev_t == 0. { + Ok(mu) + } else { + mu + std + } + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs new file mode 100644 index 00000000..64a48c8a --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -0,0 +1,396 @@ +use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct ResBlockStageB { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlockStageB { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + groups: c, + padding: ksize / 2, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c)?; + let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + }; + let xs = xs + .permute((0, 2, 3, 1))? + .contiguous()? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} + +#[derive(Debug)] +struct SubBlock { + res_block: ResBlockStageB, + ts_block: TimestepBlock, + attn_block: Option<AttnBlock>, +} + +#[derive(Debug)] +struct DownBlock { + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::Conv2d>, + sub_blocks: Vec<SubBlock>, +} + +#[derive(Debug)] +struct UpBlock { + sub_blocks: Vec<SubBlock>, + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::ConvTranspose2d>, +} + +#[derive(Debug)] +pub struct WDiffNeXt { + clip_mapper: candle_nn::Linear, + effnet_mappers: Vec<Option<candle_nn::Conv2d>>, + seq_norm: LayerNormNoWeights, + embedding_conv: candle_nn::Conv2d, + embedding_ln: WLayerNorm, + down_blocks: Vec<DownBlock>, + up_blocks: Vec<UpBlock>, + clf_ln: WLayerNorm, + clf_conv: candle_nn::Conv2d, + c_r: usize, + patch_size: usize, +} + +impl WDiffNeXt { + #[allow(clippy::too_many_arguments)] + pub fn new( + c_in: usize, + c_out: usize, + c_r: usize, + c_cond: usize, + clip_embd: usize, + patch_size: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; + const BLOCKS: [usize; 4] = [4, 4, 14, 4]; + const NHEAD: [usize; 4] = [1, 10, 20, 20]; + const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; + const EFFNET_EMBD: usize = 16; + + let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len()); + let vb_e = vb.pp("effnet_mappers"); + for (i, &inject) in INJECT_EFFNET.iter().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i + INJECT_EFFNET.len()), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + let seq_norm = LayerNormNoWeights::new(c_cond)?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; + let embedding_conv = candle_nn::conv2d( + c_in * patch_size * patch_size, + C_HIDDEN[0], + 1, + Default::default(), + vb.pp("embedding.1"), + )?; + + let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { + let vb = vb.pp("down_blocks").pp(i); + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; + (Some(layer_norm), Some(conv), 1) + } else { + (None, None, 0) + }; + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = start_layer_i; + for _j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let down_block = DownBlock { + layer_norm, + conv, + sub_blocks, + }; + down_blocks.push(down_block) + } + + let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { + let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = 0; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { + c_hidden + c_skip + } else { + c_skip + }; + let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let (layer_norm, conv) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv_transpose2d( + c_hidden, + C_HIDDEN[i - 1], + 2, + cfg, + vb.pp(layer_i).pp(1), + )?; + (Some(layer_norm), Some(conv)) + } else { + (None, None) + }; + let up_block = UpBlock { + layer_norm, + conv, + sub_blocks, + }; + up_blocks.push(up_block) + } + + let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; + let clf_conv = candle_nn::conv2d( + C_HIDDEN[0], + 2 * c_out * patch_size * patch_size, + 1, + Default::default(), + vb.pp("clf.1"), + )?; + Ok(Self { + clip_mapper, + effnet_mappers, + seq_norm, + embedding_conv, + embedding_ln, + down_blocks, + up_blocks, + clf_ln, + clf_conv, + c_r, + patch_size, + }) + } + + fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> { + clip.apply(&self.clip_mapper)?.apply(&self.seq_norm) + } + + pub fn forward( + &self, + xs: &Tensor, + r: &Tensor, + effnet: &Tensor, + clip: Option<&Tensor>, + ) -> Result<Tensor> { + const EPS: f64 = 1e-3; + + let r_embed = self.gen_r_embedding(r)?; + let clip = match clip { + None => None, + Some(clip) => Some(self.gen_c_embeddings(clip)?), + }; + let x_in = xs; + + let mut xs = xs + .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? + .apply(&self.embedding_conv)? + .apply(&self.embedding_ln)?; + + let mut level_outputs = Vec::new(); + for (i, down_block) in self.down_blocks.iter().enumerate() { + if let Some(ln) = &down_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &down_block.conv { + xs = xs.apply(conv)? + } + let skip = match &self.effnet_mappers[i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for block in down_block.sub_blocks.iter() { + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + level_outputs.push(xs.clone()) + } + level_outputs.reverse(); + let mut xs = level_outputs[0].clone(); + + for (i, up_block) in self.up_blocks.iter().enumerate() { + let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for (j, block) in up_block.sub_blocks.iter().enumerate() { + let skip = if j == 0 && i > 0 { + Some(&level_outputs[i]) + } else { + None + }; + let skip = match (skip, effnet_c.as_ref()) { + (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), + (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), + (None, None) => None, + }; + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + if let Some(ln) = &up_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &up_block.conv { + xs = xs.apply(conv)? + } + } + + let ab = xs + .apply(&self.clf_ln)? + .apply(&self.clf_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? + .chunk(2, 1)?; + let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; + (x_in - &ab[0])? / b + } +} diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs new file mode 100644 index 00000000..7b076f06 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -0,0 +1,6 @@ +pub mod attention_processor; +pub mod common; +pub mod ddpm; +pub mod diffnext; +pub mod paella_vq; +pub mod prior; diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs new file mode 100644 index 00000000..4a69cca0 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -0,0 +1,211 @@ +use super::common::LayerNormNoWeights; +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct MixingResidualBlock { + norm1: LayerNormNoWeights, + depthwise_conv: candle_nn::Conv2d, + norm2: LayerNormNoWeights, + channelwise_lin1: candle_nn::Linear, + channelwise_lin2: candle_nn::Linear, + gammas: Vec<f32>, +} + +impl MixingResidualBlock { + pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let norm1 = LayerNormNoWeights::new(inp)?; + let norm2 = LayerNormNoWeights::new(inp)?; + let cfg = candle_nn::Conv2dConfig { + groups: inp, + ..Default::default() + }; + let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?; + let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?; + let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?; + let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?; + Ok(Self { + norm1, + depthwise_conv, + norm2, + channelwise_lin1, + channelwise_lin2, + gammas, + }) + } +} + +impl Module for MixingResidualBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mods = &self.gammas; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm1)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[0] as f64, mods[1] as f64)?; + let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?; + let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm2)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[3] as f64, mods[4] as f64)?; + let x_temp = x_temp + .permute((0, 2, 3, 1))? + .contiguous()? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_temp * mods[5] as f64 + } +} + +#[derive(Debug)] +pub struct PaellaVQ { + in_block_conv: candle_nn::Conv2d, + out_block_conv: candle_nn::Conv2d, + down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>, + down_blocks_conv: candle_nn::Conv2d, + down_blocks_bn: candle_nn::BatchNorm, + up_blocks_conv: candle_nn::Conv2d, + up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>, +} + +impl PaellaVQ { + pub fn new(vb: VarBuilder) -> Result<Self> { + const IN_CHANNELS: usize = 3; + const OUT_CHANNELS: usize = 3; + const LATENT_CHANNELS: usize = 4; + const EMBED_DIM: usize = 384; + const BOTTLENECK_BLOCKS: usize = 12; + const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM]; + + let in_block_conv = candle_nn::conv2d( + IN_CHANNELS * 4, + C_LEVELS[0], + 1, + Default::default(), + vb.pp("in_block.1"), + )?; + let out_block_conv = candle_nn::conv2d( + C_LEVELS[0], + OUT_CHANNELS * 4, + 1, + Default::default(), + vb.pp("out_block.0"), + )?; + + let mut down_blocks = Vec::new(); + let vb_d = vb.pp("down_blocks"); + let mut d_idx = 0; + for (i, &c_level) in C_LEVELS.iter().enumerate() { + let conv_block = if i > 0 { + let cfg = candle_nn::Conv2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?; + d_idx += 1; + Some(block) + } else { + None + }; + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?; + d_idx += 1; + down_blocks.push((conv_block, res_block)) + } + let vb_d = vb_d.pp(d_idx); + let down_blocks_conv = candle_nn::conv2d_no_bias( + C_LEVELS[1], + LATENT_CHANNELS, + 1, + Default::default(), + vb_d.pp(0), + )?; + let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?; + + let mut up_blocks = Vec::new(); + let vb_u = vb.pp("up_blocks"); + let mut u_idx = 0; + let up_blocks_conv = candle_nn::conv2d( + LATENT_CHANNELS, + C_LEVELS[1], + 1, + Default::default(), + vb_u.pp(u_idx).pp(0), + )?; + u_idx += 1; + for (i, &c_level) in C_LEVELS.iter().rev().enumerate() { + let mut res_blocks = Vec::new(); + let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 }; + for _j in 0..n_bottleneck_blocks { + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?; + u_idx += 1; + res_blocks.push(res_block) + } + let conv_block = if i < C_LEVELS.len() - 1 { + let cfg = candle_nn::ConvTranspose2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = candle_nn::conv_transpose2d( + c_level, + C_LEVELS[C_LEVELS.len() - i - 2], + 4, + cfg, + vb_u.pp(u_idx), + )?; + u_idx += 1; + Some(block) + } else { + None + }; + up_blocks.push((res_blocks, conv_block)) + } + Ok(Self { + in_block_conv, + down_blocks, + down_blocks_conv, + down_blocks_bn, + up_blocks, + up_blocks_conv, + out_block_conv, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?; + for down_block in self.down_blocks.iter() { + if let Some(conv) = &down_block.0 { + xs = xs.apply(conv)? + } + xs = xs.apply(&down_block.1)? + } + xs.apply(&self.down_blocks_conv)? + .apply(&self.down_blocks_bn) + } + + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + // TODO: quantizer if we want to support `force_not_quantize=False`. + let mut xs = xs.apply(&self.up_blocks_conv)?; + for up_block in self.up_blocks.iter() { + for b in up_block.0.iter() { + xs = xs.apply(b)?; + } + if let Some(conv) = &up_block.1 { + xs = xs.apply(conv)? + } + } + xs.apply(&self.out_block_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2)) + } +} + +impl Module for PaellaVQ { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.decode(&self.encode(xs)?) + } +} diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs new file mode 100644 index 00000000..97ccf0e2 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -0,0 +1,103 @@ +use super::common::{AttnBlock, ResBlock, TimestepBlock}; +use candle::{DType, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct Block { + res_block: ResBlock, + ts_block: TimestepBlock, + attn_block: AttnBlock, +} + +#[derive(Debug)] +pub struct WPrior { + projection: candle_nn::Conv2d, + cond_mapper_lin1: candle_nn::Linear, + cond_mapper_lin2: candle_nn::Linear, + blocks: Vec<Block>, + out_ln: super::common::WLayerNorm, + out_conv: candle_nn::Conv2d, + c_r: usize, +} + +impl WPrior { + #[allow(clippy::too_many_arguments)] + pub fn new( + c_in: usize, + c: usize, + c_cond: usize, + c_r: usize, + depth: usize, + nhead: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; + let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?; + let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?; + let out_ln = super::common::WLayerNorm::new(c)?; + let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?; + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?; + let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?; + let attn_block = AttnBlock::new( + c, + c, + nhead, + true, + use_flash_attn, + vb.pp(format!("blocks.{}", 3 * index + 2)), + )?; + blocks.push(Block { + res_block, + ts_block, + attn_block, + }) + } + Ok(Self { + projection, + cond_mapper_lin1, + cond_mapper_lin2, + blocks, + out_ln, + out_conv, + c_r, + }) + } + + pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> { + let x_in = xs; + let mut xs = xs.apply(&self.projection)?; + let c_embed = c + .apply(&self.cond_mapper_lin1)? + .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))? + .apply(&self.cond_mapper_lin2)?; + let r_embed = self.gen_r_embedding(r)?; + for block in self.blocks.iter() { + xs = block.res_block.forward(&xs, None)?; + xs = block.ts_block.forward(&xs, &r_embed)?; + xs = block.attn_block.forward(&xs, &c_embed)?; + } + let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?; + (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5) + } +} diff --git a/candle-examples/src/object_detection.rs b/candle-transformers/src/object_detection.rs index c7c60136..ce579316 100644 --- a/candle-examples/src/object_detection.rs +++ b/candle-transformers/src/object_detection.rs @@ -1,12 +1,12 @@ /// A bounding box around an object. #[derive(Debug, Clone)] -pub struct Bbox { +pub struct Bbox<D> { pub xmin: f32, pub ymin: f32, pub xmax: f32, pub ymax: f32, pub confidence: f32, - pub keypoints: Vec<KeyPoint>, + pub data: D, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -17,7 +17,7 @@ pub struct KeyPoint { } /// Intersection over union of two bounding boxes. -pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 { +pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 { let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); let i_xmin = b1.xmin.max(b2.xmin); @@ -28,7 +28,7 @@ pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 { i_area / (b1_area + b2_area - i_area) } -pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) { +pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) { // Perform non-maximum suppression. for bboxes_for_class in bboxes.iter_mut() { bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs new file mode 100644 index 00000000..76f994d0 --- /dev/null +++ b/candle-transformers/tests/generation_tests.rs @@ -0,0 +1,29 @@ +use candle::{Device, Result, Tensor}; +use candle_transformers::generation::LogitsProcessor; + +#[test] +fn sample_with_zero_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(1337, None, None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + Ok(()) +} + +#[test] +fn sample_with_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(0.9), None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 0); + Ok(()) +} + +#[test] +fn sample_with_top_p() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5)); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 2); + Ok(()) +} diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml new file mode 100644 index 00000000..81a043de --- /dev/null +++ b/candle-wasm-examples/bert/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "candle-wasm-example-bert" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.2" } +num-traits = { workspace = true } +tokenizers = { workspace = true, features = ["unstable_wasm"] } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +safetensors = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +getrandom = { version = "0.2", features = ["js"] } +gloo = "0.8" +js-sys = "0.3.64" +wasm-bindgen = "0.2.87" +serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/bert/README.md b/candle-wasm-examples/bert/README.md new file mode 100644 index 00000000..c34d33cc --- /dev/null +++ b/candle-wasm-examples/bert/README.md @@ -0,0 +1,26 @@ +## Running BERT with Candle and WASM + +Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime. + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/bert/bertWorker.js b/candle-wasm-examples/bert/bertWorker.js new file mode 100644 index 00000000..fd796c2b --- /dev/null +++ b/candle-wasm-examples/bert/bertWorker.js @@ -0,0 +1,77 @@ +//load Candle Bert Module wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url) { + const cacheName = "bert-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class Bert { + static instance = {}; + + static async getInstance(weightsURL, tokenizerURL, configURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] = + await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(configURL), + ]); + + this.instance[modelID] = new Model( + weightsArrayU8, + tokenizerArrayU8, + mel_filtersArrayU8 + ); + } else { + self.postMessage({ status: "ready", message: "Model Already Loaded" }); + } + return this.instance[modelID]; + } +} + +self.addEventListener("message", async (event) => { + const { + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + normalize = true, + } = event.data; + try { + self.postMessage({ status: "ready", message: "Starting Bert Model" }); + const model = await Bert.getInstance( + weightsURL, + tokenizerURL, + configURL, + modelID + ); + self.postMessage({ + status: "embedding", + message: "Calculating Embeddings", + }); + const output = model.get_embeddings({ + sentences: sentences, + normalize_embeddings: normalize, + }); + + self.postMessage({ + status: "complete", + message: "complete", + output: output.data, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/bert/build-lib.sh b/candle-wasm-examples/bert/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/bert/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/bert/lib-example.html b/candle-wasm-examples/bert/lib-example.html new file mode 100644 index 00000000..d10ea1db --- /dev/null +++ b/candle-wasm-examples/bert/lib-example.html @@ -0,0 +1,368 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle Bert</title> + </head> + <body></body> +</html> + +<!DOCTYPE html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + html, + body { + font-family: "Source Sans 3", sans-serif; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module" src="./code.js"></script> + <script type="module"> + import { hcl } from "https://cdn.skypack.dev/d3-color@3"; + import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3"; + import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4"; + import { + getModelInfo, + getEmbeddings, + getWikiText, + cosineSimilarity, + } from "./utils.js"; + + const bertWorker = new Worker("./bertWorker.js", { + type: "module", + }); + + const inputContainerEL = document.querySelector("#input-container"); + const textAreaEl = document.querySelector("#input-area"); + const outputAreaEl = document.querySelector("#output-area"); + const formEl = document.querySelector("#form"); + const searchInputEl = document.querySelector("#search-input"); + const formWikiEl = document.querySelector("#form-wiki"); + const searchWikiEl = document.querySelector("#search-wiki"); + const outputStatusEl = document.querySelector("#output-status"); + const modelSelectEl = document.querySelector("#model"); + + const sentencesRegex = + /(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm; + + let sentenceEmbeddings = []; + let currInputText = ""; + let isCalculating = false; + + function toggleTextArea(state) { + if (state) { + textAreaEl.hidden = false; + textAreaEl.focus(); + } else { + textAreaEl.hidden = true; + } + } + inputContainerEL.addEventListener("focus", (e) => { + toggleTextArea(true); + }); + textAreaEl.addEventListener("blur", (e) => { + toggleTextArea(false); + }); + textAreaEl.addEventListener("focusout", (e) => { + toggleTextArea(false); + if (currInputText === textAreaEl.value || isCalculating) return; + populateOutputArea(textAreaEl.value); + calculateEmbeddings(textAreaEl.value); + }); + + modelSelectEl.addEventListener("change", (e) => { + if (currInputText === "" || isCalculating) return; + populateOutputArea(textAreaEl.value); + calculateEmbeddings(textAreaEl.value); + }); + + function populateOutputArea(text) { + currInputText = text; + const sentences = text.split(sentencesRegex); + + outputAreaEl.innerHTML = ""; + for (const [id, sentence] of sentences.entries()) { + const sentenceEl = document.createElement("span"); + sentenceEl.id = `sentence-${id}`; + sentenceEl.innerText = sentence + " "; + outputAreaEl.appendChild(sentenceEl); + } + } + formEl.addEventListener("submit", async (e) => { + e.preventDefault(); + if (isCalculating || currInputText === "") return; + toggleInputs(true); + const modelID = modelSelectEl.value; + const { modelURL, tokenizerURL, configURL, search_prefix } = + getModelInfo(modelID); + + const text = searchInputEl.value; + const query = search_prefix + searchInputEl.value; + outputStatusEl.classList.remove("invisible"); + outputStatusEl.innerText = "Calculating embeddings for query..."; + isCalculating = true; + const out = await getEmbeddings( + bertWorker, + modelURL, + tokenizerURL, + configURL, + modelID, + [query] + ); + outputStatusEl.classList.add("invisible"); + const queryEmbeddings = out.output[0]; + // calculate cosine similarity with all sentences given the query + const distances = sentenceEmbeddings + .map((embedding, id) => ({ + id, + similarity: cosineSimilarity(queryEmbeddings, embedding), + })) + .sort((a, b) => b.similarity - a.similarity) + // getting top 10 most similar sentences + .slice(0, 10); + + const colorScale = scaleLinear() + .domain([ + distances[distances.length - 1].similarity, + distances[0].similarity, + ]) + .range([0, 1]) + .interpolate(() => interpolateReds); + outputAreaEl.querySelectorAll("span").forEach((el) => { + el.style.color = "unset"; + el.style.backgroundColor = "unset"; + }); + distances.forEach((d) => { + const el = outputAreaEl.querySelector(`#sentence-${d.id}`); + const color = colorScale(d.similarity); + const fontColor = hcl(color).l < 70 ? "white" : "black"; + el.style.color = fontColor; + el.style.backgroundColor = color; + }); + + outputAreaEl + .querySelector(`#sentence-${distances[0].id}`) + .scrollIntoView({ + behavior: "smooth", + block: "center", + inline: "nearest", + }); + + isCalculating = false; + toggleInputs(false); + }); + async function calculateEmbeddings(text) { + isCalculating = true; + toggleInputs(true); + const modelID = modelSelectEl.value; + const { modelURL, tokenizerURL, configURL, document_prefix } = + getModelInfo(modelID); + + const sentences = text.split(sentencesRegex); + const allEmbeddings = []; + outputStatusEl.classList.remove("invisible"); + for (const [id, sentence] of sentences.entries()) { + const query = document_prefix + sentence; + outputStatusEl.innerText = `Calculating embeddings: sentence ${ + id + 1 + } of ${sentences.length}`; + const embeddings = await getEmbeddings( + bertWorker, + modelURL, + tokenizerURL, + configURL, + modelID, + [query], + updateStatus + ); + allEmbeddings.push(embeddings); + } + outputStatusEl.classList.add("invisible"); + sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]); + isCalculating = false; + toggleInputs(false); + } + + function updateStatus(data) { + if ("status" in data) { + if (data.status === "loading") { + outputStatusEl.innerText = data.message; + outputStatusEl.classList.remove("invisible"); + } + } + } + function toggleInputs(state) { + const interactive = document.querySelectorAll(".interactive"); + interactive.forEach((el) => { + if (state) { + el.disabled = true; + } else { + el.disabled = false; + } + }); + } + + searchWikiEl.addEventListener("input", () => { + searchWikiEl.setCustomValidity(""); + }); + + formWikiEl.addEventListener("submit", async (e) => { + e.preventDefault(); + if ("example" in e.submitter.dataset) { + searchWikiEl.value = e.submitter.innerText; + } + const text = searchWikiEl.value; + + if (isCalculating || text === "") return; + try { + const wikiText = await getWikiText(text); + searchWikiEl.setCustomValidity(""); + textAreaEl.innerHTML = wikiText; + populateOutputArea(wikiText); + calculateEmbeddings(wikiText); + searchWikiEl.value = ""; + } catch { + searchWikiEl.setCustomValidity("Invalid Wikipedia article name"); + searchWikiEl.reportValidity(); + } + }); + </script> + </head> + <body class="container max-w-4xl mx-auto p-4"> + <main class="grid grid-cols-1 gap-5 relative"> + <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> + <div> + <h1 class="text-5xl font-bold">Candle BERT</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + Running sentence embeddings and similarity search in the browser using + the Bert Model written with + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + >Candle + </a> + and compiled to Wasm. Embeddings models from are from + <a + href="https://huggingface.co/sentence-transformers/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + > + Sentence Transformers + </a> + and + <a + href="https://huggingface.co/intfloat/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + > + Liang Wang - e5 Models + </a> + </p> + </div> + + <div> + <label for="model" class="font-medium block">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max" + > + <option value="intfloat_e5_small_v2" selected> + intfloat/e5-small-v2 (133 MB) + </option> + <option value="intfloat_e5_base_v2"> + intfloat/e5-base-v2 (438 MB) + </option> + <option value="intfloat_multilingual_e5_small"> + intfloat/multilingual-e5-small (471 MB) + </option> + <option value="sentence_transformers_all_MiniLM_L6_v2"> + sentence-transformers/all-MiniLM-L6-v2 (90.9 MB) + </option> + <option value="sentence_transformers_all_MiniLM_L12_v2"> + sentence-transformers/all-MiniLM-L12-v2 (133 MB) + </option> + </select> + </div> + <div> + <h3 class="font-medium">Examples:</h3> + <form + id="form-wiki" + class="flex text-xs rounded-md justify-between w-min gap-3" + > + <input type="submit" hidden /> + + <button data-example class="disabled:cursor-not-allowed interactive"> + Pizza + </button> + <button data-example class="disabled:cursor-not-allowed interactive"> + Paris + </button> + <button data-example class="disabled:cursor-not-allowed interactive"> + Physics + </button> + <input + type="text" + id="search-wiki" + title="Search Wikipedia article by title" + class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive" + placeholder="Load Wikipedia article..." + /> + <button + title="Search Wikipedia article and load into input" + class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive" + > + Load + </button> + </form> + </div> + <form + id="form" + class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center" + > + <input type="submit" hidden /> + <input + type="text" + id="search-input" + class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed" + placeholder="Search query here..." + /> + <button + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive" + > + Search + </button> + </form> + <div> + <h3 class="font-medium">Input text:</h3> + <div class="flex justify-between items-center"> + <div class="rounded-md inline text-xs"> + <span id="output-status" class="m-auto font-light invisible" + >C</span + > + </div> + </div> + <div + id="input-container" + tabindex="0" + class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative" + > + <textarea + id="input-area" + hidden + value="" + placeholder="Input text to perform semantic similarity search..." + class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible" + ></textarea> + <p id="output-area" class="grid-rows-2"> + Input text to perform semantic similarity search... + </p> + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs new file mode 100644 index 00000000..f5521abd --- /dev/null +++ b/candle-wasm-examples/bert/src/bin/m.rs @@ -0,0 +1,92 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config}; +use candle_wasm_example_bert::console_log; +use tokenizers::{PaddingParams, Tokenizer}; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct Model { + bert: BertModel, + tokenizer: Tokenizer, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let device = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device); + let config: Config = serde_json::from_slice(&config)?; + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let bert = BertModel::load(vb, &config)?; + + Ok(Self { bert, tokenizer }) + } + + pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> { + let input: Params = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + let sentences = input.sentences; + let normalize_embeddings = input.normalize_embeddings; + + let device = &Device::Cpu; + if let Some(pp) = self.tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + self.tokenizer.with_padding(Some(pp)); + } + let tokens = self + .tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(|m| JsError::new(&m.to_string()))?; + + let token_ids: Vec<Tensor> = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::<Result<Vec<_>, _>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + let token_type_ids = token_ids.zeros_like()?; + console_log!("running inference on batch {:?}", token_ids.shape()); + let embeddings = self.bert.forward(&token_ids, &token_type_ids)?; + console_log!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if normalize_embeddings { + embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? + } else { + embeddings + }; + let embeddings_data = embeddings.to_vec2()?; + Ok(serde_wasm_bindgen::to_value(&Embeddings { + data: embeddings_data, + })?) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct Embeddings { + data: Vec<Vec<f64>>, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct Params { + sentences: Vec<String>, + normalize_embeddings: bool, +} +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/bert/src/lib.rs b/candle-wasm-examples/bert/src/lib.rs new file mode 100644 index 00000000..1e3657be --- /dev/null +++ b/candle-wasm-examples/bert/src/lib.rs @@ -0,0 +1,20 @@ +use candle_transformers::models::bert; +use wasm_bindgen::prelude::*; + +pub use bert::{BertModel, Config, DTYPE}; +pub use tokenizers::{PaddingParams, Tokenizer}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} diff --git a/candle-wasm-examples/bert/utils.js b/candle-wasm-examples/bert/utils.js new file mode 100644 index 00000000..9d8bd7bd --- /dev/null +++ b/candle-wasm-examples/bert/utils.js @@ -0,0 +1,99 @@ +export async function getEmbeddings( + worker, + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + updateStatus = null +) { + return new Promise((resolve, reject) => { + worker.postMessage({ + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + }); + function messageHandler(event) { + if ("error" in event.data) { + worker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete") { + worker.removeEventListener("message", messageHandler); + resolve(event.data); + } + if (updateStatus) updateStatus(event.data); + } + worker.addEventListener("message", messageHandler); + }); +} + +const MODELS = { + intfloat_e5_small_v2: { + base_url: "https://huggingface.co/intfloat/e5-small-v2/resolve/main/", + search_prefix: "query: ", + document_prefix: "passage: ", + }, + intfloat_e5_base_v2: { + base_url: "https://huggingface.co/intfloat/e5-base-v2/resolve/main/", + search_prefix: "query: ", + document_prefix: "passage:", + }, + intfloat_multilingual_e5_small: { + base_url: + "https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/", + search_prefix: "query: ", + document_prefix: "passage: ", + }, + sentence_transformers_all_MiniLM_L6_v2: { + base_url: + "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/refs%2Fpr%2F21/", + search_prefix: "", + document_prefix: "", + }, + sentence_transformers_all_MiniLM_L12_v2: { + base_url: + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/refs%2Fpr%2F4/", + search_prefix: "", + document_prefix: "", + }, +}; +export function getModelInfo(id) { + return { + modelURL: MODELS[id].base_url + "model.safetensors", + configURL: MODELS[id].base_url + "config.json", + tokenizerURL: MODELS[id].base_url + "tokenizer.json", + search_prefix: MODELS[id].search_prefix, + document_prefix: MODELS[id].document_prefix, + }; +} + +export function cosineSimilarity(vec1, vec2) { + const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0); + const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0)); + const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0)); + return dot / (a * b); +} +export async function getWikiText(article) { + // thanks to wikipedia for the API + const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`; + return fetch(URL, { + method: "GET", + headers: { + Accept: "application/json", + }, + }) + .then((r) => r.json()) + .then((data) => { + const pages = data.query.pages; + const pageId = Object.keys(pages)[0]; + const extract = pages[pageId].extract; + if (extract === undefined || extract === "") { + throw new Error("No article found"); + } + return extract; + }) + .catch((error) => console.error("Error:", error)); +} diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 51eac694..601f5e34 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.2.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.2.1" } +candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.3" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.3" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/llama2-c/README.md b/candle-wasm-examples/llama2-c/README.md new file mode 100644 index 00000000..0b41e064 --- /dev/null +++ b/candle-wasm-examples/llama2-c/README.md @@ -0,0 +1,47 @@ +## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples + +Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes. + +### Pure Rust UI + +To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install) +From the `candle-wasm-examples/llama2-c` directory run: + +Download assets: + +```bash +# Model and tokenizer + +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json + +``` + +Run hot reload server: + +```bash +trunk serve --release --public-url / --port 8080 +``` + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/llama2-c/build-lib.sh b/candle-wasm-examples/llama2-c/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/llama2-c/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html new file mode 100644 index 00000000..86fe9811 --- /dev/null +++ b/candle-wasm-examples/llama2-c/lib-example.html @@ -0,0 +1,359 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle Llama.c Rust/WASM</title> + </head> + <body></body> +</html> + +<!DOCTYPE html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + html, + body { + font-family: "Source Sans 3", sans-serif; + } + code, + output, + select, + pre { + font-family: "Source Code Pro", monospace; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module"> + // base url for audio examples + const MODELS_BASE_URL = + "https://huggingface.co/karpathy/tinyllamas/resolve/main"; + + // models base url + const MODELS = { + stories15M: { + url: "stories15M.bin", + seq_len: 256, + }, + stories42M: { + url: "stories42M.bin", + seq_len: 1024, + }, + stories110M: { + url: "stories110M.bin", + seq_len: 1024, + }, + }; + + const llamaWorker = new Worker("./llama2cWorker.js", { + type: "module", + }); + async function generateSequence(controller) { + const getValue = (id) => document.querySelector(`#${id}`).value; + const modelID = getValue("model"); + const model = MODELS[modelID]; + const weightsURL = `${MODELS_BASE_URL}/${model.url}`; + const prompt = getValue("prompt"); + const temperature = getValue("temperature"); + const topP = getValue("top-p"); + const repeatPenalty = getValue("repeat_penalty"); + const seed = getValue("seed"); + const maxSeqLen = getValue("max-seq"); + + function updateStatus(data) { + const outStatus = document.querySelector("#output-status"); + const outGen = document.querySelector("#output-generation"); + const outCounter = document.querySelector("#output-counter"); + + switch (data.status) { + case "loading": + outStatus.hidden = false; + outStatus.textContent = data.message; + outGen.hidden = true; + outCounter.hidden = true; + break; + case "generating": + const { message, prompt, sentence, tokensSec, totalTime } = data; + outStatus.hidden = true; + outCounter.hidden = false; + outGen.hidden = false; + outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace( + /\<s\>|\<\/s\>/g, + "" + )}`; + outCounter.innerHTML = `${(totalTime / 1000).toFixed( + 2 + )}s (${tokensSec.toFixed(2)} tok/s)`; + break; + case "complete": + outStatus.hidden = true; + outGen.hidden = false; + break; + } + } + + return new Promise((resolve, reject) => { + llamaWorker.postMessage({ + weightsURL, + modelID, + tokenizerURL: "tokenizer.json", + prompt, + temp: temperature, + top_p: topP, + repeatPenalty, + seed: BigInt(seed), + maxSeqLen, + command: "start", + }); + + const handleAbort = () => { + llamaWorker.postMessage({ command: "abort" }); + }; + const handleMessage = (event) => { + const { status, error, message, prompt, sentence } = event.data; + if (status) updateStatus(event.data); + if (error) { + llamaWorker.removeEventListener("message", handleMessage); + reject(new Error(error)); + } + if (status === "complete") { + llamaWorker.removeEventListener("message", handleMessage); + resolve(event.data); + } + }; + + controller.signal.addEventListener("abort", handleAbort); + llamaWorker.addEventListener("message", handleMessage); + }); + } + + const form = document.querySelector("#form"); + const prompt = document.querySelector("#prompt"); + const clearBtn = document.querySelector("#clear-btn"); + const runBtn = document.querySelector("#run"); + const modelSelect = document.querySelector("#model"); + let runController = new AbortController(); + let isRunning = false; + + modelSelect.addEventListener("change", (e) => { + const model = MODELS[e.target.value]; + document.querySelector("#max-seq").max = model.seq_len; + document.querySelector("#max-seq").nextElementSibling.value = + model.seq_len; + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (isRunning) { + stopRunning(); + } else { + startRunning(); + await generateSequence(runController); + stopRunning(); + } + }); + + function startRunning() { + isRunning = true; + runBtn.textContent = "Stop"; + } + + function stopRunning() { + runController.abort(); + runController = new AbortController(); + runBtn.textContent = "Run"; + isRunning = false; + } + clearBtn.addEventListener("click", (e) => { + e.preventDefault(); + prompt.value = ""; + clearBtn.classList.add("invisible"); + runBtn.disabled = true; + stopRunning(); + }); + prompt.addEventListener("input", (e) => { + runBtn.disabled = false; + if (e.target.value.length > 0) { + clearBtn.classList.remove("invisible"); + } else { + clearBtn.classList.add("invisible"); + } + }); + </script> + </head> + <body class="container max-w-4xl mx-auto p-4 text-gray-800"> + <main class="grid grid-cols-1 gap-8 relative"> + <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> + <div> + <h1 class="text-5xl font-bold">Candle Llama2.c</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + <a + href="https://github.com/karpathy/llama2.c" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >Llama2.c</a + > + is Andrey Karpathy's C implementation of the Llama 2 LLM model in C. + This demo uses + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + >Candle + </a> + to run Llama2.c in the browser using rust/wasm. + </p> + </div> + + <div> + <label for="model" class="font-medium">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light" + > + <option value="stories15M" selected>stories 15M (60.8 MB)</option> + <option value="stories42M">stories 42M (167 MB)</option> + <option value="stories110M">stories 110M (438 MB)</option> + </select> + </div> + <form + id="form" + class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center" + > + <input type="submit" hidden /> + <input + type="text" + id="prompt" + class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" + placeholder="Add your prompt here..." + value="Once upon a time" + /> + <button id="clear-btn"> + <svg + fill="none" + xmlns="http://www.w3.org/2000/svg" + width="40" + viewBox="0 0 70 40" + > + <path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" /> + <path + d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1" + opacity=".5" + stroke="#1F2937" + stroke-width="2" + /> + </svg> + </button> + <button + id="run" + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" + > + Run + </button> + </form> + <div class="grid grid-cols-3 max-w-md items-center gap-3"> + <label class="text-sm font-medium" for="max-seq">Maximum length </label> + <input + type="range" + id="max-seq" + name="max-seq" + min="1" + max="256" + step="1" + value="200" + oninput="this.nextElementSibling.value = Number(this.value)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 200</output + > + <label class="text-sm font-medium" for="temperature">Temperature</label> + <input + type="range" + id="temperature" + name="temperature" + min="0" + max="2" + step="0.01" + value="0.50" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 0.50</output + > + <label class="text-sm font-medium" for="top-p">Top-p</label> + <input + type="range" + id="top-p" + name="top-p" + min="0" + max="1" + step="0.01" + value="1.00" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 1.00</output + > + + <label class="text-sm font-medium" for="repeat_penalty" + >Repeat Penalty</label + > + + <input + type="range" + id="repeat_penalty" + name="repeat_penalty" + min="-2" + max="2" + step="0.01" + value="1.10" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + >1.10</output + > + <label class="text-sm font-medium" for="seed">Seed</label> + <input + type="number" + id="seed" + name="seed" + value="299792458" + class="font-light border border-gray-700 text-right rounded-md p-2" + /> + <button + id="run" + onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))" + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm" + > + Rand + </button> + </div> + <div> + <h3 class="font-medium">Generation:</h3> + <div + class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2" + > + <div + id="output-counter" + hidden + class="ml-auto font-semibold grid-rows-1 text-sm" + ></div> + <p hidden id="output-generation" class="grid-rows-2"></p> + <span id="output-status" class="m-auto font-light" + >No output yet</span + > + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js new file mode 100644 index 00000000..abaf3401 --- /dev/null +++ b/candle-wasm-examples/llama2-c/llama2cWorker.js @@ -0,0 +1,106 @@ +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url) { + const cacheName = "llama2c-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class Llama2C { + static instance = {}; + + static async getInstance(weightsURL, modelID, tokenizerURL) { + // load individual modelID only once + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + + const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + ]); + + this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8); + } + return this.instance[modelID]; + } +} + +let controller = null; +self.addEventListener("message", (event) => { + if (event.data.command === "start") { + controller = new AbortController(); + generate(event.data); + } else if (event.data.command === "abort") { + controller.abort(); + } +}); + +async function generate(data) { + const { + weightsURL, + modelID, + tokenizerURL, + prompt, + temp, + repeatPenalty, + seed, + maxSeqLen, + } = data; + try { + self.postMessage({ status: "loading", message: "Starting llama2.c" }); + const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL); + + self.postMessage({ status: "loading", message: "Initializing model" }); + model.init_with_prompt(prompt, temp, repeatPenalty, seed); + + const seq_len = model.get_seq_len(); + + let sentence = ""; + let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1; + let startTime = performance.now(); + let tokensCount = 0; + while (tokensCount < maxTokens) { + await new Promise(async (resolve) => { + if (controller && controller.signal.aborted) { + self.postMessage({ + status: "aborted", + message: "Aborted", + output: prompt + sentence, + }); + return; + } + const token = await model.next_token(); + const tokensSec = + ((tokensCount + 1) / (performance.now() - startTime)) * 1000; + + sentence += token; + self.postMessage({ + status: "generating", + message: "Generating token", + token: token, + sentence: sentence, + totalTime: performance.now() - startTime, + tokensSec, + prompt: prompt, + }); + setTimeout(resolve, 0); + }); + tokensCount++; + } + self.postMessage({ + status: "complete", + message: "complete", + output: prompt + sentence, + }); + } catch (e) { + self.postMessage({ error: e }); + } +} diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index 782026a4..ea04a810 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -46,6 +46,7 @@ pub struct App { status: String, loaded: bool, temperature: std::rc::Rc<std::cell::RefCell<f64>>, + top_p: std::rc::Rc<std::cell::RefCell<f64>>, prompt: std::rc::Rc<std::cell::RefCell<String>>, generated: String, n_tokens: usize, @@ -81,6 +82,7 @@ impl Component for App { status, n_tokens: 0, temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), + top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)), prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())), generated: String::new(), current_decode: None, @@ -122,10 +124,11 @@ impl Component for App { self.n_tokens = 0; self.generated.clear(); let temp = *self.temperature.borrow(); + let top_p = *self.top_p.borrow(); let prompt = self.prompt.borrow().clone(); - console_log!("temp: {}, prompt: {}", temp, prompt); + console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt); ctx.link() - .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt))) + .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt))) } true } @@ -177,13 +180,21 @@ impl Component for App { fn view(&self, ctx: &Context<Self>) -> Html { use yew::TargetCast; let temperature = self.temperature.clone(); - let oninput = ctx.link().callback(move |e: yew::InputEvent| { + let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); if let Ok(temp) = f64::from_str(&input.value()) { *temperature.borrow_mut() = temp } Msg::Refresh }); + let top_p = self.top_p.clone(); + let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| { + let input: web_sys::HtmlInputElement = e.target_unchecked_into(); + if let Ok(top_p_input) = f64::from_str(&input.value()) { + *top_p.borrow_mut() = top_p_input + } + Msg::Refresh + }); let prompt = self.prompt.clone(); let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); @@ -201,9 +212,13 @@ impl Component for App { </p> </div> {"temperature \u{00a0} "} - <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> + <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/> {format!(" \u{00a0} {}", self.temperature.borrow())} <br/ > + {"top_p \u{00a0} "} + <input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/> + {format!(" \u{00a0} {}", self.top_p.borrow())} + <br/ > {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/> <br/ > { diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index d014e38a..61de9d7f 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -47,7 +47,7 @@ impl Model { tokenizer, model: weights, }); - let logits_processor = LogitsProcessor::new(299792458, None); + let logits_processor = LogitsProcessor::new(299792458, None, None); match model { Ok(inner) => Ok(Self { inner, @@ -60,11 +60,18 @@ impl Model { } #[wasm_bindgen] + pub fn get_seq_len(&mut self) -> usize { + self.inner.config.seq_len + } + + #[wasm_bindgen] pub fn init_with_prompt( &mut self, prompt: String, temp: f64, + top_p: f64, repeat_penalty: f32, + seed: u64, ) -> Result<String, JsError> { // First reset the cache. { @@ -74,13 +81,18 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(299792458, temp); + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self .inner .tokenizer - .encode(prompt.to_string(), true) + .encode(prompt, true) .map_err(|m| JsError::new(&m.to_string()))? .get_ids() .to_vec(); diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 3d187fcc..79dd2f32 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>( pub struct Model { pub cache: Cache, - config: Config, + pub config: Config, pub llama: Llama, pub tokenizer: Tokenizer, } @@ -62,12 +62,18 @@ impl Model { link: &WorkerLink<Worker>, id: HandlerId, temp: f64, + top_p: f64, prompt: String, ) -> Result<()> { let dev = Device::Cpu; let temp = if temp <= 0. { None } else { Some(temp) }; - console_log!("{temp:?} {prompt}"); - let mut logits_processor = LogitsProcessor::new(299792458, temp); + let top_p = if top_p <= 0. || top_p >= 1.0 { + None + } else { + Some(top_p) + }; + console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}"); + let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p); let mut index_pos = 0; let mut tokens = self .tokenizer @@ -268,7 +274,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(f64, String), + Run(f64, f64, String), } #[derive(Serialize, Deserialize)] @@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(temp, prompt) => match &mut self.model { + WorkerInput::Run(temp, top_p, prompt) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { { @@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker { } } let result = model - .run(&self.link, id, temp, prompt) + .run(&self.link, id, temp, top_p, prompt) .map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) } diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml new file mode 100644 index 00000000..46b85615 --- /dev/null +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "candle-wasm-example-sam" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.3" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.3" } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } +image = { workspace = true } +log = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +wasm-bindgen = "0.2.87" +serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md new file mode 100644 index 00000000..04ff2033 --- /dev/null +++ b/candle-wasm-examples/segment-anything/README.md @@ -0,0 +1,26 @@ +## Running Segment Anything Example + +Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes. + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/segment-anything/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html new file mode 100644 index 00000000..5060f073 --- /dev/null +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -0,0 +1,407 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle Segment Anything Model (SAM) Rust/WASM</title> + </head> + <body></body> +</html> + +<!DOCTYPE html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + html, + body { + font-family: "Source Sans 3", sans-serif; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module"> + // base url for image examples + const MODEL_BASEURL = + "https://huggingface.co/lmz/candle-sam/resolve/main/"; + + // models base url + const MODELS = { + sam_mobile_tiny: { + url: "mobile_sam-tiny-vitt.safetensors", + }, + sam_base: { + url: "sam_vit_b_01ec64.safetensors", + }, + }; + const samWorker = new Worker("./samWorker.js", { type: "module" }); + + async function segmentPoints( + modelURL, // URL to the weights file + modelID, // model ID + imageURL, // URL to the image file + points // {x, y} points to prompt image + ) { + return new Promise((resolve, reject) => { + function messageHandler(event) { + console.log(event.data); + if ("status" in event.data) { + updateStatus(event.data); + } + if ("error" in event.data) { + samWorker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete-embedding") { + samWorker.removeEventListener("message", messageHandler); + resolve(); + } + if (event.data.status === "complete") { + samWorker.removeEventListener("message", messageHandler); + resolve(event.data.output); + } + } + samWorker.addEventListener("message", messageHandler); + samWorker.postMessage({ + modelURL, + modelID, + imageURL, + points, + }); + }); + } + function updateStatus(statusMessage) { + statusOutput.innerText = event.data.message; + } + + const clearBtn = document.querySelector("#clear-btn"); + const canvas = document.querySelector("#canvas"); + const mask = document.querySelector("#mask"); + const ctxCanvas = canvas.getContext("2d"); + const ctxMask = mask.getContext("2d"); + const fileUpload = document.querySelector("#file-upload"); + const dropArea = document.querySelector("#drop-area"); + const dropButtons = document.querySelector("#drop-buttons"); + const imagesExamples = document.querySelector("#image-select"); + const modelSelection = document.querySelector("#model"); + const statusOutput = document.querySelector("#output-status"); + + //add event listener to file input + fileUpload.addEventListener("change", (e) => { + const target = e.target; + if (target.files.length > 0) { + const href = URL.createObjectURL(target.files[0]); + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } + }); + // add event listener to drop-area + dropArea.addEventListener("dragenter", (e) => { + e.preventDefault(); + dropArea.classList.add("border-blue-700"); + }); + dropArea.addEventListener("dragleave", (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); + }); + dropArea.addEventListener("dragover", (e) => { + e.preventDefault(); + }); + dropArea.addEventListener("drop", (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); + const url = e.dataTransfer.getData("text/uri-list"); + const files = e.dataTransfer.files; + + if (files.length > 0) { + const href = URL.createObjectURL(files[0]); + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } else if (url) { + cleanImageCanvas(); + drawImageCanvas(url); + setImageEmbeddings(url); + } + }); + + let hasImage = false; + let isSegmenting = false; + let isEmbedding = false; + let currentImageURL = ""; + //add event listener to image examples + imagesExamples.addEventListener("click", (e) => { + if (isEmbedding || isSegmenting) { + return; + } + const target = e.target; + if (target.nodeName === "IMG") { + const href = target.src; + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } + }); + //add event listener to clear button + clearBtn.addEventListener("click", () => { + cleanImageCanvas(); + }); + //add click event to canvas + canvas.addEventListener("click", async (event) => { + if (!hasImage || isEmbedding || isSegmenting) { + return; + } + const targetBox = event.target.getBoundingClientRect(); + const x = (event.clientX - targetBox.left) / targetBox.width; + const y = (event.clientY - targetBox.top) / targetBox.height; + isSegmenting = true; + const { maskURL } = await getSegmentationMask({ x, y }); + isSegmenting = false; + drawMask(maskURL); + }); + + async function getSegmentationMask(points) { + const modelID = modelSelection.value; + const modelURL = MODEL_BASEURL + MODELS[modelID].url; + const imageURL = currentImageURL; + const { maskURL } = await segmentPoints( + modelURL, + modelID, + imageURL, + points + ); + return { maskURL }; + } + async function setImageEmbeddings(imageURL) { + if (isEmbedding) { + return; + } + canvas.classList.remove("cursor-pointer"); + canvas.classList.add("cursor-wait"); + clearBtn.disabled = true; + const modelID = modelSelection.value; + const modelURL = MODEL_BASEURL + MODELS[modelID].url; + isEmbedding = true; + await segmentPoints(modelURL, modelID, imageURL); + canvas.classList.remove("cursor-wait"); + canvas.classList.add("cursor-pointer"); + clearBtn.disabled = false; + isEmbedding = false; + currentImageURL = imageURL; + } + + function cleanImageCanvas() { + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + ctxMask.clearRect(0, 0, canvas.width, canvas.height); + hasImage = false; + isEmbedding = false; + isSegmenting = false; + currentImageURL = ""; + clearBtn.classList.add("invisible"); + canvas.parentElement.style.height = "auto"; + dropButtons.classList.remove("invisible"); + } + function drawMask(maskURL) { + if (!maskURL) { + throw new Error("No mask URL provided"); + } + + const img = new Image(); + img.crossOrigin = "anonymous"; + + img.onload = () => { + mask.width = canvas.width; + mask.height = canvas.height; + ctxMask.drawImage(canvas, 0, 0); + ctxMask.globalCompositeOperation = "source-atop"; + ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)"; + ctxMask.fillRect(0, 0, canvas.width, canvas.height); + ctxMask.globalCompositeOperation = "destination-in"; + ctxMask.drawImage(img, 0, 0); + }; + img.src = maskURL; + } + function drawImageCanvas(imgURL) { + if (!imgURL) { + throw new Error("No image URL provided"); + } + + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + + const img = new Image(); + img.crossOrigin = "anonymous"; + + img.onload = () => { + canvas.width = img.width; + canvas.height = img.height; + ctxCanvas.drawImage(img, 0, 0); + canvas.parentElement.style.height = canvas.offsetHeight + "px"; + hasImage = true; + clearBtn.classList.remove("invisible"); + dropButtons.classList.add("invisible"); + }; + img.src = imgURL; + } + + const observer = new ResizeObserver((entries) => { + for (let entry of entries) { + if (entry.target === canvas) { + canvas.parentElement.style.height = canvas.offsetHeight + "px"; + } + } + }); + observer.observe(canvas); + </script> + </head> + <body class="container max-w-4xl mx-auto p-4"> + <main class="grid grid-cols-1 gap-8 relative"> + <span class="absolute text-5xl -ml-[1em]">🕯️</span> + <div> + <h1 class="text-5xl font-bold">Candle Segment Anything</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + Zero-shot image segmentation with + <a + href="https://segment-anything.com" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >Segment Anything Model (SAM)</a + > + and + <a + href="https://github.com/ChaoningZhang/MobileSAM" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >MobileSAM </a + >. It runs in the browser with a WASM runtime built with + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + >Candle + </a> + </p> + </div> + <div> + <label for="model" class="font-medium">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light" + > + <option value="sam_mobile_tiny" selected> + Mobile SAM Tiny (40.6 MB) + </option> + <option value="sam_base">SAM Base (375 MB)</option> + </select> + </div> + <div> + <p class="text-xs italic max-w-lg"> + <b>Note:</b> + The model's first run may take a few seconds as it loads and caches + the model in the browser, and then creates the image embeddings. Any + subsequent clicks on points will be significantly faster. + </p> + </div> + <div class="relative max-w-lg"> + <div class="flex justify-between items-center"> + <div class="px-2 rounded-md inline text-xs"> + <span id="output-status" class="m-auto font-light"></span> + </div> + <button + id="clear-btn" + class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible" + > + <svg + class="" + xmlns="http://www.w3.org/2000/svg" + viewBox="0 0 13 12" + height="1em" + > + <path + d="M1.6.7 12 11.1M12 .7 1.6 11.1" + stroke="#2E3036" + stroke-width="2" + /> + </svg> + Clear image + </button> + </div> + <div + id="drop-area" + class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden" + > + <div + id="drop-buttons" + class="flex flex-col items-center justify-center space-y-1 text-center relative z-10" + > + <svg + width="25" + height="25" + viewBox="0 0 25 25" + fill="none" + xmlns="http://www.w3.org/2000/svg" + > + <path + d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z" + fill="#000" + /> + </svg> + <div class="flex text-sm text-gray-600"> + <label + for="file-upload" + class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700" + > + <span>Drag and drop your image here</span> + <span class="block text-xs">or</span> + <span class="block text-xs">Click to upload</span> + </label> + </div> + <input + id="file-upload" + name="file-upload" + type="file" + class="sr-only" + /> + </div> + <canvas id="canvas" class="absolute w-full"></canvas> + <canvas + id="mask" + class="pointer-events-none absolute w-full" + ></canvas> + </div> + <div class="text-right py-2"> + <button + id="share-btn" + class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible" + > + <img + src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" + /> + </button> + </div> + </div> + <div> + <div + class="flex gap-3 items-center overflow-x-scroll" + id="image-select" + > + <h3 class="font-medium">Examples:</h3> + + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg" + class="cursor-pointer w-24 h-24 object-cover" + /> + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg" + class="cursor-pointer w-24 h-24 object-cover" + /> + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg" + class="cursor-pointer w-24 h-24 object-cover" + /> + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js new file mode 100644 index 00000000..c1a152ef --- /dev/null +++ b/candle-wasm-examples/segment-anything/samWorker.js @@ -0,0 +1,155 @@ +//load the candle SAM Model wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url, cacheModel = true) { + if (!cacheModel) + return new Uint8Array(await (await fetch(url)).arrayBuffer()); + const cacheName = "sam-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class SAMModel { + static instance = {}; + // keep current image embeddings state + static imageArrayHash = {}; + // Add a new property to hold the current modelID + static currentModelID = null; + + static async getInstance(modelURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ + status: "loading", + message: `Loading Model ${modelID}`, + }); + const weightsArrayU8 = await fetchArrayBuffer(modelURL); + this.instance[modelID] = new Model( + weightsArrayU8, + /tiny|mobile/.test(modelID) + ); + } else { + self.postMessage({ status: "loading", message: "Model Already Loaded" }); + } + // Set the current modelID to the modelID that was passed in + this.currentModelID = modelID; + return this.instance[modelID]; + } + + // Remove the modelID parameter from setImageEmbeddings + static setImageEmbeddings(imageArrayU8) { + // check if image embeddings are already set for this image and model + const imageArrayHash = this.getSimpleHash(imageArrayU8); + if ( + this.imageArrayHash[this.currentModelID] === imageArrayHash && + this.instance[this.currentModelID] + ) { + self.postMessage({ + status: "embedding", + message: "Embeddings Already Set", + }); + return; + } + this.imageArrayHash[this.currentModelID] = imageArrayHash; + this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); + self.postMessage({ status: "embedding", message: "Embeddings Set" }); + } + + static getSimpleHash(imageArrayU8) { + // get simple hash of imageArrayU8 + let imageArrayHash = 0; + for (let i = 0; i < imageArrayU8.length; i += 100) { + imageArrayHash ^= imageArrayU8[i]; + } + return imageArrayHash.toString(16); + } +} + +async function createImageCanvas( + { mask_shape, mask_data }, // mask + { original_width, original_height, width, height } // original image +) { + const [_, __, shape_width, shape_height] = mask_shape; + const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask + const maskCtx = maskCanvas.getContext("2d"); + const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size + const ctx = canvas.getContext("2d"); + + const imageData = maskCtx.createImageData( + maskCanvas.width, + maskCanvas.height + ); + const data = imageData.data; + + for (let p = 0; p < data.length; p += 4) { + data[p] = 0; + data[p + 1] = 0; + data[p + 2] = 0; + data[p + 3] = mask_data[p / 4] * 255; + } + maskCtx.putImageData(imageData, 0, 0); + + let sx, sy; + if (original_height < original_width) { + sy = original_height / original_width; + sx = 1; + } else { + sy = 1; + sx = original_width / original_height; + } + ctx.drawImage( + maskCanvas, + 0, + 0, + maskCanvas.width * sx, + maskCanvas.height * sy, + 0, + 0, + original_width, + original_height + ); + + const blob = await canvas.convertToBlob(); + return URL.createObjectURL(blob); +} + +self.addEventListener("message", async (event) => { + const { modelURL, modelID, imageURL, points } = event.data; + try { + self.postMessage({ status: "loading", message: "Starting SAM" }); + const sam = await SAMModel.getInstance(modelURL, modelID); + + self.postMessage({ status: "loading", message: "Loading Image" }); + const imageArrayU8 = await fetchArrayBuffer(imageURL, false); + + self.postMessage({ status: "embedding", message: "Creating Embeddings" }); + SAMModel.setImageEmbeddings(imageArrayU8); + if (!points) { + // no points only do the embeddings + self.postMessage({ + status: "complete-embedding", + message: "Embeddings Complete", + }); + return; + } + + self.postMessage({ status: "segmenting", message: "Segmenting" }); + const { mask, image } = sam.mask_for_point(points.x, points.y); + const maskDataURL = await createImageCanvas(mask, image); + // Send the segment back to the main thread as JSON + self.postMessage({ + status: "complete", + message: "Segmentation Complete", + output: { maskURL: maskDataURL }, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs new file mode 100644 index 00000000..5140b979 --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -0,0 +1,140 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_wasm_example_sam as sam; +use wasm_bindgen::prelude::*; + +#[allow(unused)] +struct Embeddings { + original_width: u32, + original_height: u32, + width: u32, + height: u32, + data: Tensor, +} + +#[wasm_bindgen] +pub struct Model { + sam: sam::Sam, + embeddings: Option<Embeddings>, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> { + console_error_panic_hook::set_once(); + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let sam = if use_tiny { + sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; + Ok(Self { + sam, + embeddings: None, + }) + } + + pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> { + sam::console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let (original_height, original_width) = (image.height(), image.width()); + let (height, width) = (original_height, original_width); + let resize_longest = sam::IMAGE_SIZE as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + let image_t = { + let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let data = self.sam.embeddings(&image_t)?; + self.embeddings = Some(Embeddings { + original_width, + original_height, + width, + height, + data, + }); + Ok(()) + } + + // x and y have to be between 0 and 1 + pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> { + if !(0. ..=1.).contains(&x) { + Err(JsError::new(&format!( + "x has to be between 0 and 1, got {x}" + )))? + } + if !(0. ..=1.).contains(&y) { + Err(JsError::new(&format!( + "y has to be between 0 and 1, got {y}" + )))? + } + let embeddings = match &self.embeddings { + None => Err(JsError::new("image embeddings have not been set"))?, + Some(embeddings) => embeddings, + }; + let (mask, iou_predictions) = self.sam.forward_for_embeddings( + &embeddings.data, + embeddings.height as usize, + embeddings.width as usize, + Some((x, y)), + false, + )?; + let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0]; + let mask_shape = mask.dims().to_vec(); + let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; + let mask = Mask { + iou, + mask_shape, + mask_data, + }; + let image = Image { + original_width: embeddings.original_width, + original_height: embeddings.original_height, + width: embeddings.width, + height: embeddings.height, + }; + Ok(serde_wasm_bindgen::to_value(&MaskImage { mask, image })?) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct Mask { + iou: f32, + mask_shape: Vec<usize>, + mask_data: Vec<u8>, +} +#[derive(serde::Serialize, serde::Deserialize)] +struct Image { + original_width: u32, + original_height: u32, + width: u32, + height: u32, +} +#[derive(serde::Serialize, serde::Deserialize)] +struct MaskImage { + mask: Mask, + image: Image, +} + +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs new file mode 100644 index 00000000..0f4f96fd --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/lib.rs @@ -0,0 +1,19 @@ +use candle_transformers::models::segment_anything::sam; +use wasm_bindgen::prelude::*; + +pub use sam::{Sam, IMAGE_SIZE}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 47e7e094..8f1df531 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.2.1" } +candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.3" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/lib-example.html b/candle-wasm-examples/whisper/lib-example.html index a8c49785..3cfd87a7 100644 --- a/candle-wasm-examples/whisper/lib-example.html +++ b/candle-wasm-examples/whisper/lib-example.html @@ -6,7 +6,7 @@ <body></body> </html> -<!doctype html> +<!DOCTYPE html> <html> <head> <meta charset="UTF-8" /> @@ -51,18 +51,21 @@ mel_filtersURL, audioURL, }); - whisperWorker.addEventListener("message", (event) => { + function messageHandler(event) { console.log(event.data); if ("status" in event.data) { updateStatus(event.data); } if ("error" in event.data) { + whisperWorker.removeEventListener("message", messageHandler); reject(new Error(event.data.error)); } if (event.data.status === "complete") { + whisperWorker.removeEventListener("message", messageHandler); resolve(event.data); } - }); + } + whisperWorker.addEventListener("message", messageHandler); }); } @@ -141,7 +144,9 @@ const { output } = result; const text = output.map((segment) => segment.dr.text).join(" "); console.log(text); - document.getElementById("output").textContent = text; + document.querySelector("#output-status").hidden = true; + document.querySelector("#output-generation").hidden = false; + document.querySelector("#output-generation").textContent = text; }) .catch((error) => { console.error(error); @@ -295,18 +300,21 @@ <button id="detect" disabled - class="bg-orange-900 hover:bg-orange-800 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:cursor-not-allowed" + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" > Transcribe Audio </button> </div> <div> <h3 class="font-medium">Transcription:</h3> - <div - id="output" - class="min-h-[100px] bg-slate-500 text-white p-4 rounded-md" - ></div> + class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2" + > + <p hidden id="output-generation" class="grid-rows-2"></p> + <span id="output-status" class="m-auto font-light" + >No transcription results yet</span + > + </div> </div> </main> </body> diff --git a/candle-wasm-examples/whisper/whisperWorker.js b/candle-wasm-examples/whisper/whisperWorker.js index 2598adde..d2ad8e0b 100644 --- a/candle-wasm-examples/whisper/whisperWorker.js +++ b/candle-wasm-examples/whisper/whisperWorker.js @@ -2,16 +2,17 @@ import init, { Decoder } from "./build/m.js"; async function fetchArrayBuffer(url) { - const res = await fetch(url, { - cache: "force-cache", - headers: { - "Cache-Control": "public, max-age=31536000", - }, - }); - const data = await res.arrayBuffer(); - return new Uint8Array(data); + const cacheName = "whisper-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); } - class Whisper { static instance = {}; // Retrieve the Whisper model. When called for the first time, diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index b4daf6e6..71ef8049 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.2.1" } +candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.3" } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-wasm-examples/yolo/lib-example.html b/candle-wasm-examples/yolo/lib-example.html index bab2ec13..d9f18975 100644 --- a/candle-wasm-examples/yolo/lib-example.html +++ b/candle-wasm-examples/yolo/lib-example.html @@ -6,7 +6,7 @@ <body></body> </html> -<!doctype html> +<!DOCTYPE html> <html> <head> <meta charset="UTF-8" /> @@ -145,6 +145,10 @@ } }); + document.querySelector("#clear-btn").addEventListener("click", () => { + drawImageCanvas(); + }); + function drawImageCanvas(imgURL) { const canvas = document.querySelector("#canvas"); const canvasResult = document.querySelector("#canvas-result"); @@ -153,21 +157,28 @@ .clearRect(0, 0, canvas.width, canvas.height); const ctx = canvas.getContext("2d"); ctx.clearRect(0, 0, canvas.width, canvas.height); - document.querySelector("#share-btn").hidden = true; + document.querySelector("#share-btn").classList.add("invisible"); + document.querySelector("#clear-btn").classList.add("invisible"); + document.querySelector("#detect").disabled = true; + hasImage = false; + canvas.parentElement.style.height = "auto"; - const img = new Image(); - img.crossOrigin = "anonymous"; + if (imgURL && imgURL !== "") { + const img = new Image(); + img.crossOrigin = "anonymous"; - img.onload = () => { - canvas.width = img.width; - canvas.height = img.height; - ctx.drawImage(img, 0, 0); + img.onload = () => { + canvas.width = img.width; + canvas.height = img.height; + ctx.drawImage(img, 0, 0); - canvas.parentElement.style.height = canvas.offsetHeight + "px"; - hasImage = true; - document.querySelector("#detect").disabled = false; - }; - img.src = imgURL; + canvas.parentElement.style.height = canvas.offsetHeight + "px"; + hasImage = true; + document.querySelector("#detect").disabled = false; + document.querySelector("#clear-btn").classList.remove("invisible"); + }; + img.src = imgURL; + } } async function classifyImage( @@ -188,17 +199,21 @@ confidence, iou_threshold, }); - yoloWorker.addEventListener("message", (event) => { + function handleMessage(event) { + console.log("message", event.data); if ("status" in event.data) { updateStatus(event.data.status); } if ("error" in event.data) { + yoloWorker.removeEventListener("message", handleMessage); reject(new Error(event.data.error)); } if (event.data.status === "complete") { + yoloWorker.removeEventListener("message", handleMessage); resolve(event.data); } - }); + } + yoloWorker.addEventListener("message", handleMessage); }); } // add event listener to detect button @@ -310,7 +325,7 @@ button.classList.add("bg-blue-950"); button.classList.remove("bg-blue-700"); button.textContent = "Predict"; - document.querySelector("#share-btn").hidden = false; + document.querySelector("#share-btn").classList.remove("invisible"); } } document.querySelector("#share-btn").addEventListener("click", () => { @@ -372,8 +387,37 @@ <option value="yolov8x_pose">yolov8x_pose (139 MB)</option> </select> </div> + <div> + <button + id="detect" + disabled + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" + > + Predict + </button> + </div> <!-- drag and drop area --> - <div class="relative"> + <div class="relative max-w-lg"> + <div class="py-1"> + <button + id="clear-btn" + class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center ml-auto invisible" + > + <svg + class="" + xmlns="http://www.w3.org/2000/svg" + viewBox="0 0 13 12" + height="1em" + > + <path + d="M1.6.7 12 11.1M12 .7 1.6 11.1" + stroke="#2E3036" + stroke-width="2" + /> + </svg> + Clear image + </button> + </div> <div id="drop-area" class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative aspect-video w-full overflow-hidden" @@ -422,8 +466,7 @@ <div class="text-right py-2"> <button id="share-btn" - hidden - class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50" + class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible" > <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" @@ -432,7 +475,10 @@ </div> </div> <div> - <div class="flex gap-3 items-center" id="image-select"> + <div + class="flex gap-3 items-center overflow-x-scroll" + id="image-select" + > <h3 class="font-medium">Examples:</h3> <img @@ -489,15 +535,6 @@ > </div> </div> - <div> - <button - id="detect" - disabled - class="bg-blue-950 hover:bg-blue-700 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:hover:bg-blue-950" - > - Predict - </button> - </div> </main> </body> </html> diff --git a/candle-wasm-examples/yolo/yoloWorker.js b/candle-wasm-examples/yolo/yoloWorker.js index 93097372..8b5ef8b9 100644 --- a/candle-wasm-examples/yolo/yoloWorker.js +++ b/candle-wasm-examples/yolo/yoloWorker.js @@ -1,6 +1,19 @@ //load the candle yolo wasm module import init, { Model, ModelPose } from "./build/m.js"; +async function fetchArrayBuffer(url) { + const cacheName = "yolo-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} + class Yolo { static instance = {}; // Retrieve the YOLO model. When called for the first time, @@ -11,9 +24,7 @@ class Yolo { await init(); self.postMessage({ status: `loading model ${modelID}:${modelSize}` }); - const modelRes = await fetch(modelURL); - const yoloArrayBuffer = await modelRes.arrayBuffer(); - const weightsArrayU8 = new Uint8Array(yoloArrayBuffer); + const weightsArrayU8 = await fetchArrayBuffer(modelURL); if (/pose/.test(modelID)) { // if pose model, use ModelPose this.instance[modelID] = new ModelPose(weightsArrayU8, modelSize); |