summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--CHANGELOG.md47
-rw-r--r--Cargo.toml11
-rw-r--r--README.md117
-rw-r--r--candle-book/Cargo.toml10
-rw-r--r--candle-book/src/SUMMARY.md4
-rw-r--r--candle-book/src/error_manage.md2
-rw-r--r--candle-book/src/guide/hello_world.md40
-rw-r--r--candle-book/src/inference/inference.md (renamed from candle-book/src/inference/README.md)0
-rw-r--r--candle-book/src/training/training.md (renamed from candle-book/src/training/README.md)0
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/examples/cpu_benchmarks.rs166
-rw-r--r--candle-core/examples/tensor-tools.rs53
-rw-r--r--candle-core/src/accelerate.rs32
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs12
-rw-r--r--candle-core/src/cpu/erf.rs763
-rw-r--r--candle-core/src/cpu/kernels.rs95
-rw-r--r--candle-core/src/cpu/mod.rs1
-rw-r--r--candle-core/src/cpu_backend.rs265
-rw-r--r--candle-core/src/cuda_backend.rs223
-rw-r--r--candle-core/src/cudnn.rs6
-rw-r--r--candle-core/src/dtype.rs11
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/error.rs6
-rw-r--r--candle-core/src/indexer.rs39
-rw-r--r--candle-core/src/lib.rs15
-rw-r--r--candle-core/src/op.rs91
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/src/quantized/mod.rs2
-rw-r--r--candle-core/src/safetensors.rs8
-rw-r--r--candle-core/src/scalar.rs23
-rw-r--r--candle-core/src/shape.rs205
-rw-r--r--candle-core/src/storage.rs13
-rw-r--r--candle-core/src/tensor.rs133
-rw-r--r--candle-core/tests/tensor_tests.rs51
-rw-r--r--candle-datasets/Cargo.toml4
-rw-r--r--candle-datasets/src/vision/mnist.rs10
-rw-r--r--candle-examples/Cargo.toml22
-rw-r--r--candle-examples/examples/bert/README.md44
-rw-r--r--candle-examples/examples/bert/main.rs3
-rw-r--r--candle-examples/examples/bigcode/README.md19
-rw-r--r--candle-examples/examples/bigcode/main.rs19
-rw-r--r--candle-examples/examples/dinov2/README.md19
-rw-r--r--candle-examples/examples/dinov2/main.rs283
-rw-r--r--candle-examples/examples/efficientnet/main.rs335
-rw-r--r--candle-examples/examples/falcon/README.md3
-rw-r--r--candle-examples/examples/falcon/main.rs40
-rw-r--r--candle-examples/examples/llama/main.rs9
-rw-r--r--candle-examples/examples/llama2-c/main.rs7
-rw-r--r--candle-examples/examples/llama_multiprocess/main.rs6
-rw-r--r--candle-examples/examples/musicgen/main.rs3
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs11
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs397
-rw-r--r--candle-examples/examples/quantized-t5/README.md17
-rw-r--r--candle-examples/examples/quantized-t5/main.rs214
-rw-r--r--candle-examples/examples/quantized/README.md37
-rw-r--r--candle-examples/examples/quantized/assets/aoc.gifbin0 -> 121923 bytes
-rw-r--r--candle-examples/examples/quantized/main.rs8
-rw-r--r--candle-examples/examples/segment-anything/README.md40
-rw-r--r--candle-examples/examples/segment-anything/assets/sam_merged.jpgbin0 -> 160984 bytes
-rw-r--r--candle-examples/examples/segment-anything/main.rs164
-rw-r--r--candle-examples/examples/stable-diffusion/README.md63
-rw-r--r--candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpgbin0 -> 36573 bytes
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs99
-rw-r--r--candle-examples/examples/t5/README.md25
-rw-r--r--candle-examples/examples/t5/main.rs314
-rw-r--r--candle-examples/examples/whisper/README.md39
-rw-r--r--candle-examples/examples/whisper/main.rs72
-rw-r--r--candle-examples/examples/whisper/multilingual.rs2
-rw-r--r--candle-examples/examples/wuerstchen/README.md27
-rw-r--r--candle-examples/examples/wuerstchen/assets/cat.jpgbin0 -> 38638 bytes
-rw-r--r--candle-examples/examples/wuerstchen/main.rs396
-rw-r--r--candle-examples/examples/yolo-v3/main.rs6
-rw-r--r--candle-examples/examples/yolo-v8/README.md47
-rw-r--r--candle-examples/examples/yolo-v8/assets/bike.jpgbin0 -> 182991 bytes
-rw-r--r--candle-examples/examples/yolo-v8/assets/bike.od.jpgbin0 -> 179024 bytes
-rw-r--r--candle-examples/examples/yolo-v8/assets/bike.pose.jpgbin0 -> 193397 bytes
-rw-r--r--candle-examples/examples/yolo-v8/main.rs14
-rw-r--r--candle-examples/src/lib.rs63
-rw-r--r--candle-flash-attn/Cargo.toml6
-rw-r--r--candle-flash-attn/build.rs45
-rw-r--r--candle-flash-attn/kernels/flash_api.cu26
-rw-r--r--candle-flash-attn/src/ffi.rs1
-rw-r--r--candle-flash-attn/src/lib.rs128
-rw-r--r--candle-kernels/Cargo.toml2
-rw-r--r--candle-kernels/build.rs11
-rw-r--r--candle-kernels/src/cast.cu10
-rw-r--r--candle-kernels/src/conv.cu158
-rw-r--r--candle-kernels/src/cuda_utils.cuh8
-rw-r--r--candle-kernels/src/reduce.cu55
-rw-r--r--candle-kernels/src/unary.cu17
-rw-r--r--candle-nn/Cargo.toml7
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs302
-rw-r--r--candle-nn/src/activation.rs13
-rw-r--r--candle-nn/src/batch_norm.rs2
-rw-r--r--candle-nn/src/conv.rs139
-rw-r--r--candle-nn/src/embedding.rs7
-rw-r--r--candle-nn/src/group_norm.rs2
-rw-r--r--candle-nn/src/layer_norm.rs12
-rw-r--r--candle-nn/src/lib.rs5
-rw-r--r--candle-nn/src/linear.rs7
-rw-r--r--candle-nn/src/loss.rs2
-rw-r--r--candle-nn/src/ops.rs154
-rw-r--r--candle-nn/src/rnn.rs4
-rw-r--r--candle-nn/src/var_builder.rs12
-rw-r--r--candle-nn/tests/batch_norm.rs4
-rw-r--r--candle-nn/tests/ops.rs10
-rw-r--r--candle-pyo3/.gitignore160
-rw-r--r--candle-pyo3/Cargo.toml5
-rw-r--r--candle-pyo3/README.md21
-rw-r--r--candle-pyo3/py_src/candle/__init__.py5
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi375
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.py5
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.pyi19
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py16
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.py12
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi70
-rw-r--r--candle-pyo3/pyproject.toml30
-rw-r--r--candle-pyo3/quant-llama.py65
-rw-r--r--candle-pyo3/src/lib.rs406
-rw-r--r--candle-pyo3/stub.py232
-rw-r--r--candle-transformers/Cargo.toml11
-rw-r--r--candle-transformers/src/generation/mod.rs79
-rw-r--r--candle-transformers/src/lib.rs1
-rw-r--r--candle-transformers/src/models/bert.rs (renamed from candle-examples/examples/bert/model.rs)0
-rw-r--r--candle-transformers/src/models/bigcode.rs (renamed from candle-examples/examples/bigcode/model.rs)0
-rw-r--r--candle-transformers/src/models/dinov2.rs279
-rw-r--r--candle-transformers/src/models/efficientnet.rs331
-rw-r--r--candle-transformers/src/models/falcon.rs (renamed from candle-examples/examples/falcon/model.rs)11
-rw-r--r--candle-transformers/src/models/llama.rs (renamed from candle-examples/examples/llama/model.rs)2
-rw-r--r--candle-transformers/src/models/mod.rs14
-rw-r--r--candle-transformers/src/models/quantized_llama.rs (renamed from candle-examples/examples/quantized/model.rs)2
-rw-r--r--candle-transformers/src/models/quantized_t5.rs884
-rw-r--r--candle-transformers/src/models/segment_anything/image_encoder.rs483
-rw-r--r--candle-transformers/src/models/segment_anything/mask_decoder.rs239
-rw-r--r--candle-transformers/src/models/segment_anything/mod.rs100
-rw-r--r--candle-transformers/src/models/segment_anything/prompt_encoder.rs239
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs433
-rw-r--r--candle-transformers/src/models/segment_anything/tiny_vit.rs633
-rw-r--r--candle-transformers/src/models/segment_anything/transformer.rs221
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs (renamed from candle-examples/examples/stable-diffusion/attention.rs)16
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs (renamed from candle-examples/examples/stable-diffusion/clip.rs)66
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddim.rs (renamed from candle-examples/examples/stable-diffusion/ddim.rs)17
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddpm.rs205
-rw-r--r--candle-transformers/src/models/stable_diffusion/embeddings.rs (renamed from candle-examples/examples/stable-diffusion/embeddings.rs)8
-rw-r--r--candle-transformers/src/models/stable_diffusion/mod.rs (renamed from candle-examples/examples/stable-diffusion/stable_diffusion.rs)22
-rw-r--r--candle-transformers/src/models/stable_diffusion/resnet.rs (renamed from candle-examples/examples/stable-diffusion/resnet.rs)2
-rw-r--r--candle-transformers/src/models/stable_diffusion/schedulers.rs (renamed from candle-examples/examples/stable-diffusion/schedulers.rs)0
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs (renamed from candle-examples/examples/stable-diffusion/unet_2d.rs)6
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs (renamed from candle-examples/examples/stable-diffusion/unet_2d_blocks.rs)20
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs (renamed from candle-examples/examples/stable-diffusion/utils.rs)0
-rw-r--r--candle-transformers/src/models/stable_diffusion/vae.rs (renamed from candle-examples/examples/stable-diffusion/vae.rs)17
-rw-r--r--candle-transformers/src/models/t5.rs841
-rw-r--r--candle-transformers/src/models/whisper/audio.rs (renamed from candle-examples/examples/whisper/audio.rs)10
-rw-r--r--candle-transformers/src/models/whisper/mod.rs26
-rw-r--r--candle-transformers/src/models/whisper/model.rs (renamed from candle-examples/examples/whisper/model.rs)4
-rw-r--r--candle-transformers/src/models/wuerstchen/attention_processor.rs118
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs203
-rw-r--r--candle-transformers/src/models/wuerstchen/ddpm.rs103
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs396
-rw-r--r--candle-transformers/src/models/wuerstchen/mod.rs6
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs211
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs103
-rw-r--r--candle-transformers/src/object_detection.rs (renamed from candle-examples/src/object_detection.rs)8
-rw-r--r--candle-transformers/tests/generation_tests.rs29
-rw-r--r--candle-wasm-examples/bert/Cargo.toml33
-rw-r--r--candle-wasm-examples/bert/README.md26
-rw-r--r--candle-wasm-examples/bert/bertWorker.js77
-rw-r--r--candle-wasm-examples/bert/build-lib.sh2
-rw-r--r--candle-wasm-examples/bert/lib-example.html368
-rw-r--r--candle-wasm-examples/bert/src/bin/m.rs92
-rw-r--r--candle-wasm-examples/bert/src/lib.rs20
-rw-r--r--candle-wasm-examples/bert/utils.js99
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml6
-rw-r--r--candle-wasm-examples/llama2-c/README.md47
-rw-r--r--candle-wasm-examples/llama2-c/build-lib.sh2
-rw-r--r--candle-wasm-examples/llama2-c/lib-example.html359
-rw-r--r--candle-wasm-examples/llama2-c/llama2cWorker.js106
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs23
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs18
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs18
-rw-r--r--candle-wasm-examples/segment-anything/Cargo.toml30
-rw-r--r--candle-wasm-examples/segment-anything/README.md26
-rw-r--r--candle-wasm-examples/segment-anything/build-lib.sh2
-rw-r--r--candle-wasm-examples/segment-anything/lib-example.html407
-rw-r--r--candle-wasm-examples/segment-anything/samWorker.js155
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs140
-rw-r--r--candle-wasm-examples/segment-anything/src/lib.rs19
-rw-r--r--candle-wasm-examples/whisper/Cargo.toml4
-rw-r--r--candle-wasm-examples/whisper/lib-example.html26
-rw-r--r--candle-wasm-examples/whisper/whisperWorker.js19
-rw-r--r--candle-wasm-examples/yolo/Cargo.toml4
-rw-r--r--candle-wasm-examples/yolo/lib-example.html95
-rw-r--r--candle-wasm-examples/yolo/yoloWorker.js17
195 files changed, 14760 insertions, 1790 deletions
diff --git a/.gitignore b/.gitignore
index 2748d37e..d0a8c320 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,6 +23,7 @@ flamegraph.svg
*.dylib
*.so
*.swp
+*.swo
trace-*.json
candle-wasm-examples/*/build
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a52429cf..df9574d5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,13 +1,58 @@
# Changelog
This documents the main changes to the `candle` crate.
-## v0.2.1 - Unreleased
+## v0.2.3 - Unreleased
### Added
### Modified
+
+## v0.2.2 - 2023-09-18
+
+### Added
+- Support for `top_p` sampling
+ [819](https://github.com/huggingface/candle/pull/819).
+- T5 model including decoding
+ [864](https://github.com/huggingface/candle/pull/864).
+- 1-d upsampling
+ [839](https://github.com/huggingface/candle/pull/839).
+
+### Modified
+- Bugfix for conv2d
+ [820](https://github.com/huggingface/candle/pull/820).
+- Support tensor based indexing using `.i`
+ [842](https://github.com/huggingface/candle/pull/842).
+
+## v0.2.1 - 2023-09-11
+
+### Added
+- Add some RNNs (GRU and LSTM) in `candle-nn`
+ [674](https://github.com/huggingface/candle/pull/674),
+ [688](https://github.com/huggingface/candle/pull/688).
+- gguf v2 support
+ [725](https://github.com/huggingface/candle/pull/725).
+- Quantized llama example in Python using the pyo3 api
+ [716](https://github.com/huggingface/candle/pull/716).
+- `candle-nn` layer for conv2d-transposed
+ [760](https://github.com/huggingface/candle/pull/760).
+- Add the Segment-Anything Model (SAM) as an example
+ [773](https://github.com/huggingface/candle/pull/773).
+- TinyViT backbone for the segemnt anything example
+ [787](https://github.com/huggingface/candle/pull/787).
+- Shape with holes support
+ [770](https://github.com/huggingface/candle/pull/770).
+
+### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
+- Interactive mode for the quantized model
+ [690](https://github.com/huggingface/candle/pull/690).
+- Faster softmax operation
+ [747](https://github.com/huggingface/candle/pull/747).
+- Faster convolution operations on CPU and CUDA via im2col
+ [802](https://github.com/huggingface/candle/pull/802).
+- Moving some models to a more central location
+ [796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30
diff --git a/Cargo.toml b/Cargo.toml
index ce41876a..6cbbf00f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,17 +8,16 @@ members = [
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
+ "candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
+ "candle-wasm-examples/bert",
]
-exclude = [
- "candle-flash-attn",
- "candle-kernels",
-]
+exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
-version = "0.2.1"
+version = "0.2.3"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@@ -33,7 +32,7 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
-gemm = { version = "0.15.6", package = "candle-gemm" }
+gemm = { version = "0.16.0", package = "candle-gemm" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
diff --git a/README.md b/README.md
index 140382c7..93a47082 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,9 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
-[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
+[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
+[Segment
+Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started
@@ -45,37 +47,54 @@ For more advanced examples, please have a look at the following section.
## Check out our examples
-Check out our [examples](./candle-examples/examples/):
+These online demos run entirely in your browser:
+- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
+ object recognition.
+- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
+- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
+- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
+
+We also provide a some command line based examples using state of the art models:
-- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
-- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
-- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
- image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
-- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
- using self-supervision (can be used for imagenet classification, depth
- evaluation, segmentation).
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
+
+<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
+
+- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
+ image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
+
+<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
+
+- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
+ image generative model.
+
+<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
+
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
-Run them using the following commands:
-```
-cargo run --example whisper --release
-cargo run --example llama --release
-cargo run --example falcon --release
-cargo run --example bert --release
-cargo run --example bigcode --release
-cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
-cargo run --example dinov2 --release -- --image path/to/myinput.jpg
+
+<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
+- [segment-anything](./candle-examples/examples/segment-anything/): image
+ segmentation model with prompt.
+
+<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
+
+- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
+- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
+- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
+ using self-supervision (can be used for imagenet classification, depth
+ evaluation, segmentation).
+
+Run them using commands like:
+```
cargo run --example quantized --release
-cargo run --example yolo-v3 --release -- myimage.jpg
-cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
```
In order to use **CUDA** add `--features cuda` to the example command line. If
@@ -85,7 +104,8 @@ There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
-[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
+[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
+[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a
test server:
@@ -98,6 +118,15 @@ trunk serve --release --port 8081
And then head over to
[http://localhost:8081/](http://localhost:8081/).
+<!--- ANCHOR: useful_libraries --->
+
+## Useful Libraries
+- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
+
+If you have an addition to this list, please submit a pull request.
+
+<!--- ANCHOR_END: useful_libraries --->
+
<!--- ANCHOR: features --->
## Features
@@ -110,10 +139,21 @@ And then head over to
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- - LLMs: LLaMA v1 and v2, Falcon, StarCoder.
+ - Language Models.
+ - LLaMA v1 and v2.
+ - Falcon.
+ - StarCoder.
+ - T5.
+ - Bert.
- Whisper (multi-lingual support).
- - Stable Diffusion.
- - Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
+ - Stable Diffusion v1.5, v2.1, XL v1.0.
+ - Wurstchen v2.
+ - Computer Vision Models.
+ - DINOv2.
+ - EfficientNet.
+ - yolo-v3.
+ - yolo-v8.
+ - Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
@@ -243,6 +283,35 @@ authentication token. See issue
git submodule update --init
```
+#### Compiling with flash-attention fails
+
+```
+/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
+```
+
+This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
+```
+env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
+```
+
+#### Linking error on windows when running rustdoc or mdbook tests
+
+```
+Couldn't compile the test.
+---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
+error: linking with `link.exe` failed: exit code: 1181
+//very long chain of linking
+ = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
+```
+
+Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
+
+```
+mdbook test candle-book -L .\target\debug\deps\ `
+-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
+-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
+```
+
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
index 320fb887..8ec92e87 100644
--- a/candle-book/Cargo.toml
+++ b/candle-book/Cargo.toml
@@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
-candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index 1d05568a..59831af2 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -10,10 +10,10 @@
# Reference Guide
-- [Running a model](inference/README.md)
+- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
-- [Training](training/README.md)
+- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
diff --git a/candle-book/src/error_manage.md b/candle-book/src/error_manage.md
index c1a16bd9..0623e0e3 100644
--- a/candle-book/src/error_manage.md
+++ b/candle-book/src/error_manage.md
@@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
-Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
+Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
diff --git a/candle-book/src/guide/hello_world.md b/candle-book/src/guide/hello_world.md
index fc4af0e1..b5b8d7b4 100644
--- a/candle-book/src/guide/hello_world.md
+++ b/candle-book/src/guide/hello_world.md
@@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
-use candle_core::{DType, Device, Result, Tensor};
+use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
@@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
- let first = Tensor::zeros((784, 100), DType::F32, &device)?;
- let second = Tensor::zeros((100, 10), DType::F32, &device)?;
+ let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
-# use candle_core::{DType, Device, Result, Tensor};
+# use candle_core::{Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
@@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust
# extern crate candle_core;
-# use candle_core::{DType, Device, Result, Tensor};
+# use candle_core::{Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
@@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
- let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
- let bias = Tensor::zeros((100, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear{weight, bias};
- let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
- let bias = Tensor::zeros((10, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
@@ -146,7 +146,7 @@ And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
-use candle_core::{DType, Device, Result, Tensor};
+use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
@@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
- let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
- let bias = Tensor::zeros((100, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
- let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
- let bias = Tensor::zeros((10, ), DType::F32, &device)?;
+ let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
- let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics:
-- [For PyTorch users](./guide/cheatsheet.md)
-- [Running existing models](./inference/README.md)
-- [Training models](./training/README.md)
+- [For PyTorch users](../guide/cheatsheet.md)
+- [Running existing models](../inference/inference.md)
+- [Training models](../training/training.md)
diff --git a/candle-book/src/inference/README.md b/candle-book/src/inference/inference.md
index 1b75a310..1b75a310 100644
--- a/candle-book/src/inference/README.md
+++ b/candle-book/src/inference/inference.md
diff --git a/candle-book/src/training/README.md b/candle-book/src/training/training.md
index d68a917e..d68a917e 100644
--- a/candle-book/src/training/README.md
+++ b/candle-book/src/training/training.md
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index e7213919..7af9b6fa 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
+candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
deleted file mode 100644
index 13175ac1..00000000
--- a/candle-core/examples/cpu_benchmarks.rs
+++ /dev/null
@@ -1,166 +0,0 @@
-/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
-#[cfg(feature = "mkl")]
-extern crate intel_mkl_src;
-
-#[cfg(feature = "accelerate")]
-extern crate accelerate_src;
-
-use candle_core::quantized::GgmlType;
-use candle_core::{Device, Result, Tensor, D};
-use clap::{Parser, Subcommand};
-
-fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
- let dim = dim.to_index(xs.shape(), "softmax")?;
- let max = xs.max_keepdim(dim)?;
- let diff = xs.broadcast_sub(&max)?;
- let num = diff.exp()?;
- let den = num.sum_keepdim(dim)?;
- num.broadcast_div(&den)
-}
-
-trait Benchmark {
- type PreProcessData;
- type RunResult;
-
- fn preprocess() -> Result<Self::PreProcessData>;
- fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
-
- const ITERS: usize;
-}
-
-// Conv1d example as used in whisper.
-struct Conv1d;
-impl Benchmark for Conv1d {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result<Self::PreProcessData> {
- let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
- let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
- Ok((inp, w))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv1d(&d.1, 0, 1, 1, 1)
- }
-
- const ITERS: usize = 5;
-}
-
-// Conv2d example as used in stable-diffusion.
-struct Conv2d;
-impl Benchmark for Conv2d {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
-
- fn preprocess() -> Result<Self::PreProcessData> {
- let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
- let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
- Ok((inp, w))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv2d(&d.1, 0, 1, 1, 1)
- }
-
- const ITERS: usize = 1;
-}
-
-struct Matmul;
-impl Benchmark for Matmul {
- type PreProcessData = (Tensor, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result<Self::PreProcessData> {
- let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
- let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
- Ok((lhs, rhs))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.matmul(&d.1)
- }
-
- const ITERS: usize = 100;
-}
-
-// This benchmark is similar to:
-// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
-struct QMatMul;
-impl Benchmark for QMatMul {
- type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
- type RunResult = Tensor;
- fn preprocess() -> Result<Self::PreProcessData> {
- let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
- let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
- let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
- let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
- Ok((mm, arg))
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.forward(&d.1)
- }
-
- const ITERS: usize = 100;
-}
-
-struct Softmax;
-impl Benchmark for Softmax {
- type PreProcessData = Tensor;
- type RunResult = Tensor;
- fn preprocess() -> Result<Self::PreProcessData> {
- // Typical whisper tiny size.
- let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
- Ok(x)
- }
-
- fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- softmax(d, D::Minus1)
- }
-
- const ITERS: usize = 100;
-}
-
-fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
- use std::hint::black_box;
-
- let iters = iters.unwrap_or(B::ITERS);
- let d = B::preprocess()?;
- let start = std::time::Instant::now();
- for _iter in 0..iters {
- let _res = black_box(B::run_one(black_box(&d))?);
- }
- println!("{:?}", start.elapsed() / iters as u32);
- Ok(())
-}
-
-#[derive(Subcommand, Debug, Clone)]
-enum Task {
- Conv1d,
- Conv2d,
- Matmul,
- Qmatmul,
- Softmax,
-}
-
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-pub struct Args {
- /// The benchmark to be run.
- #[command(subcommand)]
- task: Task,
-
- #[arg(long)]
- iters: Option<usize>,
-}
-
-fn main() -> Result<()> {
- let args = Args::parse();
- match args.task {
- Task::Conv1d => run::<Conv1d>(args.iters)?,
- Task::Conv2d => run::<Conv2d>(args.iters)?,
- Task::Matmul => run::<Matmul>(args.iters)?,
- Task::Softmax => run::<Softmax>(args.iters)?,
- Task::Qmatmul => run::<QMatMul>(args.iters)?,
- }
- Ok(())
-}
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 2bc1fa2e..c3459004 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(())
}
+fn run_quantize_safetensors(
+ in_file: std::path::PathBuf,
+ out_file: std::path::PathBuf,
+ q: Quantization,
+) -> Result<()> {
+ let mut out_file = std::fs::File::create(out_file)?;
+ let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
+ println!("tensors: {}", tensors.len());
+
+ let quantize_fn = match q {
+ Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
+ Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
+ Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
+ Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
+ Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
+ Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
+ Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
+ Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
+ Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
+ Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
+ Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
+ Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
+ Quantization::F16 => QTensor::quantize::<half::f16>,
+ Quantization::F32 => QTensor::quantize::<f32>,
+ };
+
+ let qtensors = tensors
+ .into_par_iter()
+ .map(|(name, tensor)| {
+ println!(" quantizing {name} {tensor:?}");
+ let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
+ let tensor = if should_quantize {
+ quantize_fn(&tensor)?
+ } else {
+ QTensor::quantize::<f32>(&tensor)?
+ };
+ Ok((name, tensor))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let qtensors = qtensors
+ .iter()
+ .map(|(k, v)| (k.as_str(), v))
+ .collect::<Vec<_>>();
+ gguf_file::write(&mut out_file, &[], &qtensors)?;
+ Ok(())
+}
+
fn run_quantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
+ if let Some(extension) = in_file.extension() {
+ if extension == "safetensors" {
+ return run_quantize_safetensors(in_file, out_file, q);
+ }
+ }
+
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs
index 87e0ee8d..1cb34e19 100644
--- a/candle-core/src/accelerate.rs
+++ b/candle-core/src/accelerate.rs
@@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
+#[inline]
+pub fn vs_tanh_inplace(y: &mut [f32]) {
+ unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vd_tanh_inplace(y: &mut [f64]) {
+ unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vs_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vd_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 67a08714..03a07434 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index d2099df7..a2548198 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -91,13 +91,14 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
- | Op::Reduce(node, _, _)
+ | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -111,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@@ -262,6 +264,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
+ Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest1d",
+ })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
@@ -437,6 +442,10 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
+ Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
+ Op::Unary(_, UnaryOp::GeluErf) => {
+ Err(Error::BackwardNotSupported { op: "gelu-erf" })?
+ }
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
@@ -517,6 +526,7 @@ impl Tensor {
}
}
+#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs
new file mode 100644
index 00000000..ca6be53f
--- /dev/null
+++ b/candle-core/src/cpu/erf.rs
@@ -0,0 +1,763 @@
+#![allow(clippy::excessive_precision)]
+// Code taken from https://github.com/statrs-dev/statrs
+//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
+//! related functions
+
+mod evaluate {
+ //! Provides functions that don't have a numerical solution and must
+ //! be solved computationally (e.g. evaluation of a polynomial)
+
+ /// evaluates a polynomial at `z` where `coeff` are the coeffecients
+ /// to a polynomial of order `k` where `k` is the length of `coeff` and the
+ /// coeffecient
+ /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
+ /// `2z^2 - z + 3`
+ ///
+ /// # Remarks
+ ///
+ /// Returns 0 for a 0 length coefficient slice
+ pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
+ let n = coeff.len();
+ if n == 0 {
+ return 0.0;
+ }
+
+ let mut sum = *coeff.last().unwrap();
+ for c in coeff[0..n - 1].iter().rev() {
+ sum = *c + z * sum;
+ }
+ sum
+ }
+}
+use std::f64;
+
+/// `erf` calculates the error function at `x`.
+pub fn erf(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x >= 0.0 && x.is_infinite() {
+ 1.0
+ } else if x <= 0.0 && x.is_infinite() {
+ -1.0
+ } else if x == 0. {
+ 0.0
+ } else {
+ erf_impl(x, false)
+ }
+}
+
+/// `erf_inv` calculates the inverse error function
+/// at `x`.
+pub fn erf_inv(x: f64) -> f64 {
+ if x == 0.0 {
+ 0.0
+ } else if x >= 1.0 {
+ f64::INFINITY
+ } else if x <= -1.0 {
+ f64::NEG_INFINITY
+ } else if x < 0.0 {
+ erf_inv_impl(-x, 1.0 + x, -1.0)
+ } else {
+ erf_inv_impl(x, 1.0 - x, 1.0)
+ }
+}
+
+/// `erfc` calculates the complementary error function
+/// at `x`.
+pub fn erfc(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x == f64::INFINITY {
+ 0.0
+ } else if x == f64::NEG_INFINITY {
+ 2.0
+ } else {
+ erf_impl(x, true)
+ }
+}
+
+/// `erfc_inv` calculates the complementary inverse
+/// error function at `x`.
+pub fn erfc_inv(x: f64) -> f64 {
+ if x <= 0.0 {
+ f64::INFINITY
+ } else if x >= 2.0 {
+ f64::NEG_INFINITY
+ } else if x > 1.0 {
+ erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
+ } else {
+ erf_inv_impl(1.0 - x, x, 1.0)
+ }
+}
+
+// **********************************************************
+// ********** Coefficients for erf_impl polynomial **********
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_impl`
+/// in the interval [1e-10, 0.5].
+const ERF_IMPL_AN: &[f64] = &[
+ 0.00337916709551257388990745,
+ -0.00073695653048167948530905,
+ -0.374732337392919607868241,
+ 0.0817442448733587196071743,
+ -0.0421089319936548595203468,
+ 0.0070165709512095756344528,
+ -0.00495091255982435110337458,
+ 0.000871646599037922480317225,
+];
+
+/// Polynomial coefficients for a denominator of `erf_impl`
+/// in the interval [1e-10, 0.5]
+const ERF_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.218088218087924645390535,
+ 0.412542972725442099083918,
+ -0.0841891147873106755410271,
+ 0.0655338856400241519690695,
+ -0.0120019604454941768171266,
+ 0.00408165558926174048329689,
+ -0.000615900721557769691924509,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BN: &[f64] = &[
+ -0.0361790390718262471360258,
+ 0.292251883444882683221149,
+ 0.281447041797604512774415,
+ 0.125610208862766947294894,
+ 0.0274135028268930549240776,
+ 0.00250839672168065762786937,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BD: &[f64] = &[
+ 1.0,
+ 1.8545005897903486499845,
+ 1.43575803037831418074962,
+ 0.582827658753036572454135,
+ 0.124810476932949746447682,
+ 0.0113724176546353285778481,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CN: &[f64] = &[
+ -0.0397876892611136856954425,
+ 0.153165212467878293257683,
+ 0.191260295600936245503129,
+ 0.10276327061989304213645,
+ 0.029637090615738836726027,
+ 0.0046093486780275489468812,
+ 0.000307607820348680180548455,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CD: &[f64] = &[
+ 1.0,
+ 1.95520072987627704987886,
+ 1.64762317199384860109595,
+ 0.768238607022126250082483,
+ 0.209793185936509782784315,
+ 0.0319569316899913392596356,
+ 0.00213363160895785378615014,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DN: &[f64] = &[
+ -0.0300838560557949717328341,
+ 0.0538578829844454508530552,
+ 0.0726211541651914182692959,
+ 0.0367628469888049348429018,
+ 0.00964629015572527529605267,
+ 0.00133453480075291076745275,
+ 0.778087599782504251917881e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.75967098147167528287343,
+ 1.32883571437961120556307,
+ 0.552528596508757581287907,
+ 0.133793056941332861912279,
+ 0.0179509645176280768640766,
+ 0.00104712440019937356634038,
+ -0.106640381820357337177643e-7,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_EN: &[f64] = &[
+ -0.0117907570137227847827732,
+ 0.014262132090538809896674,
+ 0.0202234435902960820020765,
+ 0.00930668299990432009042239,
+ 0.00213357802422065994322516,
+ 0.00025022987386460102395382,
+ 0.120534912219588189822126e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_ED: &[f64] = &[
+ 1.0,
+ 1.50376225203620482047419,
+ 0.965397786204462896346934,
+ 0.339265230476796681555511,
+ 0.0689740649541569716897427,
+ 0.00771060262491768307365526,
+ 0.000371421101531069302990367,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FN: &[f64] = &[
+ -0.00546954795538729307482955,
+ 0.00404190278731707110245394,
+ 0.0054963369553161170521356,
+ 0.00212616472603945399437862,
+ 0.000394984014495083900689956,
+ 0.365565477064442377259271e-4,
+ 0.135485897109932323253786e-5,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FD: &[f64] = &[
+ 1.0,
+ 1.21019697773630784832251,
+ 0.620914668221143886601045,
+ 0.173038430661142762569515,
+ 0.0276550813773432047594539,
+ 0.00240625974424309709745382,
+ 0.891811817251336577241006e-4,
+ -0.465528836283382684461025e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GN: &[f64] = &[
+ -0.00270722535905778347999196,
+ 0.0013187563425029400461378,
+ 0.00119925933261002333923989,
+ 0.00027849619811344664248235,
+ 0.267822988218331849989363e-4,
+ 0.923043672315028197865066e-6,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.814632808543141591118279,
+ 0.268901665856299542168425,
+ 0.0449877216103041118694989,
+ 0.00381759663320248459168994,
+ 0.000131571897888596914350697,
+ 0.404815359675764138445257e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HN: &[f64] = &[
+ -0.00109946720691742196814323,
+ 0.000406425442750422675169153,
+ 0.000274499489416900707787024,
+ 0.465293770646659383436343e-4,
+ 0.320955425395767463401993e-5,
+ 0.778286018145020892261936e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HD: &[f64] = &[
+ 1.0,
+ 0.588173710611846046373373,
+ 0.139363331289409746077541,
+ 0.0166329340417083678763028,
+ 0.00100023921310234908642639,
+ 0.24254837521587225125068e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_IN: &[f64] = &[
+ -0.00056907993601094962855594,
+ 0.000169498540373762264416984,
+ 0.518472354581100890120501e-4,
+ 0.382819312231928859704678e-5,
+ 0.824989931281894431781794e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_ID: &[f64] = &[
+ 1.0,
+ 0.339637250051139347430323,
+ 0.043472647870310663055044,
+ 0.00248549335224637114641629,
+ 0.535633305337152900549536e-4,
+ -0.117490944405459578783846e-12,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JN: &[f64] = &[
+ -0.000241313599483991337479091,
+ 0.574224975202501512365975e-4,
+ 0.115998962927383778460557e-4,
+ 0.581762134402593739370875e-6,
+ 0.853971555085673614607418e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JD: &[f64] = &[
+ 1.0,
+ 0.233044138299687841018015,
+ 0.0204186940546440312625597,
+ 0.000797185647564398289151125,
+ 0.117019281670172327758019e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KN: &[f64] = &[
+ -0.000146674699277760365803642,
+ 0.162666552112280519955647e-4,
+ 0.269116248509165239294897e-5,
+ 0.979584479468091935086972e-7,
+ 0.101994647625723465722285e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KD: &[f64] = &[
+ 1.0,
+ 0.165907812944847226546036,
+ 0.0103361716191505884359634,
+ 0.000286593026373868366935721,
+ 0.298401570840900340874568e-5,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LN: &[f64] = &[
+ -0.583905797629771786720406e-4,
+ 0.412510325105496173512992e-5,
+ 0.431790922420250949096906e-6,
+ 0.993365155590013193345569e-8,
+ 0.653480510020104699270084e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LD: &[f64] = &[
+ 1.0,
+ 0.105077086072039915406159,
+ 0.00414278428675475620830226,
+ 0.726338754644523769144108e-4,
+ 0.477818471047398785369849e-6,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MN: &[f64] = &[
+ -0.196457797609229579459841e-4,
+ 0.157243887666800692441195e-5,
+ 0.543902511192700878690335e-7,
+ 0.317472492369117710852685e-9,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MD: &[f64] = &[
+ 1.0,
+ 0.052803989240957632204885,
+ 0.000926876069151753290378112,
+ 0.541011723226630257077328e-5,
+ 0.535093845803642394908747e-15,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_NN: &[f64] = &[
+ -0.789224703978722689089794e-5,
+ 0.622088451660986955124162e-6,
+ 0.145728445676882396797184e-7,
+ 0.603715505542715364529243e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_ND: &[f64] = &[
+ 1.0,
+ 0.0375328846356293715248719,
+ 0.000467919535974625308126054,
+ 0.193847039275845656900547e-5,
+];
+
+// **********************************************************
+// ********** Coefficients for erf_inv_impl polynomial ******
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AN: &[f64] = &[
+ -0.000508781949658280665617,
+ -0.00836874819741736770379,
+ 0.0334806625409744615033,
+ -0.0126926147662974029034,
+ -0.0365637971411762664006,
+ 0.0219878681111168899165,
+ 0.00822687874676915743155,
+ -0.00538772965071242932965,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.970005043303290640362,
+ -1.56574558234175846809,
+ 1.56221558398423026363,
+ 0.662328840472002992063,
+ -0.71228902341542847553,
+ -0.0527396382340099713954,
+ 0.0795283687341571680018,
+ -0.00233393759374190016776,
+ 0.000886216390456424707504,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BN: &[f64] = &[
+ -0.202433508355938759655,
+ 0.105264680699391713268,
+ 8.37050328343119927838,
+ 17.6447298408374015486,
+ -18.8510648058714251895,
+ -44.6382324441786960818,
+ 17.445385985570866523,
+ 21.1294655448340526258,
+ -3.67192254707729348546,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BD: &[f64] = &[
+ 1.0,
+ 6.24264124854247537712,
+ 3.9713437953343869095,
+ -28.6608180499800029974,
+ -20.1432634680485188801,
+ 48.5609213108739935468,
+ 10.8268667355460159008,
+ -22.6436933413139721736,
+ 1.72114765761200282724,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CN: &[f64] = &[
+ -0.131102781679951906451,
+ -0.163794047193317060787,
+ 0.117030156341995252019,
+ 0.387079738972604337464,
+ 0.337785538912035898924,
+ 0.142869534408157156766,
+ 0.0290157910005329060432,
+ 0.00214558995388805277169,
+ -0.679465575181126350155e-6,
+ 0.285225331782217055858e-7,
+ -0.681149956853776992068e-9,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CD: &[f64] = &[
+ 1.0,
+ 3.46625407242567245975,
+ 5.38168345707006855425,
+ 4.77846592945843778382,
+ 2.59301921623620271374,
+ 0.848854343457902036425,
+ 0.152264338295331783612,
+ 0.01105924229346489121,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DN: &[f64] = &[
+ -0.0350353787183177984712,
+ -0.00222426529213447927281,
+ 0.0185573306514231072324,
+ 0.00950804701325919603619,
+ 0.00187123492819559223345,
+ 0.000157544617424960554631,
+ 0.460469890584317994083e-5,
+ -0.230404776911882601748e-9,
+ 0.266339227425782031962e-11,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.3653349817554063097,
+ 0.762059164553623404043,
+ 0.220091105764131249824,
+ 0.0341589143670947727934,
+ 0.00263861676657015992959,
+ 0.764675292302794483503e-4,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_EN: &[f64] = &[
+ -0.0167431005076633737133,
+ -0.00112951438745580278863,
+ 0.00105628862152492910091,
+ 0.000209386317487588078668,
+ 0.149624783758342370182e-4,
+ 0.449696789927706453732e-6,
+ 0.462596163522878599135e-8,
+ -0.281128735628831791805e-13,
+ 0.99055709973310326855e-16,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_ED: &[f64] = &[
+ 1.0,
+ 0.591429344886417493481,
+ 0.138151865749083321638,
+ 0.0160746087093676504695,
+ 0.000964011807005165528527,
+ 0.275335474764726041141e-4,
+ 0.282243172016108031869e-6,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FN: &[f64] = &[
+ -0.0024978212791898131227,
+ -0.779190719229053954292e-5,
+ 0.254723037413027451751e-4,
+ 0.162397777342510920873e-5,
+ 0.396341011304801168516e-7,
+ 0.411632831190944208473e-9,
+ 0.145596286718675035587e-11,
+ -0.116765012397184275695e-17,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FD: &[f64] = &[
+ 1.0,
+ 0.207123112214422517181,
+ 0.0169410838120975906478,
+ 0.000690538265622684595676,
+ 0.145007359818232637924e-4,
+ 0.144437756628144157666e-6,
+ 0.509761276599778486139e-9,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GN: &[f64] = &[
+ -0.000539042911019078575891,
+ -0.28398759004727721098e-6,
+ 0.899465114892291446442e-6,
+ 0.229345859265920864296e-7,
+ 0.225561444863500149219e-9,
+ 0.947846627503022684216e-12,
+ 0.135880130108924861008e-14,
+ -0.348890393399948882918e-21,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.0845746234001899436914,
+ 0.00282092984726264681981,
+ 0.468292921940894236786e-4,
+ 0.399968812193862100054e-6,
+ 0.161809290887904476097e-8,
+ 0.231558608310259605225e-11,
+];
+
+/// `erf_impl` computes the error function at `z`.
+/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
+fn erf_impl(z: f64, inv: bool) -> f64 {
+ if z < 0.0 {
+ if !inv {
+ return -erf_impl(-z, false);
+ }
+ if z < -0.5 {
+ return 2.0 - erf_impl(-z, true);
+ }
+ return 1.0 + erf_impl(-z, false);
+ }
+
+ let result = if z < 0.5 {
+ if z < 1e-10 {
+ z * 1.125 + z * 0.003379167095512573896158903121545171688
+ } else {
+ z * 1.125
+ + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
+ }
+ } else if z < 110.0 {
+ let (r, b) = if z < 0.75 {
+ (
+ evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
+ / evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
+ 0.3440242112,
+ )
+ } else if z < 1.25 {
+ (
+ evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
+ / evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
+ 0.419990927,
+ )
+ } else if z < 2.25 {
+ (
+ evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
+ / evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
+ 0.4898625016,
+ )
+ } else if z < 3.5 {
+ (
+ evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
+ / evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
+ 0.5317370892,
+ )
+ } else if z < 5.25 {
+ (
+ evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
+ / evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
+ 0.5489973426,
+ )
+ } else if z < 8.0 {
+ (
+ evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
+ / evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
+ 0.5571740866,
+ )
+ } else if z < 11.5 {
+ (
+ evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
+ / evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
+ 0.5609807968,
+ )
+ } else if z < 17.0 {
+ (
+ evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
+ / evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
+ 0.5626493692,
+ )
+ } else if z < 24.0 {
+ (
+ evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
+ / evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
+ 0.5634598136,
+ )
+ } else if z < 38.0 {
+ (
+ evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
+ / evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
+ 0.5638477802,
+ )
+ } else if z < 60.0 {
+ (
+ evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
+ / evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
+ 0.5640528202,
+ )
+ } else if z < 85.0 {
+ (
+ evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
+ / evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
+ 0.5641309023,
+ )
+ } else {
+ (
+ evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
+ / evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
+ 0.5641584396,
+ )
+ };
+ let g = (-z * z).exp() / z;
+ g * b + g * r
+ } else {
+ 0.0
+ };
+
+ if inv && z >= 0.5 {
+ result
+ } else if z >= 0.5 || inv {
+ 1.0 - result
+ } else {
+ result
+ }
+}
+
+// `erf_inv_impl` computes the inverse error function where
+// `p`,`q`, and `s` are the first, second, and third intermediate
+// parameters respectively
+fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
+ let result = if p <= 0.5 {
+ let y = 0.0891314744949340820313;
+ let g = p * (p + 10.0);
+ let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
+ g * y + g * r
+ } else if q >= 0.25 {
+ let y = 2.249481201171875;
+ let g = (-2.0 * q.ln()).sqrt();
+ let xs = q - 0.25;
+ let r =
+ evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
+ g / (y + r)
+ } else {
+ let x = (-q.ln()).sqrt();
+ if x < 3.0 {
+ let y = 0.807220458984375;
+ let xs = x - 1.125;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_CD);
+ y * x + r * x
+ } else if x < 6.0 {
+ let y = 0.93995571136474609375;
+ let xs = x - 3.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_DD);
+ y * x + r * x
+ } else if x < 18.0 {
+ let y = 0.98362827301025390625;
+ let xs = x - 6.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_ED);
+ y * x + r * x
+ } else if x < 44.0 {
+ let y = 0.99714565277099609375;
+ let xs = x - 18.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_FD);
+ y * x + r * x
+ } else {
+ let y = 0.99941349029541015625;
+ let xs = x - 44.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_GD);
+ y * x + r * x
+ }
+ };
+ s * result
+}
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 97e195ef..527646d6 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -1,4 +1,7 @@
-pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
+pub trait VecOps: num_traits::NumAssign + Copy {
+ fn min(self, rhs: Self) -> Self;
+ fn max(self, rhs: Self) -> Self;
+
/// Dot-product of two vectors.
///
/// # Safety
@@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x > *res {
- *res = x
- }
+ *res = (*res).max(*xs.add(i))
}
}
@@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x < *res {
- *res = x
- }
+ *res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
}
@@ -76,6 +83,16 @@ impl VecOps for f32 {
impl VecOps for half::f16 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
@@ -83,11 +100,61 @@ impl VecOps for half::f16 {
}
}
-impl VecOps for f64 {}
-impl VecOps for half::bf16 {}
-impl VecOps for u8 {}
-impl VecOps for u32 {}
-impl VecOps for i64 {}
+impl VecOps for f64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for half::bf16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for u8 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for u32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for i64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs
index 9a8e6317..50afb30f 100644
--- a/candle-core/src/cpu/mod.rs
+++ b/candle-core/src/cpu/mod.rs
@@ -1,3 +1,4 @@
+pub mod erf;
pub mod kernels;
trait Cpu<const ARR: usize> {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index ed3dd3fc..4e808b34 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
+use rayon::prelude::*;
+
+const USE_IM2COL_CONV1D: bool = true;
+const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
}
// This function maps over two strided index sequences.
-fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
+pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
}
// Similar to binary_map but with vectorized variants.
-fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
+pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@@ -723,6 +727,36 @@ impl Map1 for MaxPool2D {
}
}
+struct UpsampleNearest1D(usize);
+
+impl Map1 for UpsampleNearest1D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // TODO: Specialized implementation for the case 2*sz?
+ let dst_sz = self.0;
+ let (b_sz, c, src_sz) = layout.shape().dims3()?;
+ let stride = layout.stride();
+ let stride_sz = stride[2];
+ let src_index = layout.start_offset();
+ let scale_sz = src_sz as f64 / dst_sz as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_sz];
+ let src_idxs = (0..dst_sz)
+ .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
+ .collect::<Vec<_>>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_sz..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_sz..];
+ let src_index = src_index + c_idx * stride[1];
+ for (idx, src_idx) in src_idxs.iter().enumerate() {
+ dst[idx] = src[src_index + src_idx * stride_sz]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset in 0..p.k_size {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
+ let &Self {
+ l_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, l) = layout.shape().dims3()?;
+ let l_out = self.l_out(l);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * l_out * c * l_k];
+ let (src_s0, src_s1, src_s2) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - l_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * l_out * c * l_k;
+ for l_idx in 0..l_out {
+ let dst_idx = dst_idx + l_idx * c * l_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * l_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for l_k_idx in 0..l_k {
+ let src_l = l_idx * stride + l_k_idx * dilation;
+ if padding != 0 && (src_l < padding || src_l >= l + padding) {
+ continue;
+ }
+ let src_l = src_l - padding;
+ let src_idx = src_idx + src_l * src_s2;
+ let dst_idx = dst_idx + l_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
+ let &Self {
+ h_k,
+ w_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, h, w) = layout.shape().dims4()?;
+ let (h_out, w_out) = self.hw_out(h, w);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
+ let (src_s0, src_s1, src_s2, src_s3) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2], s[3])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - h_k = w_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
+ for h_idx in 0..h_out {
+ let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
+ for w_idx in 0..w_out {
+ let dst_idx = dst_idx + w_idx * c * h_k * w_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * h_k * w_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for h_k_idx in 0..h_k {
+ let src_h = h_idx * stride + h_k_idx * dilation;
+ if padding != 0 && (src_h < padding || src_h >= h + padding) {
+ continue;
+ }
+ let src_h = src_h - padding;
+ let src_idx = src_idx + src_h * src_s2;
+ let dst_idx = dst_idx + h_k_idx * w_k;
+ for w_k_idx in 0..w_k {
+ let src_w = w_idx * stride + w_k_idx * dilation;
+ if padding != 0 && (src_w < padding || src_w >= w + padding) {
+ continue;
+ }
+ let src_w = src_w - padding;
+ let src_idx = src_idx + src_w * src_s3;
+ let dst_idx = dst_idx + w_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
- let num_threads = crate::utils::get_num_threads();
-
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
}
}
}
- let num_threads = crate::utils::get_num_threads();
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
- crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
@@ -1298,8 +1461,9 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
- if T::DTYPE == DType::BF16 {
- return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
}
let (b, m, n, k) = self.0;
@@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout)
}
+ fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
+ UpsampleNearest1D(sz).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
@@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
- Conv1D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV1D {
+ return Conv1D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col1D {
+ l_k: params.k_size,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let l_out = params.l_out();
+ let k = op.l_k * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv2d(
@@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
- Conv2D(params).map(self, l, kernel, kernel_l)
+ if !USE_IM2COL_CONV2D {
+ return Conv2D(params).map(self, l, kernel, kernel_l);
+ }
+ let op = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ };
+ let col = op.map(self, l)?;
+ let b = params.b_size;
+ let n = params.c_out;
+ let (h_out, w_out) = (params.out_h(), params.out_w());
+ let k = op.h_k * op.w_k * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv_transpose2d(
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 663f2319..00fd1d04 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
-use candle_kernels as kernels;
+pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
+ // curand can only generate an odd number of values.
+ // https://github.com/huggingface/candle/issues/734
+ let elem_count_round = if elem_count % 2 == 1 {
+ elem_count + 1
+ } else {
+ elem_count
+ };
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice {
}
#[derive(Debug)]
-enum CudaStorageSlice {
+pub enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I64(CudaSlice<i64>),
@@ -394,7 +401,7 @@ enum CudaStorageSlice {
}
type S = CudaStorageSlice;
-trait Map1 {
+pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@@ -416,7 +423,7 @@ trait Map1 {
}
}
-trait Map2 {
+pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@@ -441,7 +448,7 @@ trait Map2 {
}
}
-trait Map2InPlace {
+pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@@ -472,7 +479,7 @@ trait Map2InPlace {
}
}
-trait Map1Any {
+pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@@ -495,7 +502,7 @@ trait Map1Any {
}
}
-trait Map2Any {
+pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@@ -532,7 +539,7 @@ impl Map1 for Clone {
}
}
-fn kernel_name<T: WithDType>(root: &str) -> String {
+pub fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
@@ -593,6 +600,105 @@ impl Map1 for Elu {
}
}
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let l_out = self.l_out(dims[2]);
+ let dst_el = dims[0] * l_out * dims[1] * self.l_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ l_out,
+ self.l_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
+ let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let params = (
+ dst_el,
+ h_out,
+ w_out,
+ self.h_k,
+ self.w_k,
+ self.stride,
+ self.padding,
+ self.dilation,
+ &ds,
+ src,
+ &dst,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+}
+
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
@@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)]
pub struct CudaStorage {
- slice: CudaStorageSlice,
- device: CudaDevice,
+ pub slice: CudaStorageSlice,
+ pub device: CudaDevice,
}
pub trait CudaDType: Sized {
@@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
+ const USE_IM2COL_CONV1D: bool = true;
+
let device = self.device().clone();
- let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV1D {
+ let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col1D {
+ l_k: params.k_size,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let l_out = params.l_out();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_size * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(not(feature = "cudnn"))]
@@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
+ const USE_IM2COL_CONV2D: bool = true;
+
let device = self.device().clone();
- let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
- Ok(Self { slice, device })
+ if !USE_IM2COL_CONV2D {
+ let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ return Ok(Self { slice, device });
+ }
+
+ let col = Im2Col {
+ h_k: params.k_h,
+ w_k: params.k_w,
+ stride: params.stride,
+ dilation: params.dilation,
+ padding: params.padding,
+ }
+ .map(&self.slice, &device, l)?;
+ let col = Self { slice: col, device };
+ let h_out = params.out_h();
+ let w_out = params.out_w();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_h * params.k_w * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, n))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
#[cfg(feature = "cudnn")]
@@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
+ crate::bail!("upsample-nearest1d is not supported on cuda")
+ }
+
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 235ad6e3..dd466ba2 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [
params.b_size as i32,
params.c_in as i32,
- params.i_w as i32,
params.i_h as i32,
+ params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
- params.k_w as i32,
params.k_h as i32,
+ params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
- [params.b_size as i32, params.c_out as i32, w_out, h_out],
+ [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index adfc4a3c..c7a1567f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -1,15 +1,24 @@
+//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
+/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
+ // Unsigned 8 bits integer.
U8,
+ // Unsigned 32 bits integer.
U32,
+ // Signed 64 bits integer.
I64,
+ // Brain floating-point using half precision (16 bits).
BF16,
+ // Floating-point using half precision (16 bits).
F16,
+ // Floating-point using single precision (32 bits).
F32,
+ // Floating-point using double precision (64 bits).
F64,
}
@@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
}
impl DType {
+ /// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
@@ -45,6 +55,7 @@ impl DType {
}
}
+ /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 6c896653..5cc9c6d8 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 1cf20a84..be8f7b07 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
- #[error("{op}: dimension index {dim} out of range for {shape:?}")]
+ #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@@ -207,11 +207,11 @@ pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Wrapped(Box::new(err))
+ Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
- Self::Msg(err.to_string())
+ Self::Msg(err.to_string()).bt()
}
pub fn bt(self) -> Self {
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index 2b6d694b..7b84d316 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
+ TensorIndexer::IndexSelect(indexes) => {
+ if indexes.rank() != 1 {
+ crate::bail!("multi-dimensional tensor indexing is not supported")
+ }
+ let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
+ current_dim += 1;
+ out
+ }
+ TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
+ /// Indexing via a 1d tensor
+ IndexSelect(Tensor),
+ Err(Error),
}
impl From<usize> for TensorIndexer {
@@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
}
}
+impl From<&[u32]> for TensorIndexer {
+ fn from(index: &[u32]) -> Self {
+ match Tensor::new(index, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<Vec<u32>> for TensorIndexer {
+ fn from(index: Vec<u32>) -> Self {
+ let len = index.len();
+ match Tensor::from_vec(index, len, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<&Tensor> for TensorIndexer {
+ fn from(tensor: &Tensor) -> Self {
+ TensorIndexer::IndexSelect(tensor.clone())
+ }
+}
+
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index a0347416..52effdcf 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -59,6 +59,7 @@ mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
+pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
@@ -109,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
-pub trait Module: std::fmt::Debug {
+pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
-
- /// Change the module to use training mode vs eval mode.
- ///
- /// The default implementation does nothing as this is only used for a couple modules such as
- /// dropout or batch-normalization.
- fn set_training(&mut self, _training: bool) {}
}
impl Module for quantized::QMatMul {
@@ -124,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
+
+impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self(xs)
+ }
+}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index fbfc9c1a..4882a205 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -58,6 +58,8 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
+ GeluErf,
+ Erf,
Relu,
Tanh,
}
@@ -116,6 +118,7 @@ pub enum Op {
stride: (usize, usize),
},
+ UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
@@ -324,6 +327,8 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
+pub(crate) struct GeluErf;
+pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@@ -600,6 +605,92 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
+
+ #[cfg(feature = "accelerate")]
+ const F32_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f32_vec(xs: &[f32], ys: &mut [f32]) {
+ crate::accelerate::vs_gelu(xs, ys)
+ }
+
+ #[cfg(feature = "accelerate")]
+ const F64_VEC: bool = true;
+
+ #[cfg(feature = "accelerate")]
+ #[inline(always)]
+ fn f64_vec(xs: &[f64], ys: &mut [f64]) {
+ crate::accelerate::vd_gelu(xs, ys)
+ }
+}
+
+impl UnaryOpT for Erf {
+ const NAME: &'static str = "erf";
+ const KERNEL: &'static str = "uerf";
+ const V: Self = Erf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ crate::cpu::erf::erf(v)
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
+}
+
+impl UnaryOpT for GeluErf {
+ const NAME: &'static str = "gelu_erf";
+ const KERNEL: &'static str = "ugelu_erf";
+ const V: Self = GeluErf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
}
impl UnaryOpT for Relu {
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 65fd6a6e..a0fe455c 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -85,7 +85,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
pub struct BlockQ8_1 {
pub(crate) d: f16,
pub(crate) s: f16,
- pub(crate) qs: [u8; QK8_1],
+ pub(crate) qs: [i8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
@@ -278,6 +278,7 @@ impl GgmlType for BlockQ4_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -471,6 +472,7 @@ impl GgmlType for BlockQ5_1 {
}
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
+ + f16::to_f32(xs.m) * f16::to_f32(ys.s)
}
Ok(sumf)
}
@@ -652,8 +654,8 @@ impl GgmlType for BlockQ8_1 {
for j in 0..Self::BLCK_SIZE / 2 {
let v0 = xs[j] * id;
let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
- ys.qs[j] = f32::round(v0) as u8;
- ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as u8;
+ ys.qs[j] = f32::round(v0) as i8;
+ ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
}
ys.s = f16::from_f32(sum as f32) * ys.d;
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 5c2bb2b2..f627f0f6 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -229,7 +229,7 @@ impl QTensor {
}
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index f37bb8ef..d588ea67 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -78,11 +78,7 @@ impl st::View for &Tensor {
}
impl Tensor {
- pub fn save_safetensors<P: AsRef<std::path::Path>>(
- &self,
- name: &str,
- filename: P,
- ) -> Result<()> {
+ pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
- pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs
new file mode 100644
index 00000000..43e1f4c8
--- /dev/null
+++ b/candle-core/src/scalar.rs
@@ -0,0 +1,23 @@
+use crate::{Result, Tensor, WithDType};
+
+pub enum TensorScalar {
+ Tensor(Tensor),
+ Scalar(Tensor),
+}
+
+pub trait TensorOrScalar {
+ fn to_tensor_scalar(self) -> Result<TensorScalar>;
+}
+
+impl TensorOrScalar for &Tensor {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ Ok(TensorScalar::Tensor(self.clone()))
+ }
+}
+
+impl<T: WithDType> TensorOrScalar for T {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ let scalar = Tensor::new(self, &crate::Device::Cpu)?;
+ Ok(TensorScalar::Scalar(scalar))
+ }
+}
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index aea8b887..4d500e7f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -1,3 +1,4 @@
+//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
@@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
+impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
+ fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
+ Self(vec![
+ d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
+ ])
+ }
+}
+
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
@@ -119,6 +128,7 @@ impl Shape {
Self(dims.to_vec())
}
+ /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@@ -127,10 +137,12 @@ impl Shape {
self.0
}
+ /// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
+ /// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@@ -182,6 +194,8 @@ impl Shape {
true
}
+ /// Modifies the shape by adding a list of additional dimensions at the end of the existing
+ /// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
@@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
}
}
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4])
+ }
+}
+
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ let d5 = self.5.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4, d5])
+ }
+}
+
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@@ -457,3 +494,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
+
+pub trait ShapeWithOneHole {
+ fn into_shape(self, el_count: usize) -> Result<Shape>;
+}
+
+impl<S: Into<Shape>> ShapeWithOneHole for S {
+ fn into_shape(self, _el_count: usize) -> Result<Shape> {
+ Ok(self.into())
+ }
+}
+
+impl ShapeWithOneHole for ((),) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ Ok(el_count.into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((el_count / d1, d1).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, ()) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((d1, el_count / d1).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, ()) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, ()) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, (), d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, d4, ()) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, d4, el_count / d).into())
+ }
+}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 8bd14ea9..9bd1fed6 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -369,6 +369,19 @@ impl Storage {
}
}
+ pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e181f240..9dccf2b5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,8 +1,10 @@
+//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
+use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -103,6 +105,28 @@ macro_rules! binary_op {
};
}
+macro_rules! binary_op_scalar {
+ ($fn_name:ident, $op_name:ident) => {
+ pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
+ let storage = self.storage().binary_impl::<crate::op::$op_name>(
+ &*rhs.storage(),
+ self.layout(),
+ rhs.layout(),
+ )?;
+ let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
+ Ok(from_storage(storage, shape.clone(), op, false))
+ }
+ };
+}
+
macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
@@ -445,8 +469,8 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
- binary_op!(maximum, Maximum);
- binary_op!(minimum, Minimum);
+ binary_op_scalar!(maximum, Maximum);
+ binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
@@ -465,6 +489,8 @@ impl Tensor {
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
+ unary_op!(gelu_erf, GeluErf);
+ unary_op!(erf, Erf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
@@ -642,7 +668,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
- let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
+ let op = match op {
+ ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
+ BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
+ }
+ ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
+ };
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
@@ -775,8 +806,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
- pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
- let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@@ -785,45 +823,68 @@ impl Tensor {
}
/// Element-wise equality.
- pub fn eq(&self, rhs: &Self) -> Result<Self> {
+ pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
- pub fn ne(&self, rhs: &Self) -> Result<Self> {
+ pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
- pub fn lt(&self, rhs: &Self) -> Result<Self> {
+ pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
- pub fn gt(&self, rhs: &Self) -> Result<Self> {
+ pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
- pub fn ge(&self, rhs: &Self) -> Result<Self> {
+ pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
- pub fn le(&self, rhs: &Self) -> Result<Self> {
+ pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
- /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
+ /// Clamp the tensor values to be between `min` and `max`.
+ pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
+ self.maximum(min)?.minimum(max)
+ }
+
+ /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
+ ///
+ /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
+ /// tensor also has three dimensions, `(batch, channels, target_size)`.
+ pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
+ let (n, c, _l) = self.dims3()?;
+ let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
+ let storage = self
+ .storage()
+ .upsample_nearest1d(self.layout(), target_size)?;
+ Ok(from_storage(storage, (n, c, target_size), op, false))
+ }
+
+ /// Alias for `interpolate1d`.
+ pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
+ self.interpolate1d(target_size)
+ }
+
+ /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
- pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
@@ -832,6 +893,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
+ /// Alias for `interpolate2d`.
+ pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ self.interpolate2d(target_h, target_w)
+ }
+
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@@ -1684,12 +1750,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
- // TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
+ /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
+ /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
+ /// as to match the number of elements in the tensor.
+ ///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@@ -1699,10 +1768,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
+ ///
+ /// let c = a.reshape((2, (), 1))?;
+ /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
+ ///
/// # Ok::<(), candle_core::Error>(())
/// ```
- pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
- let shape = shape.into();
+ pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
+ let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
@@ -1836,6 +1909,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg0.rank() != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: arg0.rank(),
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx != dim && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ }
if dim == 0 {
Self::cat0(args)
} else {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 6af43196..edd0bd79 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1,4 +1,4 @@
-use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
+use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@@ -33,6 +33,44 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}
+fn clamp(device: &Device) -> Result<()> {
+ let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
+ let tensor = Tensor::new(data, device)?;
+ let tensor = tensor.clamp(1.5, 6.2)?;
+ assert_eq!(
+ tensor.to_vec2::<f32>()?,
+ [[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
+ );
+ Ok(())
+}
+
+fn unary_op(device: &Device) -> Result<()> {
+ let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
+ let tensor = Tensor::new(data, device)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
+ [
+ [-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
+ [2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
+ [
+ [-0.004, 0.8413, 3.9999, -0.046, 0.3457],
+ [2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.erf()?, 4)?,
+ [
+ [-1.0, 0.8427, 1.0, -0.1125, 0.5205],
+ [0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
+ ]
+ );
+ Ok(())
+}
+
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
@@ -877,6 +915,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
+fn randn(device: &Device) -> Result<()> {
+ let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ Ok(())
+}
+
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@@ -889,6 +935,7 @@ test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
+test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
@@ -899,6 +946,8 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+test_device!(randn, randn_cpu, randn_gpu);
+test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index d69318e1..316f31c5 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 30b0d01f..2dac883c 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
-fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
- let mut b = vec![0u8; 4];
- reader.read_exact(&mut b)?;
- let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
- (s + basis * u64::from(x), basis * 256)
- });
- Ok(result as u32)
+fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
+ use byteorder::ReadBytesExt;
+ reader.read_u32::<byteorder::BigEndian>()
}
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 9035eae0..0e2e8093 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -11,19 +11,19 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
-candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
-safetensors = { workspace = true }
-serde = { workspace = true }
-serde_json = { workspace = true }
-num-traits = { workspace = true }
-intel-mkl-src = { workspace = true, optional = true }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
+intel-mkl-src = { workspace = true, optional = true }
+num-traits = { workspace = true }
+rayon = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -50,7 +50,7 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
-flash-attn = ["cuda", "dep:candle-flash-attn"]
+flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md
new file mode 100644
index 00000000..82ca5f40
--- /dev/null
+++ b/candle-examples/examples/bert/README.md
@@ -0,0 +1,44 @@
+# candle-bert
+
+Bert is a general large language model. In this example it can be used for two
+different tasks:
+- Compute sentence embeddings for a prompt.
+- Compute similarities between a set of sentences.
+
+
+## Sentence embeddings
+
+Bert is used to compute the sentence embeddings for a prompt. The model weights
+are downloaded from the hub on the first run.
+
+```bash
+cargo run --example bert --release -- --prompt "Here is a test sentence"
+
+> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
+> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
+> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
+> ...
+> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
+> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
+> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
+> Tensor[[1, 7, 384], f32]
+```
+
+## Similarities
+
+In this example, Bert is used to compute the sentence embeddings for a set of
+sentences (hardcoded in the examples). Then cosine similarities are computed for
+each sentence pair and they are reported by decreasing values, hence the first
+reported pair contains the two sentences that have the highest similarity score.
+The sentence embeddings are computed using average pooling through all the
+sentence tokens, including some potential padding.
+
+```bash
+cargo run --example bert --release
+
+> score: 0.85 'The new movie is awesome' 'The new movie is so great'
+> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
+> score: 0.52 'I love pasta' 'Do you like pizza?'
+> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
+> score: 0.22 'I love pasta' 'The new movie is awesome'
+```
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 6cee66ee..9d0eccdf 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-mod model;
+use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
-use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
diff --git a/candle-examples/examples/bigcode/README.md b/candle-examples/examples/bigcode/README.md
new file mode 100644
index 00000000..cb4e79b1
--- /dev/null
+++ b/candle-examples/examples/bigcode/README.md
@@ -0,0 +1,19 @@
+# candle-starcoder: code generation model
+
+[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
+model specialized to code generation. The initial model was trained on 80
+programming languages.
+
+## Running some example
+
+```bash
+cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
+
+> fn fact(n: u64) -> u64 {
+> if n == 0 {
+> 1
+> } else {
+> n * fact(n - 1)
+> }
+> }
+```
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 652cd47f..5f17109e 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-mod model;
-use model::{Config, GPTBigCode};
+use candle_transformers::models::bigcode::{Config, GPTBigCode};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@@ -29,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
+ top_p: Option<f64>,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@@ -95,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -150,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ &device,
+ );
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/dinov2/README.md b/candle-examples/examples/dinov2/README.md
new file mode 100644
index 00000000..10d4ac1f
--- /dev/null
+++ b/candle-examples/examples/dinov2/README.md
@@ -0,0 +1,19 @@
+# candle-dinov2
+
+[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
+In this example, it is used as an ImageNet classifier: the model returns the
+probability for the image to belong to each of the 1000 ImageNet categories.
+
+## Running some example
+
+```bash
+cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
+
+> mountain bike, all-terrain bike, off-roader: 43.67%
+> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
+> crash helmet : 13.23%
+> unicycle, monocycle : 2.44%
+> maillot : 2.42%
+```
+
+![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index e80c81e2..d3adb37c 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::dinov2;
-const IMG_SIZE: usize = 518;
-const PATCH_SIZE: usize = 14;
-const NUM_CLASSES: usize = 1000;
-
-fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- qkv: Linear,
- proj: Linear,
- num_heads: usize,
- scale: f64,
-}
-
-impl Attention {
- fn new(
- vb: VarBuilder,
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- proj_bias: bool,
- ) -> Result<Self> {
- let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
- let scale = 1. / ((dim / num_heads) as f64).sqrt();
- Ok(Self {
- qkv,
- proj,
- num_heads,
- scale,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (b, n, c) = xs.dims3()?;
- let qkv = self
- .qkv
- .forward(xs)?
- .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)? // 02134
- .transpose(0, 1)? // 20134
- .transpose(2, 3)?; // 20314
- let q = (qkv.i(0)? * self.scale)?;
- let k = qkv.i(1)?;
- let v = qkv.i(2)?;
- let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
- let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
- self.proj.forward(&attn)
- }
-}
-
-#[derive(Debug)]
-struct LayerScale {
- gamma: Tensor,
-}
-
-impl LayerScale {
- fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
- let gamma = vb.get(dim, "gamma")?;
- Ok(Self { gamma })
- }
-}
-
-impl Module for LayerScale {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.broadcast_mul(&self.gamma)
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- fc1: Linear,
- fc2: Linear,
-}
-
-impl Mlp {
- fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
- let out_features = in_features;
- let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
- let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.fc1.forward(xs)?.gelu()?;
- self.fc2.forward(&xs)
- }
-}
-
-#[derive(Debug)]
-struct Block {
- norm1: LayerNorm,
- attn: Attention,
- ls1: LayerScale,
- norm2: LayerNorm,
- mlp: Mlp,
- ls2: LayerScale,
-}
-
-impl Block {
- fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
- let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
- let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
- let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
- let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
- let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
- let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
- Ok(Self {
- norm1,
- attn,
- ls1,
- norm2,
- mlp,
- ls2,
- })
- }
-}
-
-impl Module for Block {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- let xs = self
- .ls1
- .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
- let xs = (xs + residual)?;
- let residual = &xs;
- let xs = self
- .ls2
- .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
- xs + residual
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- proj: candle_nn::Conv2d,
- patch_size: (usize, usize),
- num_patches: usize,
-}
-
-impl PatchEmbed {
- fn new(
- vb: VarBuilder,
- img_size: usize,
- patch_size: usize,
- in_chans: usize,
- embed_dim: usize,
- ) -> Result<Self> {
- let config = candle_nn::Conv2dConfig {
- stride: patch_size,
- ..Default::default()
- };
- let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
- let num_patches = (img_size / patch_size) * (img_size / patch_size);
- Ok(Self {
- proj,
- patch_size: (patch_size, patch_size),
- num_patches,
- })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _c, h, w) = xs.dims4()?;
- let (patch_h, patch_w) = self.patch_size;
- if (h % patch_h) != 0 {
- candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
- }
- if (w % patch_w) != 0 {
- candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
- }
- let xs = self.proj.forward(xs)?;
- let (b, c, h, w) = xs.dims4()?;
- // flatten embeddings.
- xs.reshape((b, c, h * w))?.transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-pub struct DinoVisionTransformer {
- patch_embed: PatchEmbed,
- cls_token: Tensor,
- pos_embed: Tensor,
- blocks: Vec<Block>,
- norm: LayerNorm,
- head: Linear,
-}
-
-impl DinoVisionTransformer {
- pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
- let patch_embed =
- PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
- let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
- let num_tokens = 1;
- let pos_embed = vb.get(
- (1, patch_embed.num_patches + num_tokens, embed_dim),
- "pos_embed",
- )?;
- let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
- let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
- let vb_b = vb.pp("blocks");
- let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- patch_embed,
- cls_token,
- pos_embed,
- blocks,
- norm,
- head,
- })
- }
-
- fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
- let npatch = xs.dim(1)? - 1;
- let n = self.pos_embed.dim(1)? - 1;
- let sqrt_n = (n as f64).sqrt();
- if npatch == n && w == h {
- return Ok(xs.clone());
- }
- let class_pos_embed = self.pos_embed.i((.., ..1))?;
- let patch_pos_embed = self.pos_embed.i((.., 1..))?;
- let dim = xs.dim(D::Minus1)?;
- let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
- let patch_pos_embed = patch_pos_embed
- .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
- .transpose(2, 3)?
- .transpose(1, 2)?;
- // This uses bicubic interpolation in the original implementation.
- let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
- let el_count = patch_pos_embed.shape().elem_count();
- let patch_pos_embed =
- patch_pos_embed
- .transpose(1, 2)?
- .transpose(2, 3)?
- .reshape((1, el_count / dim, dim))?;
- Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
- }
-
- fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _nc, w, h) = xs.dims4()?;
- let xs = self.patch_embed.forward(xs)?;
- let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
- &xs + &self.interpolate_pos_encoding(&xs, w, h)?
- }
-}
-
-impl Module for DinoVisionTransformer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.prepare_tokens_with_mask(xs)?;
- for blk in self.blocks.iter() {
- xs = blk.forward(&xs)?
- }
- let xs = self.norm.forward(&xs)?;
- let xs_norm_clstoken = xs.i((.., 0))?;
- let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
- let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
- self.head.forward(&xs)
- }
-}
-
-pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
- DinoVisionTransformer::new(vb, 12, 384, 6)
-}
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
- let model = vit_small(vb)?;
+ let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs
index cbe2c90a..1e45e301 100644
--- a/candle-examples/examples/efficientnet/main.rs
+++ b/candle-examples/examples/efficientnet/main.rs
@@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn as nn;
-use nn::{Module, VarBuilder};
-
-// Based on the Python version from torchvision.
-// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
-#[derive(Debug, Clone, Copy)]
-pub struct MBConvConfig {
- expand_ratio: f64,
- kernel: usize,
- stride: usize,
- input_channels: usize,
- out_channels: usize,
- num_layers: usize,
-}
-
-fn make_divisible(v: f64, divisor: usize) -> usize {
- let min_value = divisor;
- let new_v = usize::max(
- min_value,
- (v + divisor as f64 * 0.5) as usize / divisor * divisor,
- );
- if (new_v as f64) < 0.9 * v {
- new_v + divisor
- } else {
- new_v
- }
-}
-
-fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
- let bneck_conf = |e, k, s, i, o, n| {
- let input_channels = make_divisible(i as f64 * width_mult, 8);
- let out_channels = make_divisible(o as f64 * width_mult, 8);
- let num_layers = (n as f64 * depth_mult).ceil() as usize;
- MBConvConfig {
- expand_ratio: e,
- kernel: k,
- stride: s,
- input_channels,
- out_channels,
- num_layers,
- }
- };
- vec![
- bneck_conf(1., 3, 1, 32, 16, 1),
- bneck_conf(6., 3, 2, 16, 24, 2),
- bneck_conf(6., 5, 2, 24, 40, 2),
- bneck_conf(6., 3, 2, 40, 80, 3),
- bneck_conf(6., 5, 1, 80, 112, 3),
- bneck_conf(6., 5, 2, 112, 192, 4),
- bneck_conf(6., 3, 1, 192, 320, 1),
- ]
-}
-
-impl MBConvConfig {
- fn b0() -> Vec<Self> {
- bneck_confs(1.0, 1.0)
- }
- fn b1() -> Vec<Self> {
- bneck_confs(1.0, 1.1)
- }
- fn b2() -> Vec<Self> {
- bneck_confs(1.1, 1.2)
- }
- fn b3() -> Vec<Self> {
- bneck_confs(1.2, 1.4)
- }
- fn b4() -> Vec<Self> {
- bneck_confs(1.4, 1.8)
- }
- fn b5() -> Vec<Self> {
- bneck_confs(1.6, 2.2)
- }
- fn b6() -> Vec<Self> {
- bneck_confs(1.8, 2.6)
- }
- fn b7() -> Vec<Self> {
- bneck_confs(2.0, 3.1)
- }
-}
-
-/// Conv2D with same padding.
-#[derive(Debug)]
-struct Conv2DSame {
- conv2d: nn::Conv2d,
- s: usize,
- k: usize,
-}
-
-impl Conv2DSame {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- bias: bool,
- ) -> Result<Self> {
- let conv_config = nn::Conv2dConfig {
- stride,
- groups,
- ..Default::default()
- };
- let conv2d = if bias {
- nn::conv2d(i, o, k, conv_config, vb)?
- } else {
- nn::conv2d_no_bias(i, o, k, conv_config, vb)?
- };
- Ok(Self {
- conv2d,
- s: stride,
- k,
- })
- }
-}
-
-impl Module for Conv2DSame {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let s = self.s;
- let k = self.k;
- let (_, _, ih, iw) = xs.dims4()?;
- let oh = (ih + s - 1) / s;
- let ow = (iw + s - 1) / s;
- let pad_h = usize::max((oh - 1) * s + k - ih, 0);
- let pad_w = usize::max((ow - 1) * s + k - iw, 0);
- if pad_h > 0 || pad_w > 0 {
- let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
- let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
- self.conv2d.forward(&xs)
- } else {
- self.conv2d.forward(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct ConvNormActivation {
- conv2d: Conv2DSame,
- bn2d: nn::BatchNorm,
- activation: bool,
-}
-
-impl ConvNormActivation {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- ) -> Result<Self> {
- let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
- let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
- Ok(Self {
- conv2d,
- bn2d,
- activation: true,
- })
- }
-
- fn no_activation(self) -> Self {
- Self {
- activation: false,
- ..self
- }
- }
-}
-
-impl Module for ConvNormActivation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.conv2d.forward(xs)?;
- let xs = self.bn2d.forward(&xs)?;
- if self.activation {
- swish(&xs)
- } else {
- Ok(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct SqueezeExcitation {
- fc1: Conv2DSame,
- fc2: Conv2DSame,
-}
-
-impl SqueezeExcitation {
- fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
- let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
- let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for SqueezeExcitation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- // equivalent to adaptive_avg_pool2d([1, 1])
- let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = swish(&xs)?;
- let xs = self.fc2.forward(&xs)?;
- let xs = nn::ops::sigmoid(&xs)?;
- residual.broadcast_mul(&xs)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- expand_cna: Option<ConvNormActivation>,
- depthwise_cna: ConvNormActivation,
- squeeze_excitation: SqueezeExcitation,
- project_cna: ConvNormActivation,
- config: MBConvConfig,
-}
-
-impl MBConv {
- fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
- let vb = vb.pp("block");
- let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
- let expand_cna = if exp != c.input_channels {
- Some(ConvNormActivation::new(
- vb.pp("0"),
- c.input_channels,
- exp,
- 1,
- 1,
- 1,
- )?)
- } else {
- None
- };
- let start_index = if expand_cna.is_some() { 1 } else { 0 };
- let depthwise_cna =
- ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
- let squeeze_channels = usize::max(1, c.input_channels / 4);
- let squeeze_excitation =
- SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
- let project_cna =
- ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
- .no_activation();
- Ok(Self {
- expand_cna,
- depthwise_cna,
- squeeze_excitation,
- project_cna,
- config: c,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let use_res_connect =
- self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
- let ys = match &self.expand_cna {
- Some(expand_cna) => expand_cna.forward(xs)?,
- None => xs.clone(),
- };
- let ys = self.depthwise_cna.forward(&ys)?;
- let ys = self.squeeze_excitation.forward(&ys)?;
- let ys = self.project_cna.forward(&ys)?;
- if use_res_connect {
- ys + xs
- } else {
- Ok(ys)
- }
- }
-}
-
-fn swish(s: &Tensor) -> Result<Tensor> {
- s * nn::ops::sigmoid(s)?
-}
-
-#[derive(Debug)]
-struct EfficientNet {
- init_cna: ConvNormActivation,
- blocks: Vec<MBConv>,
- final_cna: ConvNormActivation,
- classifier: nn::Linear,
-}
-
-impl EfficientNet {
- fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
- let f_p = p.pp("features");
- let first_in_c = configs[0].input_channels;
- let last_out_c = configs.last().unwrap().out_channels;
- let final_out_c = 4 * last_out_c;
- let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
- let nconfigs = configs.len();
- let mut blocks = vec![];
- for (index, cnf) in configs.into_iter().enumerate() {
- let f_p = f_p.pp(index + 1);
- for r_index in 0..cnf.num_layers {
- let cnf = if r_index == 0 {
- cnf
- } else {
- MBConvConfig {
- input_channels: cnf.out_channels,
- stride: 1,
- ..cnf
- }
- };
- blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
- }
- }
- let final_cna =
- ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
- let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
- Ok(Self {
- init_cna,
- blocks,
- final_cna,
- classifier,
- })
- }
-}
-
-impl Module for EfficientNet {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.init_cna.forward(xs)?;
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- let xs = self.final_cna.forward(&xs)?;
- // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
- let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
- self.classifier.forward(&xs)
- }
-}
-
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
diff --git a/candle-examples/examples/falcon/README.md b/candle-examples/examples/falcon/README.md
new file mode 100644
index 00000000..267c78c2
--- /dev/null
+++ b/candle-examples/examples/falcon/README.md
@@ -0,0 +1,3 @@
+# candle-falcon
+
+Falcon is a general large language model.
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index 05507f08..b0973d64 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -14,8 +14,7 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
-mod model;
-use model::{Config, Falcon};
+use candle_transformers::models::falcon::{Config, Falcon};
struct TextGeneration {
model: Falcon,
@@ -26,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize,
}
+struct GenerationOptions {
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
impl TextGeneration {
fn new(
model: Falcon,
tokenizer: Tokenizer,
+ generation_options: GenerationOptions,
seed: u64,
- temp: Option<f64>,
device: &Device,
- repeat_penalty: f32,
- repeat_last_n: usize,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor =
+ LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
+ let repeat_penalty = generation_options.repeat_penalty;
+ let repeat_last_n = generation_options.repeat_last_n;
Self {
model,
tokenizer,
@@ -119,6 +126,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -186,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- &device,
- args.repeat_penalty,
- args.repeat_last_n,
- );
+ let generation_options = GenerationOptions {
+ temp: args.temperature,
+ top_p: args.top_p,
+ repeat_penalty: args.repeat_penalty,
+ repeat_last_n: args.repeat_last_n,
+ };
+ let mut pipeline =
+ TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 6f8766d4..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
-mod model;
+use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
-const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
@@ -43,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -194,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index e0ade322..e752a494 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
#[arg(long, default_value = "")]
prompt: String,
@@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
+ top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs
index 17dc90e2..8a13ce6c 100644
--- a/candle-examples/examples/llama_multiprocess/main.rs
+++ b/candle-examples/examples/llama_multiprocess/main.rs
@@ -89,6 +89,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -222,7 +226,7 @@ fn main() -> Result<()> {
.to_vec();
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 3794c22d..0fae67b5 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -13,7 +13,6 @@ extern crate accelerate_src;
mod encodec_model;
mod musicgen_model;
mod nn;
-mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
@@ -78,7 +77,7 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
- let model = MusicgenForConditionalGeneration::load(vb, config)?;
+ let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 7e272fd7..d6d8ae15 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -1,9 +1,10 @@
-use crate::{encodec_model, t5_model};
+use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
+use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
- pub text_encoder: crate::t5_model::T5EncoderModel,
+ pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
@@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
- t5: crate::t5_model::Config,
+ t5: t5::Config,
encodec: crate::encodec_model::Config,
}
@@ -387,7 +388,7 @@ impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
- t5: t5_model::Config::musicgen_small(),
+ t5: t5::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
@@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
}
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
- let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
+ let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
deleted file mode 100644
index 607b5c93..00000000
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ /dev/null
@@ -1,397 +0,0 @@
-// T5 Text Encoder
-// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
-
-use candle::{DType, Result, Tensor, D};
-use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
-use std::sync::Arc;
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct Config {
- vocab_size: usize,
- d_model: usize,
- d_kv: usize,
- d_ff: usize,
- num_layers: usize,
- num_decoder_layers: Option<usize>,
- num_heads: usize,
- relative_attention_num_buckets: usize,
- relative_attention_max_distance: usize,
- dropout_rate: f64,
- layer_norm_epsilon: f64,
- initializer_factor: f64,
- feed_forward_proj: Activation,
- is_decoder: bool,
- is_encoder_decoder: bool,
- use_cache: bool,
- pad_token_id: usize,
- eos_token_id: usize,
-}
-
-impl Default for Config {
- fn default() -> Self {
- Self {
- vocab_size: 32128,
- d_model: 512,
- d_kv: 64,
- d_ff: 2048,
- num_layers: 6,
- num_decoder_layers: None,
- num_heads: 8,
- relative_attention_num_buckets: 32,
- relative_attention_max_distance: 128,
- dropout_rate: 0.1,
- layer_norm_epsilon: 1e-6,
- initializer_factor: 1.0,
- feed_forward_proj: Activation::Relu,
- is_decoder: false,
- is_encoder_decoder: true,
- use_cache: true,
- pad_token_id: 0,
- eos_token_id: 1,
- }
- }
-}
-
-impl Config {
- // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
- pub fn musicgen_small() -> Self {
- Self {
- d_ff: 3072,
- d_kv: 64,
- d_model: 768,
- dropout_rate: 0.1,
- eos_token_id: 1,
- feed_forward_proj: Activation::Relu,
- initializer_factor: 1.0,
- is_decoder: false,
- is_encoder_decoder: true,
- layer_norm_epsilon: 1e-6,
- num_decoder_layers: Some(12),
- num_heads: 12,
- num_layers: 12,
- pad_token_id: 0,
- relative_attention_max_distance: 128,
- relative_attention_num_buckets: 32,
- use_cache: true,
- vocab_size: 32128,
- }
- }
-}
-
-#[derive(Debug)]
-struct T5LayerNorm {
- weight: Tensor,
- variance_epsilon: f64,
-}
-
-impl T5LayerNorm {
- fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let weight = vb.get(h, "weight")?;
- Ok(Self {
- weight,
- variance_epsilon: eps,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let dtype = xs.dtype();
- let xs_f32 = xs.to_dtype(DType::F32)?;
- // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
- let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
- let xs = xs.to_dtype(dtype)?;
- let xs = xs.broadcast_mul(&self.weight)?;
- Ok(xs)
- }
-}
-
-#[derive(Debug)]
-struct T5DenseActDense {
- wi: Linear,
- wo: Linear,
- act: Activation,
-}
-
-impl T5DenseActDense {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
- let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
- Ok(Self {
- wi,
- wo,
- act: Activation::Relu,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.wi.forward(xs)?;
- let xs = self.act.forward(&xs)?;
- let xs = self.wo.forward(&xs)?;
- Ok(xs)
- }
-}
-
-#[derive(Debug)]
-struct T5LayerFF {
- dense_relu_dense: T5DenseActDense,
- layer_norm: T5LayerNorm,
-}
-
-impl T5LayerFF {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- // is_gated_act is not supported.
- let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
- let layer_norm =
- T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
- Ok(Self {
- dense_relu_dense,
- layer_norm,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let ys = self.layer_norm.forward(xs)?;
- let ys = self.dense_relu_dense.forward(&ys)?;
- let xs = (xs + ys)?;
- Ok(xs)
- }
-}
-
-#[derive(Debug)]
-struct T5Attention {
- q: Linear,
- k: Linear,
- v: Linear,
- o: Linear,
- n_heads: usize,
- d_kv: usize,
- relative_attention_bias: Option<Embedding>,
- relative_attention_num_buckets: usize,
- relative_attention_max_distance: usize,
- inner_dim: usize,
-}
-
-impl T5Attention {
- fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let inner_dim = cfg.num_heads * cfg.d_kv;
- let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
- let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
- let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
- let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
- let relative_attention_bias = if h {
- let emb = embedding(
- cfg.relative_attention_num_buckets,
- cfg.num_heads,
- vb.pp("relative_attention_bias"),
- )?;
- Some(emb)
- } else {
- None
- };
- Ok(Self {
- q,
- k,
- v,
- o,
- n_heads: cfg.num_heads,
- d_kv: cfg.d_kv,
- relative_attention_bias,
- relative_attention_num_buckets: cfg.relative_attention_num_buckets,
- relative_attention_max_distance: cfg.relative_attention_max_distance,
- inner_dim,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- // TODO: Apply the mask(s)?
- // TODO: kv caching.
- let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
- let q = self.q.forward(xs)?;
- let k = self.k.forward(xs)?;
- let v = self.v.forward(xs)?;
- let q = q
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?
- .contiguous()?;
- let k = k
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?
- .contiguous()?;
- let v = v
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?
- .contiguous()?;
- let scores = q.matmul(&k.t()?)?;
-
- let scores = match &self.relative_attention_bias {
- None => scores,
- Some(relative_attention_bias) => {
- let query_length = seq_len;
- let key_length = seq_len;
- // This only handles the bidirectional case.
- let num_buckets = self.relative_attention_num_buckets / 2;
- let relative_position = (0..query_length as u32)
- .map(|i| {
- (0..key_length as u32)
- .map(|j| {
- if i < j {
- j - i + num_buckets as u32
- } else {
- i - j
- }
- })
- .collect::<Vec<u32>>()
- })
- .collect::<Vec<Vec<_>>>();
- let relative_buckets = Tensor::new(relative_position, q.device())?;
- let position_bias = relative_attention_bias
- .forward(&relative_buckets)?
- .permute((2, 0, 1))?
- .unsqueeze(0)?;
- (scores + position_bias)?
- // TODO: position_bias_masked?
- }
- };
-
- let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
- let attn_output = attn_weights.matmul(&v)?;
- let attn_output = attn_output
- .transpose(1, 2)?
- .reshape((b_sz, seq_len, self.inner_dim))?;
- let attn_output = self.o.forward(&attn_output)?;
- Ok(attn_output)
- }
-}
-
-#[derive(Debug)]
-struct T5LayerSelfAttention {
- self_attention: T5Attention,
- layer_norm: T5LayerNorm,
-}
-
-impl T5LayerSelfAttention {
- fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
- let layer_norm =
- T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
- Ok(Self {
- self_attention,
- layer_norm,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let normed_xs = self.layer_norm.forward(xs)?;
- let ys = self.self_attention.forward(&normed_xs)?;
- let ys = (xs + ys)?;
- Ok(ys)
- }
-}
-
-#[derive(Debug)]
-struct T5LayerCrossAttention {}
-
-impl T5LayerCrossAttention {
- fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
- todo!()
- }
-
- fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
- todo!()
- }
-}
-
-#[derive(Debug)]
-struct T5Block {
- self_attn: T5LayerSelfAttention,
- cross_attn: Option<T5LayerCrossAttention>,
- ff: T5LayerFF,
-}
-
-impl T5Block {
- fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let vb = vb.pp("layer");
- let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
- let cross_attn = if cfg.is_decoder {
- Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
- } else {
- None
- };
- let ff_i = if cross_attn.is_some() { 2 } else { 1 };
- let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
- Ok(Self {
- self_attn,
- cross_attn,
- ff,
- })
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.self_attn.forward(xs)?;
- // TODO: clamp for f16?
- if let Some(cross_attn) = &self.cross_attn {
- xs = cross_attn.forward(&xs)?;
- // TODO: clamp for f16?
- }
- let xs = self.ff.forward(&xs)?;
- // TODO: clamp for f16?
- Ok(xs)
- }
-}
-
-#[derive(Debug)]
-struct T5Stack {
- block: Vec<T5Block>,
- shared: Arc<Embedding>,
- final_layer_norm: T5LayerNorm,
-}
-
-impl T5Stack {
- fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
- let block = (0..cfg.num_layers)
- .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
- .collect::<Result<Vec<_>>>()?;
- let final_layer_norm = T5LayerNorm::load(
- cfg.d_model,
- cfg.layer_norm_epsilon,
- vb.pp("final_layer_norm"),
- )?;
- Ok(Self {
- block,
- shared: shared.clone(),
- final_layer_norm,
- })
- }
-
- fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
- let input_embeds = self.shared.as_ref().forward(input_ids)?;
- let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
-
- let mut hidden_states = input_embeds;
- for block in self.block.iter() {
- hidden_states = block.forward(&hidden_states)?
- }
- let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
- Ok(hidden_states)
- }
-}
-
-#[derive(Debug)]
-pub struct T5EncoderModel {
- shared: Arc<Embedding>,
- encoder: T5Stack,
-}
-
-impl T5EncoderModel {
- pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
- let shared = Arc::new(shared);
- let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
- Ok(Self { shared, encoder })
- }
-
- pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
- let encoder_outputs = self.encoder.forward(input_ids)?;
- Ok(encoder_outputs)
- }
-}
diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md
new file mode 100644
index 00000000..1f6b99eb
--- /dev/null
+++ b/candle-examples/examples/quantized-t5/README.md
@@ -0,0 +1,17 @@
+# candle-quantized-t5
+
+This example uses a quantized version of the t5 model.
+
+```bash
+$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
+...
+ Eine schöne Kerze.
+```
+
+The weight file is automatically retrieved from the hub. It is also possible to
+generate quantized weight files from the original safetensors file by using the
+`tensor-tools` command line utility via:
+
+```bash
+cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
+```
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs
new file mode 100644
index 00000000..93a86309
--- /dev/null
+++ b/candle-examples/examples/quantized-t5/main.rs
@@ -0,0 +1,214 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+use std::io::Write;
+use std::path::PathBuf;
+
+use candle_transformers::models::quantized_t5 as t5;
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+use clap::{Parser, ValueEnum};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+#[derive(Clone, Debug, Copy, ValueEnum)]
+enum Which {
+ T5Small,
+ FlanT5Small,
+ FlanT5Base,
+ FlanT5Large,
+ FlanT5Xl,
+ FlanT5Xxl,
+}
+
+#[derive(Parser, Debug, Clone)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model repository to use on the HuggingFace hub.
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ #[arg(long)]
+ weight_file: Option<String>,
+
+ // Enable/disable decoding.
+ #[arg(long, default_value = "false")]
+ disable_cache: bool,
+
+ /// Use this prompt, otherwise compute sentence similarities.
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long, default_value_t = 0.8)]
+ temperature: f64,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+
+ /// The model size to use.
+ #[arg(long, default_value = "t5-small")]
+ which: Which,
+}
+
+struct T5ModelBuilder {
+ device: Device,
+ config: t5::Config,
+ weights_filename: PathBuf,
+}
+
+impl T5ModelBuilder {
+ pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
+ let device = Device::Cpu;
+ let default_model = "lmz/candle-quantized-t5".to_string();
+ let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, "main".to_string()),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ let config_filename = match args.which {
+ Which::T5Small => api.get("config.json")?,
+ Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
+ Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
+ Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
+ Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
+ Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
+ };
+ let tokenizer_filename = api.get("tokenizer.json")?;
+ let weights_filename = match &args.weight_file {
+ Some(filename) => std::path::PathBuf::from(filename),
+ None => match args.which {
+ Which::T5Small => api.get("model.gguf")?,
+ Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
+ Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
+ Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
+ Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
+ Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
+ },
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let mut config: t5::Config = serde_json::from_str(&config)?;
+ config.use_cache = !args.disable_cache;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ Ok((
+ Self {
+ device,
+ config,
+ weights_filename,
+ },
+ tokenizer,
+ ))
+ }
+
+ pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
+ let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
+ Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
+ }
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
+ let device = &builder.device;
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ let tokens = tokenizer
+ .encode(args.prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let mut model = builder.build_model()?;
+ let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ let temperature = if args.temperature <= 0. {
+ None
+ } else {
+ Some(args.temperature)
+ };
+ let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
+ let encoder_output = model.encode(&input_token_ids)?;
+ let start = std::time::Instant::now();
+
+ for index in 0.. {
+ if output_token_ids.len() > 512 {
+ break;
+ }
+ let decoder_token_ids = if index == 0 || !builder.config.use_cache {
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
+ } else {
+ let last_token = *output_token_ids.last().unwrap();
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
+ };
+ let logits = model
+ .decode(&decoder_token_ids, &encoder_output)?
+ .squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &output_token_ids[start_at..],
+ )?
+ };
+
+ let next_token_id = logits_processor.sample(&logits)?;
+ if next_token_id as usize == builder.config.eos_token_id {
+ break;
+ }
+ output_token_ids.push(next_token_id);
+ if let Some(text) = tokenizer.id_to_token(next_token_id) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ print!("{text}");
+ std::io::stdout().flush()?;
+ }
+ }
+ let dt = start.elapsed();
+ println!(
+ "\n{} tokens generated ({:.2} token/s)\n",
+ output_token_ids.len(),
+ output_token_ids.len() as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+}
diff --git a/candle-examples/examples/quantized/README.md b/candle-examples/examples/quantized/README.md
new file mode 100644
index 00000000..bed09243
--- /dev/null
+++ b/candle-examples/examples/quantized/README.md
@@ -0,0 +1,37 @@
+# candle-quantized-llama: Fast Inference of quantized LLaMA models
+
+This example provides a quantized LLaMA model similar to
+[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle
+built-in quantization methods. Supported features include:
+
+- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support.
+- SIMD optimizations on Apple Silicon and x86.
+- Support using the `gguf` and `ggml` file formats.
+
+The weights are automatically downloaded for you from the [HuggingFace
+Hub](https://huggingface.co/) on the first run. There are various command line
+flags to use local files instead, run with `--help` to learn about them.
+
+![Axiom of Choice](./assets/aoc.gif)
+
+## Running some example.
+
+```bash
+cargo run --example quantized --release -- --prompt "The best thing about coding in rust is "
+
+> avx: true, neon: false, simd128: false, f16c: true
+> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
+> loaded 291 tensors (3.79GB) in 2.17s
+> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 }
+> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
+```
+
+## Command-line flags
+
+Run with `--help` to see all options.
+
+- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`.
+- `--prompt interactive`: interactive mode where multiple prompts can be
+ entered.
+- `--model mymodelfile.gguf`: use a local model file rather than getting one
+ from the hub.
diff --git a/candle-examples/examples/quantized/assets/aoc.gif b/candle-examples/examples/quantized/assets/aoc.gif
new file mode 100644
index 00000000..686074af
--- /dev/null
+++ b/candle-examples/examples/quantized/assets/aoc.gif
Binary files differ
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index a3f98d8e..a80ad420 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
-mod model;
+use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
@@ -71,6 +71,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)]
temperature: f64,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
- let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md
new file mode 100644
index 00000000..3c5b034f
--- /dev/null
+++ b/candle-examples/examples/segment-anything/README.md
@@ -0,0 +1,40 @@
+# candle-segment-anything: Segment-Anything Model
+
+This example is based on Meta AI [Segment-Anything
+Model](https://github.com/facebookresearch/segment-anything). This model
+provides a robust and fast image segmentation pipeline that can be tweaked via
+some prompting (requesting some points to be in the target mask, requesting some
+points to be part of the background so _not_ in the target mask, specifying some
+bounding box).
+
+The default backbone can be replaced by the smaller and faster TinyViT model
+based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
+
+## Running some example.
+
+```bash
+cargo run --example segment-anything --release -- \
+ --image candle-examples/examples/yolo-v8/assets/bike.jpg
+ --use-tiny
+ --point-x 0.4
+ --point-y 0.3
+```
+
+Running this command generates a `sam_merged.jpg` file containing the original
+image with a blue overlay of the selected mask. The red dot represents the prompt
+specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
+of the target mask.
+
+The values used for `--point-x` and `--point-y` should be between 0 and 1 and
+are proportional to the image dimension, i.e. use 0.5 for the image center.
+
+![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
+
+![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg)
+
+### Command-line flags
+- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
+ one.
+- `--point-x`, `--point-y`: specifies the location of the target point.
+- `--threshold`: sets the threshold value to be part of the mask, a negative
+ value results in a larger mask and can be specified via `--threshold=-1.2`.
diff --git a/candle-examples/examples/segment-anything/assets/sam_merged.jpg b/candle-examples/examples/segment-anything/assets/sam_merged.jpg
new file mode 100644
index 00000000..a5f64e5e
--- /dev/null
+++ b/candle-examples/examples/segment-anything/assets/sam_merged.jpg
Binary files differ
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
new file mode 100644
index 00000000..3d9898b6
--- /dev/null
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -0,0 +1,164 @@
+//! SAM: Segment Anything Model
+//! https://github.com/facebookresearch/segment-anything
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use candle::DType;
+use candle_nn::VarBuilder;
+use candle_transformers::models::segment_anything::sam;
+use clap::Parser;
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ image: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(long)]
+ generate_masks: bool,
+
+ /// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
+ #[arg(long, default_value_t = 0.5)]
+ point_x: f64,
+
+ /// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
+ #[arg(long, default_value_t = 0.5)]
+ point_y: f64,
+
+ /// The detection threshold for the mask, 0 is the default value, negative values mean a larger
+ /// mask, positive makes the mask more selective.
+ #[arg(long, default_value_t = 0.)]
+ threshold: f32,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// Use the TinyViT based models from MobileSAM
+ #[arg(long)]
+ use_tiny: bool,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let (image, initial_h, initial_w) =
+ candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
+ let image = image.to_device(&device)?;
+ println!("loaded image {image:?}");
+
+ let model = match args.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("lmz/candle-sam".to_string());
+ let filename = if args.use_tiny {
+ "mobile_sam-tiny-vitt.safetensors"
+ } else {
+ "sam_vit_b_01ec64.safetensors"
+ };
+ api.get(filename)?
+ }
+ };
+ let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let sam = if args.use_tiny {
+ sam::Sam::new_tiny(vb)? // tiny vit_t
+ } else {
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ };
+
+ if args.generate_masks {
+ // Default options similar to the Python version.
+ let bboxes = sam.generate_masks(
+ &image,
+ /* points_per_side */ 32,
+ /* crop_n_layer */ 0,
+ /* crop_overlap_ratio */ 512. / 1500.,
+ /* crop_n_points_downscale_factor */ 1,
+ )?;
+ for (idx, bbox) in bboxes.iter().enumerate() {
+ println!("{idx} {bbox:?}");
+ let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
+ let (h, w) = mask.dims2()?;
+ let mask = mask.broadcast_as((3, h, w))?;
+ candle_examples::save_image_resize(
+ &mask,
+ format!("sam_mask{idx}.png"),
+ initial_h,
+ initial_w,
+ )?;
+ }
+ } else {
+ let point = Some((args.point_x, args.point_y));
+ let start_time = std::time::Instant::now();
+ let (mask, iou_predictions) = sam.forward(&image, point, false)?;
+ println!(
+ "mask generated in {:.2}s",
+ start_time.elapsed().as_secs_f32()
+ );
+ println!("mask:\n{mask}");
+ println!("iou_predictions: {iou_predictions:?}");
+
+ let mask = (mask.ge(args.threshold)? * 255.)?;
+ let (_one, h, w) = mask.dims3()?;
+ let mask = mask.expand((3, h, w))?;
+
+ let mut img = image::io::Reader::open(&args.image)?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
+ let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
+ Some(image) => image,
+ None => anyhow::bail!("error saving merged image"),
+ };
+ let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
+ img.width(),
+ img.height(),
+ image::imageops::FilterType::CatmullRom,
+ );
+ for x in 0..img.width() {
+ for y in 0..img.height() {
+ let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
+ if mask_p.0[0] > 100 {
+ let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
+ img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
+ img_p.0[1] /= 2;
+ img_p.0[0] /= 2;
+ imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
+ }
+ }
+ }
+ let (x, y) = (
+ (args.point_x * img.width() as f64) as i32,
+ (args.point_y * img.height() as f64) as i32,
+ );
+ imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
+ .save("sam_merged.jpg")?
+ }
+ Ok(())
+}
diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md
new file mode 100644
index 00000000..ee83b3f9
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/README.md
@@ -0,0 +1,63 @@
+# candle-stable-diffusion: A Diffusers API in Rust/Candle
+
+![rusty robot holding a candle](./assets/stable-diffusion-xl.jpg)
+
+_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion
+XL using Rust and [candle](https://github.com/huggingface/candle).
+
+The `stable-diffusion` example is a conversion of
+[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
+rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
+as well as Stable Diffusion XL 1.0.
+
+## Getting the weights
+
+The weights are automatically downloaded for you from the [HuggingFace
+Hub](https://huggingface.co/) on the first run. There are various command line
+flags to use local files instead, run with `--help` to learn about them.
+
+## Running some example.
+
+```bash
+cargo run --example stable-diffusion --release --features=cuda,cudnn \
+ -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
+```
+
+The final image is named `sd_final.png` by default.
+The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
+original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
+
+### Command-line flags
+
+- `--prompt`: the prompt to be used to generate the image.
+- `--uncond-prompt`: the optional unconditional prompt.
+- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
+ `xl`.
+- `--cpu`: use the cpu rather than the gpu (much slower).
+- `--height`, `--width`: set the height and width for the generated image.
+- `--n-steps`: the number of steps to be used in the diffusion process.
+- `--num-samples`: the number of samples to generate.
+- `--final-image`: the filename for the generated image(s).
+
+### Using flash-attention
+
+Using flash attention makes image generation a lot faster and uses less memory.
+The downside is some long compilation time. You can set the
+`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like
+`/home/user/.candle` to ensures that the compilation artifacts are properly
+cached.
+
+Enabling flash-attention requires both a feature flag, `--feature flash-attn`
+and using the command line flag `--use-flash-attn`.
+
+## Image to Image Pipeline
+...
+
+## FAQ
+
+### Memory Issues
+
+This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used
+with the `--cpu` flag but is much slower.
+Alternatively, reducing the height and width with the `--height` and `--width`
+flag is likely to reduce memory usage significantly.
diff --git a/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg b/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
new file mode 100644
index 00000000..a6f7b6c6
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Binary files differ
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 8372edcd..c8b771a0 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -4,20 +4,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-mod attention;
-mod clip;
-mod ddim;
-mod embeddings;
-mod resnet;
-mod schedulers;
-mod stable_diffusion;
-mod unet_2d;
-mod unet_2d_blocks;
-mod utils;
-mod vae;
+use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor, D};
+use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@@ -96,6 +86,15 @@ struct Args {
#[arg(long)]
use_f16: bool,
+
+ #[arg(long, value_name = "FILE")]
+ img2img: Option<String>,
+
+ /// The strength, indicates how much to transform the initial image. The
+ /// value must be between 0 and 1, a value of 1 discards the initial image
+ /// information.
+ #[arg(long, default_value_t = 0.8)]
+ img2img_strength: f64,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@@ -306,6 +305,26 @@ fn text_embeddings(
Ok(text_embeddings)
}
+fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
+ let img = image::io::Reader::open(path)?.decode()?;
+ let (height, width) = (img.height() as usize, img.width() as usize);
+ let height = height - height % 32;
+ let width = width - width % 32;
+ let img = img.resize_to_fill(
+ width as u32,
+ height as u32,
+ image::imageops::FilterType::CatmullRom,
+ );
+ let img = img.to_rgb8();
+ let img = img.into_raw();
+ let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
+ .permute((2, 0, 1))?
+ .to_dtype(DType::F32)?
+ .affine(2. / 255., -1.)?
+ .unsqueeze(0)?;
+ Ok(img)
+}
+
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@@ -328,9 +347,15 @@ fn run(args: Args) -> Result<()> {
tracing,
use_f16,
use_flash_attn,
+ img2img,
+ img2img_strength,
..
} = args;
+ if !(0. ..=1.).contains(&img2img_strength) {
+ anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
+ }
+
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@@ -382,25 +407,53 @@ fn run(args: Args) -> Result<()> {
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
+ let init_latent_dist = match &img2img {
+ None => None,
+ Some(image) => {
+ let image = image_preprocess(image)?.to_device(&device)?;
+ Some(vae.encode(&image)?)
+ }
+ };
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
+ let t_start = if img2img.is_some() {
+ n_steps - (n_steps as f64 * img2img_strength) as usize
+ } else {
+ 0
+ };
let bsize = 1;
for idx in 0..num_samples {
- let mut latents = Tensor::randn(
- 0f32,
- 1f32,
- (bsize, 4, sd_config.height / 8, sd_config.width / 8),
- &device,
- )?
- .to_dtype(dtype)?;
-
- // scale the initial noise by the standard deviation required by the scheduler
- latents = (latents * scheduler.init_noise_sigma())?;
+ let timesteps = scheduler.timesteps();
+ let latents = match &init_latent_dist {
+ Some(init_latent_dist) => {
+ let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
+ if t_start < timesteps.len() {
+ let noise = latents.randn_like(0f64, 1f64)?;
+ scheduler.add_noise(&latents, noise, timesteps[t_start])?
+ } else {
+ latents
+ }
+ }
+ None => {
+ let latents = Tensor::randn(
+ 0f32,
+ 1f32,
+ (bsize, 4, sd_config.height / 8, sd_config.width / 8),
+ &device,
+ )?;
+ // scale the initial noise by the standard deviation required by the scheduler
+ (latents * scheduler.init_noise_sigma())?
+ }
+ };
+ let mut latents = latents.to_dtype(dtype)?;
println!("starting sampling");
- for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
+ for (timestep_index, &timestep) in timesteps.iter().enumerate() {
+ if timestep_index < t_start {
+ continue;
+ }
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md
new file mode 100644
index 00000000..6a406467
--- /dev/null
+++ b/candle-examples/examples/t5/README.md
@@ -0,0 +1,25 @@
+# candle-t5
+
+## Encoder-decoder example:
+
+```bash
+$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
+...
+Running on CPU, to run on GPU, build this example with `--features cuda`
+ Eine schöne Kerze.
+9 tokens generated (2.42 token/s)
+```
+
+## Sentence embedding example:
+
+```bash
+$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
+...
+[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
+ [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
+ [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
+ [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
+ [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
+Tensor[[1, 5, 512], f32]
+Took 303.766583ms
+```
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
new file mode 100644
index 00000000..55929c33
--- /dev/null
+++ b/candle-examples/examples/t5/main.rs
@@ -0,0 +1,314 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+use std::io::Write;
+use std::path::PathBuf;
+
+use candle_transformers::models::t5;
+
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+const DTYPE: DType = DType::F32;
+
+#[derive(Parser, Debug, Clone)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model repository to use on the HuggingFace hub.
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ /// Enable decoding.
+ #[arg(long)]
+ decode: bool,
+
+ // Enable/disable decoding.
+ #[arg(long, default_value = "false")]
+ disable_cache: bool,
+
+ /// Use this prompt, otherwise compute sentence similarities.
+ #[arg(long)]
+ prompt: Option<String>,
+
+ /// If set along with --decode, will use this prompt to initialize the decoder.
+ #[arg(long)]
+ decoder_prompt: Option<String>,
+
+ /// L2 normalization for embeddings.
+ #[arg(long, default_value = "true")]
+ normalize_embeddings: bool,
+
+ /// The temperature used to generate samples.
+ #[arg(long, default_value_t = 0.8)]
+ temperature: f64,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+struct T5ModelBuilder {
+ device: Device,
+ config: t5::Config,
+ weights_filename: Vec<PathBuf>,
+}
+
+impl T5ModelBuilder {
+ pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
+ let device = candle_examples::device(args.cpu)?;
+ let default_model = "t5-small".to_string();
+ let default_revision = "refs/pr/15".to_string();
+ let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ let config_filename = api.get("config.json")?;
+ let tokenizer_filename = api.get("tokenizer.json")?;
+ let weights_filename = if model_id == "google/flan-t5-xxl" {
+ vec![
+ api.get("model-00001-of-00005.safetensors")?,
+ api.get("model-00002-of-00005.safetensors")?,
+ api.get("model-00003-of-00005.safetensors")?,
+ api.get("model-00004-of-00005.safetensors")?,
+ api.get("model-00005-of-00005.safetensors")?,
+ ]
+ } else {
+ vec![api.get("model.safetensors")?]
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let mut config: t5::Config = serde_json::from_str(&config)?;
+ config.use_cache = !args.disable_cache;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ Ok((
+ Self {
+ device,
+ config,
+ weights_filename,
+ },
+ tokenizer,
+ ))
+ }
+
+ pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
+ let weights = self
+ .weights_filename
+ .iter()
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let weights = weights
+ .iter()
+ .map(|w| w.deserialize())
+ .collect::<candle::Result<Vec<_>>>()?;
+ let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ Ok(t5::T5EncoderModel::load(vb, &self.config)?)
+ }
+
+ pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
+ let weights = self
+ .weights_filename
+ .iter()
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let weights = weights
+ .iter()
+ .map(|w| w.deserialize())
+ .collect::<candle::Result<Vec<_>>>()?;
+ let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
+ }
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
+ let device = &builder.device;
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ match args.prompt {
+ Some(prompt) => {
+ let tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ if !args.decode {
+ let mut model = builder.build_encoder()?;
+ let start = std::time::Instant::now();
+ let ys = model.forward(&input_token_ids)?;
+ println!("{ys}");
+ println!("Took {:?}", start.elapsed());
+ } else {
+ let mut model = builder.build_conditional_generation()?;
+ let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ if let Some(decoder_prompt) = &args.decoder_prompt {
+ print!("{decoder_prompt}");
+ output_token_ids.extend(
+ tokenizer
+ .encode(decoder_prompt.to_string(), false)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec(),
+ );
+ }
+ let temperature = if args.temperature <= 0. {
+ None
+ } else {
+ Some(args.temperature)
+ };
+ let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
+ let encoder_output = model.encode(&input_token_ids)?;
+ let start = std::time::Instant::now();
+
+ for index in 0.. {
+ if output_token_ids.len() > 512 {
+ break;
+ }
+ let decoder_token_ids = if index == 0 || !builder.config.use_cache {
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
+ } else {
+ let last_token = *output_token_ids.last().unwrap();
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
+ };
+ let logits = model
+ .decode(&decoder_token_ids, &encoder_output)?
+ .squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &output_token_ids[start_at..],
+ )?
+ };
+
+ let next_token_id = logits_processor.sample(&logits)?;
+ if next_token_id as usize == builder.config.eos_token_id {
+ break;
+ }
+ output_token_ids.push(next_token_id);
+ if let Some(text) = tokenizer.id_to_token(next_token_id) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ print!("{text}");
+ std::io::stdout().flush()?;
+ }
+ }
+ let dt = start.elapsed();
+ println!(
+ "\n{} tokens generated ({:.2} token/s)\n",
+ output_token_ids.len(),
+ output_token_ids.len() as f64 / dt.as_secs_f64(),
+ );
+ }
+ }
+ None => {
+ let mut model = builder.build_encoder()?;
+ let sentences = [
+ "The cat sits outside",
+ "A man is playing guitar",
+ "I love pasta",
+ "The new movie is awesome",
+ "The cat plays in the garden",
+ "A woman watches TV",
+ "The new movie is so great",
+ "Do you like pizza?",
+ ];
+ let n_sentences = sentences.len();
+ let mut all_embeddings = Vec::with_capacity(n_sentences);
+ for sentence in sentences {
+ let tokens = tokenizer
+ .encode(sentence, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
+ let embeddings = model.forward(&token_ids)?;
+ println!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if args.normalize_embeddings {
+ normalize_l2(&embeddings)?
+ } else {
+ embeddings
+ };
+ println!("pooled embeddings {:?}", embeddings.shape());
+ all_embeddings.push(embeddings)
+ }
+
+ let mut similarities = vec![];
+ for (i, e_i) in all_embeddings.iter().enumerate() {
+ for (j, e_j) in all_embeddings
+ .iter()
+ .enumerate()
+ .take(n_sentences)
+ .skip(i + 1)
+ {
+ let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
+ let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
+ similarities.push((cosine_similarity, i, j))
+ }
+ }
+ similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
+ for &(score, i, j) in similarities[..5].iter() {
+ println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
+ }
+ }
+ }
+ Ok(())
+}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}
diff --git a/candle-examples/examples/whisper/README.md b/candle-examples/examples/whisper/README.md
new file mode 100644
index 00000000..124cd182
--- /dev/null
+++ b/candle-examples/examples/whisper/README.md
@@ -0,0 +1,39 @@
+# candle-whisper: speech recognition
+
+An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using
+candle. Whisper is a general purpose speech recognition model, it can be used to
+convert audio files (in the `.wav` format) to text. Supported features include
+language detection as well as multilingual speech recognition.
+
+## Running some example
+
+If no audio file is passed as input, a [sample
+file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded
+from the hub.
+
+```bash
+ cargo run --example whisper --release
+
+> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
+> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
+> pcm data loaded 176000
+> loaded mel: [1, 80, 3000]
+> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country
+ ```
+
+ In order to use the multilingual mode, specify a multilingual model via the
+ `--model` flag, see the details below.
+
+## Command line flags
+
+- `--input`: the audio file to be converted to text, in wav format.
+- `--language`: force the language to some specific value rather than being
+ detected, e.g. `en`.
+- `--task`: the task to be performed, can be `transcribe` (return the text data
+ in the original language) or `translate` (translate the text to English).
+- `--timestamps`: enable the timestamp mode where some timestamps are reported
+ for each recognized audio extracts.
+- `--model`: the model to be used. Models that do not end with `-en` are
+ multilingual models, other ones are English only models. The supported models
+ are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
+ `medium.en`, `large`, and `large-v2`.
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 5dd8ee20..c71d562a 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,41 +10,16 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
-mod audio;
-mod model;
-use model::{Config, Whisper};
mod multilingual;
-
-const DTYPE: DType = DType::F32;
-
-// Audio parameters.
-const SAMPLE_RATE: usize = 16000;
-const N_FFT: usize = 400;
-const N_MELS: usize = 80;
-const HOP_LENGTH: usize = 160;
-const CHUNK_LENGTH: usize = 30;
-const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
-const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-
-const NO_SPEECH_THRESHOLD: f64 = 0.6;
-const LOGPROB_THRESHOLD: f64 = -1.0;
-const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
-const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
-
-// Tokenizer dependent bits.
-const SOT_TOKEN: &str = "<|startoftranscript|>";
-const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
-const TRANSLATE_TOKEN: &str = "<|translate|>";
-const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
-const EOT_TOKEN: &str = "<|endoftext|>";
-const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
+use candle_transformers::models::whisper::{self as m, audio, model};
+use model::{Config, Whisper};
#[allow(dead_code)]
#[derive(Debug, Clone)]
@@ -94,7 +69,7 @@ impl Decoder {
timestamps: bool,
verbose: bool,
) -> Result<Self> {
- let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
+ let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
@@ -109,11 +84,11 @@ impl Decoder {
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
- let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
- let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
- let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
- let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
- let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
+ let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
+ let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
+ let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
+ let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
+ let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@@ -220,17 +195,17 @@ impl Decoder {
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
- for (i, &t) in TEMPERATURES.iter().enumerate() {
+ for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
- if i == TEMPERATURES.len() - 1 {
+ if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
- let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
- || dr.avg_logprob < LOGPROB_THRESHOLD;
- if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
+ let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < m::LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
@@ -248,13 +223,13 @@ impl Decoder {
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
- let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
- let segment_size = usize::min(content_frames - seek, N_FRAMES);
+ let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
- let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
- if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
+ if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
@@ -431,7 +406,6 @@ fn main() -> Result<()> {
let args = Args::parse();
let _guard = if args.tracing {
- println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
@@ -493,8 +467,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
- if header.sampling_rate != SAMPLE_RATE as u32 {
- anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
+ if header.sampling_rate != m::SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@@ -502,14 +476,14 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
- let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
- let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
+ let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;
diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs
index bc0bae1f..a82b09ef 100644
--- a/candle-examples/examples/whisper/multilingual.rs
+++ b/candle-examples/examples/whisper/multilingual.rs
@@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
- let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
+ let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
diff --git a/candle-examples/examples/wuerstchen/README.md b/candle-examples/examples/wuerstchen/README.md
new file mode 100644
index 00000000..1b8accd1
--- /dev/null
+++ b/candle-examples/examples/wuerstchen/README.md
@@ -0,0 +1,27 @@
+# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
+
+![anthropomorphic cat dressed as a fire fighter](./assets/cat.jpg)
+
+The `wuerstchen` example is a port of the [diffusers
+implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
+The candle implementation reproduces the same structure/files for models and
+pipelines. Useful resources:
+
+- [Official implementation](https://github.com/dome272/Wuerstchen).
+- [Arxiv paper](https://arxiv.org/abs/2306.00637).
+- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
+
+## Getting the weights
+
+The weights are automatically downloaded for you from the [HuggingFace
+Hub](https://huggingface.co/) on the first run. There are various command line
+flags to use local files instead, run with `--help` to learn about them.
+
+## Running some example.
+
+```bash
+cargo run --example wuerstchen --release --features cuda,cudnn -- \
+ --prompt "Anthropomorphic cat dressed as a fire fighter"
+```
+
+The final image is named `sd_final.png` by default.
diff --git a/candle-examples/examples/wuerstchen/assets/cat.jpg b/candle-examples/examples/wuerstchen/assets/cat.jpg
new file mode 100644
index 00000000..9ff67183
--- /dev/null
+++ b/candle-examples/examples/wuerstchen/assets/cat.jpg
Binary files differ
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
new file mode 100644
index 00000000..95f3b8f4
--- /dev/null
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -0,0 +1,396 @@
+#![allow(unused)]
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use candle_transformers::models::stable_diffusion;
+use candle_transformers::models::wuerstchen;
+
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, IndexOp, Module, Tensor, D};
+use clap::Parser;
+use tokenizers::Tokenizer;
+
+const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
+const RESOLUTION_MULTIPLE: f64 = 42.67;
+const LATENT_DIM_SCALE: f64 = 10.67;
+const PRIOR_CIN: usize = 16;
+const DECODER_CIN: usize = 4;
+
+#[derive(Parser)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The prompt to be used for image generation.
+ #[arg(
+ long,
+ default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
+ )]
+ prompt: String,
+
+ #[arg(long, default_value = "")]
+ uncond_prompt: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ use_flash_attn: bool,
+
+ /// The height in pixels of the generated image.
+ #[arg(long)]
+ height: Option<usize>,
+
+ /// The width in pixels of the generated image.
+ #[arg(long)]
+ width: Option<usize>,
+
+ /// The decoder weight file, in .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ decoder_weights: Option<String>,
+
+ /// The CLIP weight file, in .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ clip_weights: Option<String>,
+
+ /// The CLIP weight file used by the prior model, in .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ prior_clip_weights: Option<String>,
+
+ /// The prior weight file, in .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ prior_weights: Option<String>,
+
+ /// The VQGAN weight file, in .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ vqgan_weights: Option<String>,
+
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for tokenization.
+ tokenizer: Option<String>,
+
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for prior tokenization.
+ prior_tokenizer: Option<String>,
+
+ /// The size of the sliced attention or 0 for automatic slicing (disabled by default)
+ #[arg(long)]
+ sliced_attention_size: Option<usize>,
+
+ /// The number of steps to run the diffusion for.
+ #[arg(long, default_value_t = 30)]
+ n_steps: usize,
+
+ /// The number of samples to generate.
+ #[arg(long, default_value_t = 1)]
+ num_samples: i64,
+
+ /// The name of the final image to generate.
+ #[arg(long, value_name = "FILE", default_value = "sd_final.png")]
+ final_image: String,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum ModelFile {
+ Tokenizer,
+ PriorTokenizer,
+ Clip,
+ PriorClip,
+ Decoder,
+ VqGan,
+ Prior,
+}
+
+impl ModelFile {
+ fn get(&self, filename: Option<String>) -> Result<std::path::PathBuf> {
+ use hf_hub::api::sync::Api;
+ match filename {
+ Some(filename) => Ok(std::path::PathBuf::from(filename)),
+ None => {
+ let repo_main = "warp-ai/wuerstchen";
+ let repo_prior = "warp-ai/wuerstchen-prior";
+ let (repo, path) = match self {
+ Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
+ Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
+ Self::Clip => (repo_main, "text_encoder/model.safetensors"),
+ Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
+ Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
+ Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
+ Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
+ };
+ let filename = Api::new()?.model(repo.to_string()).get(path)?;
+ Ok(filename)
+ }
+ }
+ }
+}
+
+fn output_filename(
+ basename: &str,
+ sample_idx: i64,
+ num_samples: i64,
+ timestep_idx: Option<usize>,
+) -> String {
+ let filename = if num_samples > 1 {
+ match basename.rsplit_once('.') {
+ None => format!("{basename}.{sample_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}.{sample_idx}.{extension}")
+ }
+ }
+ } else {
+ basename.to_string()
+ };
+ match timestep_idx {
+ None => filename,
+ Some(timestep_idx) => match filename.rsplit_once('.') {
+ None => format!("{filename}-{timestep_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}-{timestep_idx}.{extension}")
+ }
+ },
+ }
+}
+
+fn encode_prompt(
+ prompt: &str,
+ uncond_prompt: Option<&str>,
+ tokenizer: std::path::PathBuf,
+ clip_weights: std::path::PathBuf,
+ clip_config: stable_diffusion::clip::Config,
+ device: &Device,
+) -> Result<Tensor> {
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &clip_config.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let tokens_len = tokens.len();
+ while tokens.len() < clip_config.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ println!("Building the clip transformer.");
+ let text_model =
+ stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
+ let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
+ match uncond_prompt {
+ None => Ok(text_embeddings),
+ Some(uncond_prompt) => {
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let uncond_tokens_len = uncond_tokens.len();
+ while uncond_tokens.len() < clip_config.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ let uncond_embeddings =
+ text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
+ let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
+ Ok(text_embeddings)
+ }
+ }
+}
+
+fn run(args: Args) -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let Args {
+ prompt,
+ uncond_prompt,
+ cpu,
+ height,
+ width,
+ n_steps,
+ tokenizer,
+ final_image,
+ sliced_attention_size,
+ num_samples,
+ clip_weights,
+ prior_weights,
+ vqgan_weights,
+ decoder_weights,
+ tracing,
+ ..
+ } = args;
+
+ let _guard = if tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ let device = candle_examples::device(cpu)?;
+ let height = height.unwrap_or(1024);
+ let width = width.unwrap_or(1024);
+
+ let prior_text_embeddings = {
+ let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
+ let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
+ encode_prompt(
+ &prompt,
+ Some(&uncond_prompt),
+ tokenizer.clone(),
+ weights,
+ stable_diffusion::clip::Config::wuerstchen_prior(),
+ &device,
+ )?
+ };
+ println!("generated prior text embeddings {prior_text_embeddings:?}");
+
+ let text_embeddings = {
+ let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
+ let weights = ModelFile::Clip.get(clip_weights)?;
+ encode_prompt(
+ &prompt,
+ None,
+ tokenizer.clone(),
+ weights,
+ stable_diffusion::clip::Config::wuerstchen(),
+ &device,
+ )?
+ };
+ println!("generated text embeddings {text_embeddings:?}");
+
+ println!("Building the prior.");
+ let b_size = 1;
+ let image_embeddings = {
+ // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
+ let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
+ let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
+ let mut latents = Tensor::randn(
+ 0f32,
+ 1f32,
+ (b_size, PRIOR_CIN, latent_height, latent_width),
+ &device,
+ )?;
+
+ let prior = {
+ let prior_weights = ModelFile::Prior.get(prior_weights)?;
+ let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
+ let weights = weights.deserialize()?;
+ let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ wuerstchen::prior::WPrior::new(
+ /* c_in */ PRIOR_CIN,
+ /* c */ 1536,
+ /* c_cond */ 1280,
+ /* c_r */ 64,
+ /* depth */ 32,
+ /* nhead */ 24,
+ args.use_flash_attn,
+ vb,
+ )?
+ };
+ let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
+ let timesteps = prior_scheduler.timesteps();
+ let timesteps = &timesteps[..timesteps.len() - 1];
+ println!("prior denoising");
+ for (index, &t) in timesteps.iter().enumerate() {
+ let start_time = std::time::Instant::now();
+ let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+ let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
+ let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
+ let noise_pred = (noise_pred_uncond
+ + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
+ latents = prior_scheduler.step(&noise_pred, t, &latents)?;
+ let dt = start_time.elapsed().as_secs_f32();
+ println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
+ }
+ ((latents * 42.)? - 1.)?
+ };
+
+ println!("Building the vqgan.");
+ let vqgan = {
+ let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
+ let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
+ let weights = weights.deserialize()?;
+ let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ wuerstchen::paella_vq::PaellaVQ::new(vb)?
+ };
+
+ println!("Building the decoder.");
+
+ // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
+ let decoder = {
+ let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
+ let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
+ let weights = weights.deserialize()?;
+ let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ wuerstchen::diffnext::WDiffNeXt::new(
+ /* c_in */ DECODER_CIN,
+ /* c_out */ DECODER_CIN,
+ /* c_r */ 64,
+ /* c_cond */ 1024,
+ /* clip_embd */ 1024,
+ /* patch_size */ 2,
+ args.use_flash_attn,
+ vb,
+ )?
+ };
+
+ for idx in 0..num_samples {
+ // https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
+ let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
+ let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
+
+ let mut latents = Tensor::randn(
+ 0f32,
+ 1f32,
+ (b_size, DECODER_CIN, latent_height, latent_width),
+ &device,
+ )?;
+
+ println!("diffusion process with prior {image_embeddings:?}");
+ let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
+ let timesteps = scheduler.timesteps();
+ let timesteps = &timesteps[..timesteps.len() - 1];
+ for (index, &t) in timesteps.iter().enumerate() {
+ let start_time = std::time::Instant::now();
+ let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
+ let noise_pred =
+ decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
+ latents = scheduler.step(&noise_pred, t, &latents)?;
+ let dt = start_time.elapsed().as_secs_f32();
+ println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
+ }
+ println!(
+ "Generating the final image for sample {}/{}.",
+ idx + 1,
+ num_samples
+ );
+ let image = vqgan.decode(&(&latents * 0.3764)?)?;
+ // TODO: Add the clamping between 0 and 1.
+ let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
+ let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
+ candle_examples::save_image(&image, image_filename)?
+ }
+ Ok(())
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ run(args)
+}
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index 5e388921..ecf75bdf 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_examples::object_detection::{non_maximum_suppression, Bbox};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;
@@ -46,7 +46,7 @@ pub fn report(
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
- let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
+ let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
@@ -65,7 +65,7 @@ pub fn report(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints: vec![],
+ data: (),
};
bboxes[class_index].push(bbox)
}
diff --git a/candle-examples/examples/yolo-v8/README.md b/candle-examples/examples/yolo-v8/README.md
new file mode 100644
index 00000000..938dea13
--- /dev/null
+++ b/candle-examples/examples/yolo-v8/README.md
@@ -0,0 +1,47 @@
+# candle-yolo-v8: Object Detection and Pose Estimation
+
+This is a port of [Ultralytics
+YOLOv8](https://github.com/ultralytics/ultralytics). The implementation is based
+on the [tinygrad
+version](https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py)
+and on the model architecture described in this
+[issue](https://github.com/ultralytics/ultralytics/issues/189). The supported
+tasks are object detection and pose estimation.
+
+You can try this model online on the [Candle YOLOv8
+Space](https://huggingface.co/spaces/lmz/candle-yolo). The model then fully runs
+in your browser using WebAssembly - if you use a custom image it will never
+leave your phone/computer!
+
+## Running some example
+
+### Object Detection
+```bash
+cargo run --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
+```
+
+This prints details about the detected objects and generates a `bike.pp.jpg` file.
+
+![Leading group, Giro d'Italia 2021](./assets/bike.jpg)
+
+Image source:
+[wikimedia](https://commons.wikimedia.org/wiki/File:Leading_group,_Giro_d%27Italia_2021,_Stage_15.jpg).
+
+![Leading group, Giro d'Italia 2021](./assets/bike.od.jpg)
+
+### Pose Estimation
+```bash
+cargo run --example yolo-v8 --release -- \
+ candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
+```
+
+![Leading group, Giro d'Italia 2021](./assets/bike.pose.jpg)
+
+### Command-line flags
+
+- `--which`: select the model variant to be used, `n`, `s` , `m`, `l`, or `x` by
+ increasing size and quality.
+- `--task`: `detect` for object detection and `pose` for pose estimation.
+- `--legend-size`: the size of the characters to print.
+- `--model`: use a local model file rather than downloading it from the hub.
+
diff --git a/candle-examples/examples/yolo-v8/assets/bike.jpg b/candle-examples/examples/yolo-v8/assets/bike.jpg
new file mode 100644
index 00000000..05d1faaf
--- /dev/null
+++ b/candle-examples/examples/yolo-v8/assets/bike.jpg
Binary files differ
diff --git a/candle-examples/examples/yolo-v8/assets/bike.od.jpg b/candle-examples/examples/yolo-v8/assets/bike.od.jpg
new file mode 100644
index 00000000..111b9286
--- /dev/null
+++ b/candle-examples/examples/yolo-v8/assets/bike.od.jpg
Binary files differ
diff --git a/candle-examples/examples/yolo-v8/assets/bike.pose.jpg b/candle-examples/examples/yolo-v8/assets/bike.pose.jpg
new file mode 100644
index 00000000..e660f65b
--- /dev/null
+++ b/candle-examples/examples/yolo-v8/assets/bike.pose.jpg
Binary files differ
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index d5c5ac1c..d48bac35 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
-use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;
@@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
- let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
+ let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints: vec![],
+ data: vec![],
};
bboxes[class_index].push(bbox)
}
@@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints,
+ data: keypoints,
};
bboxes.push(bbox)
}
@@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]),
);
}
- for kp in b.keypoints.iter() {
+ for kp in b.data.iter() {
if kp.mask < 0.6 {
continue;
}
@@ -219,8 +219,8 @@ pub fn report_pose(
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
- let kp1 = &b.keypoints[idx1];
- let kp2 = &b.keypoints[idx2];
+ let kp1 = &b.data[idx1];
+ let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue;
}
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 395162eb..5e0b44fb 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -1,6 +1,5 @@
pub mod coco_classes;
pub mod imagenet;
-pub mod object_detection;
use candle::{Device, Result, Tensor};
@@ -16,6 +15,36 @@ pub fn device(cpu: bool) -> Result<Device> {
}
}
+pub fn load_image<P: AsRef<std::path::Path>>(
+ p: P,
+ resize_longest: Option<usize>,
+) -> Result<(Tensor, usize, usize)> {
+ let img = image::io::Reader::open(p)?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
+ let img = match resize_longest {
+ None => img,
+ Some(resize_longest) => {
+ let (height, width) = (img.height(), img.width());
+ let resize_longest = resize_longest as u32;
+ let (height, width) = if height < width {
+ let h = (resize_longest * height) / width;
+ (h, resize_longest)
+ } else {
+ let w = (resize_longest * width) / height;
+ (resize_longest, w)
+ };
+ img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
+ }
+ };
+ let (height, width) = (img.height() as usize, img.width() as usize);
+ let img = img.to_rgb8();
+ let data = img.into_raw();
+ let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ Ok((data, initial_h, initial_w))
+}
+
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
@@ -35,20 +64,44 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
}
/// Saves an image to disk using the image crate, this expects an input with shape
-/// (c, width, height).
+/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref();
- let (channel, width, height) = img.dims3()?;
+ let (channel, height, width) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, height, width)")
+ }
+ let img = img.permute((1, 2, 0))?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}
+
+pub fn save_image_resize<P: AsRef<std::path::Path>>(
+ img: &Tensor,
+ p: P,
+ h: usize,
+ w: usize,
+) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, height, width) = img.dims3()?;
if channel != 3 {
- candle::bail!("save_image expects an input of shape (3, width, height)")
+ candle::bail!("save_image expects an input of shape (3, height, width)")
}
- let img = img.transpose(0, 1)?.t()?.flatten_all()?;
+ let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
+ let image = image::DynamicImage::from(image);
+ let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index 0d130519..808e0070 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
-version = "0.2.1"
+version = "0.2.3"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
-candle = { path = "../candle-core", features = ["cuda"], version = "0.2.1", package = "candle-core" }
+candle = { path = "../candle-core", features = ["cuda"], version = "0.2.3", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
-candle-nn = { path = "../candle-nn", version = "0.2.1", features = ["cuda"] }
+candle-nn = { path = "../candle-nn", version = "0.2.3", features = ["cuda"] }
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 773c5638..64275fda 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -6,7 +6,7 @@ use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
-const KERNEL_FILES: [&str; 9] = [
+const KERNEL_FILES: [&str; 17] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
@@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
"flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu",
- // "flash_fwd_hdim128_bf16_sm80.cu",
- // "flash_fwd_hdim160_bf16_sm80.cu",
- // "flash_fwd_hdim192_bf16_sm80.cu",
- // "flash_fwd_hdim224_bf16_sm80.cu",
- // "flash_fwd_hdim256_bf16_sm80.cu",
- // "flash_fwd_hdim32_bf16_sm80.cu",
- // "flash_fwd_hdim64_bf16_sm80.cu",
- // "flash_fwd_hdim96_bf16_sm80.cu",
+ "flash_fwd_hdim128_bf16_sm80.cu",
+ "flash_fwd_hdim160_bf16_sm80.cu",
+ "flash_fwd_hdim192_bf16_sm80.cu",
+ "flash_fwd_hdim224_bf16_sm80.cu",
+ "flash_fwd_hdim256_bf16_sm80.cu",
+ "flash_fwd_hdim32_bf16_sm80.cu",
+ "flash_fwd_hdim64_bf16_sm80.cu",
+ "flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
@@ -57,9 +57,20 @@ fn main() -> Result<()> {
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
- Ok(build_dir) => PathBuf::from(build_dir),
+ Ok(build_dir) => {
+ let path = PathBuf::from(build_dir);
+ path.canonicalize().expect(&format!(
+ "Directory doesn't exists: {} (the current directory is {})",
+ &path.display(),
+ std::env::current_dir()?.display()
+ ))
+ }
};
set_cuda_include_dir()?;
+
+ let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
+ println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
+
let compute_cap = compute_cap()?;
let out_file = build_dir.join("libflashattention.a");
@@ -95,14 +106,21 @@ fn main() -> Result<()> {
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
- .arg(cu_file);
+ .arg("--verbose");
+ if let Ok(ccbin_path) = &ccbin_env {
+ command
+ .arg("-allow-unsupported-compiler")
+ .args(["-ccbin", ccbin_path]);
+ }
+ command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
- "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ &command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
@@ -122,7 +140,8 @@ fn main() -> Result<()> {
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
- "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ &command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index d928bcb6..72991257 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,20 +1,19 @@
#include "flash_fwd_launch_template.h"
-// TODO: Switch back to handling bf16.
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
- FWD_HEADDIM_SWITCH(params.d, [&] {
- run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
- });
-}
-
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
-// FP16_SWITCH(!params.is_bf16, [&] {
-// FWD_HEADDIM_SWITCH(params.d, [&] {
-// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-// });
+// FWD_HEADDIM_SWITCH(params.d, [&] {
+// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+ });
+ });
+}
+
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
@@ -52,7 +51,8 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
- int is_causal
+ int is_causal,
+ int is_bf16
) {
Flash_fwd_params params;
// Reset the parameters
@@ -102,7 +102,7 @@ extern "C" void run_mha(
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
- params.is_bf16 = 0;
+ params.is_bf16 = is_bf16;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs
index ae61c405..90f34e43 100644
--- a/candle-flash-attn/src/ffi.rs
+++ b/candle-flash-attn/src/ffi.rs
@@ -38,6 +38,7 @@ extern "C" {
seqlen_k_rounded: u32,
is_causal: c_int,
+ is_bf16: c_int,
);
}
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 3c5fd455..61980a58 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
-use half::f16;
+use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
@@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
-impl candle::CustomOp3 for FlashAttn {
- fn name(&self) -> &'static str {
- "flash-attn"
- }
-
- fn cpu_fwd(
- &self,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- ) -> Result<(CpuStorage, Shape)> {
- candle::bail!("no cpu support for flash-attn")
- }
-
- fn cuda_fwd(
+impl FlashAttn {
+ fn cuda_fwd_t<
+ T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
+ >(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@@ -40,15 +26,16 @@ impl candle::CustomOp3 for FlashAttn {
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
+ is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
let dev = q.device();
let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape);
- let q = q.as_cuda_slice::<f16>()?;
- let k = k.as_cuda_slice::<f16>()?;
- let v = v.as_cuda_slice::<f16>()?;
+ let q = q.as_cuda_slice::<T>()?;
+ let k = k.as_cuda_slice::<T>()?;
+ let v = v.as_cuda_slice::<T>()?;
let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..);
let v = v.slice(v_l.start_offset()..);
@@ -104,10 +91,11 @@ impl candle::CustomOp3 for FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
- let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
+ let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
+ let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@@ -146,6 +134,7 @@ impl candle::CustomOp3 for FlashAttn {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
+ /* is_bf16 */ is_bf16,
)
}
@@ -154,6 +143,40 @@ impl candle::CustomOp3 for FlashAttn {
}
}
+impl candle::CustomOp3 for FlashAttn {
+ fn name(&self) -> &'static str {
+ "flash-attn"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ candle::bail!("no cpu support for flash-attn")
+ }
+
+ fn cuda_fwd(
+ &self,
+ q: &candle::CudaStorage,
+ q_l: &Layout,
+ k: &candle::CudaStorage,
+ k_l: &Layout,
+ v: &candle::CudaStorage,
+ v_l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ match q.dtype() {
+ candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
+ candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
+ dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
+ }
+ }
+}
+
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
@@ -190,24 +213,10 @@ struct FlashAttnVarLen {
seqlens_k: Tensor,
}
-impl candle::CustomOp3 for FlashAttnVarLen {
- fn name(&self) -> &'static str {
- "flash-attn-varlen"
- }
-
- fn cpu_fwd(
- &self,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- ) -> Result<(CpuStorage, Shape)> {
- candle::bail!("no cpu support for flash-attn")
- }
-
- fn cuda_fwd(
+impl FlashAttnVarLen {
+ fn cuda_fwd_t<
+ T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
+ >(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@@ -215,6 +224,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
+ is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
let dev = q.device();
@@ -314,6 +324,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
.w()?;
let causal = if self.causal { 1 } else { 0 };
+ let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@@ -354,6 +365,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
+ /* is_bf16 */ is_bf16,
)
}
@@ -362,6 +374,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
}
}
+impl candle::CustomOp3 for FlashAttnVarLen {
+ fn name(&self) -> &'static str {
+ "flash-attn-varlen"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ candle::bail!("no cpu support for flash-attn")
+ }
+
+ fn cuda_fwd(
+ &self,
+ q: &candle::CudaStorage,
+ q_l: &Layout,
+ k: &candle::CudaStorage,
+ k_l: &Layout,
+ v: &candle::CudaStorage,
+ v_l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ match q.dtype() {
+ candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
+ candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
+ dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
+ }
+ }
+}
+
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml
index 576c52ea..80b6aaab 100644
--- a/candle-kernels/Cargo.toml
+++ b/candle-kernels/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
-version = "0.2.1"
+version = "0.2.3"
edition = "2021"
description = "CUDA kernels for Candle"
diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs
index 3c8e96a9..ad084671 100644
--- a/candle-kernels/build.rs
+++ b/candle-kernels/build.rs
@@ -164,6 +164,8 @@ mod cuda {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
+ let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
+ println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
@@ -188,8 +190,13 @@ mod cuda {
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
- .args(&include_options)
- .arg(p);
+ .args(&include_options);
+ if let Ok(ccbin_path) = &ccbin_env {
+ command
+ .arg("-allow-unsupported-compiler")
+ .args(["-ccbin", ccbin_path]);
+ }
+ command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}})
diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu
index ab2045a3..ee20fe5f 100644
--- a/candle-kernels/src/cast.cu
+++ b/candle-kernels/src/cast.cu
@@ -77,20 +77,30 @@ CAST_OP(double, __half, cast_f64_f16)
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
+CAST_OP(uint32_t, int64_t, cast_u32_i64 )
CAST_OP(uint32_t, float, cast_u32_f32)
CAST_OP(uint32_t, double, cast_u32_f64)
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
+CAST_OP(uint8_t, int64_t, cast_u8_i64 )
CAST_OP(uint8_t, float, cast_u8_f32)
CAST_OP(uint8_t, double, cast_u8_f64)
+CAST_OP(int64_t, uint32_t, cast_i64_u32)
+CAST_OP(int64_t, uint8_t, cast_i64_u8 )
+CAST_OP(int64_t, int64_t, cast_i64_i64 )
+CAST_OP(int64_t, float, cast_i64_f32)
+CAST_OP(int64_t, double, cast_i64_f64)
+
CAST_OP(float, uint8_t, cast_f32_u8 )
CAST_OP(float, uint32_t, cast_f32_u32)
+CAST_OP(float, int64_t, cast_f32_i64 )
CAST_OP(float, float, cast_f32_f32)
CAST_OP(float, double, cast_f32_f64)
CAST_OP(double, uint8_t, cast_f64_u8 )
CAST_OP(double, uint32_t, cast_f64_u32)
+CAST_OP(double, int64_t, cast_f64_i64 )
CAST_OP(double, float, cast_f64_f32)
CAST_OP(double, double, cast_f64_f64)
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index ba2fa1ad..9c8ce00f 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -51,6 +51,118 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d);
}
+template <typename T>
+__device__ void im2col1d(
+ const size_t dst_numel,
+ const size_t l_out,
+ const size_t l_k,
+ const size_t stride,
+ const size_t padding,
+ const size_t dilation,
+ const size_t *info,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // dst: (b_size, l_out, c_in, l_k)
+ // src: (b_size, c_in, l_in)
+ if (dst_i >= dst_numel) {
+ return;
+ }
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 3;
+ const size_t b_in = src_dims[0];
+ const size_t c_in = src_dims[1];
+ const size_t l_in = src_dims[2];
+
+ const size_t dst_s2 = l_k;
+ const size_t dst_s1 = c_in * dst_s2;
+ const size_t dst_s0 = l_out * dst_s1;
+
+ size_t tmp_dst_i = dst_i;
+ const size_t b_idx = tmp_dst_i / dst_s0;
+ tmp_dst_i -= b_idx * dst_s0;
+ const size_t l_idx = tmp_dst_i / dst_s1;
+ tmp_dst_i -= l_idx * dst_s1;
+ const size_t c_idx = tmp_dst_i / dst_s2;
+ tmp_dst_i -= c_idx * dst_s2;
+ const size_t l_k_idx = tmp_dst_i;
+ size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
+ if (src_l_idx < padding || src_l_idx >= l_in + padding) {
+ dst[dst_i] = static_cast<T>(0);
+ }
+ else {
+ src_l_idx -= padding;
+ const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
+ dst[dst_i] = src[src_i];
+ }
+}
+
+template <typename T>
+__device__ void im2col(
+ const size_t dst_numel,
+ const size_t h_out,
+ const size_t w_out,
+ const size_t h_k,
+ const size_t w_k,
+ const size_t stride,
+ const size_t padding,
+ const size_t dilation,
+ const size_t *info,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // dst: (b_size, h_out, w_out, c_in, h_k, w_k)
+ // src: (b_size, c_in, h_in, w_in)
+ if (dst_i >= dst_numel) {
+ return;
+ }
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+ const size_t b_in = src_dims[0];
+ const size_t c_in = src_dims[1];
+ const size_t h_in = src_dims[2];
+ const size_t w_in = src_dims[3];
+
+ const size_t dst_s4 = w_k;
+ const size_t dst_s3 = h_k * dst_s4;
+ const size_t dst_s2 = c_in * dst_s3;
+ const size_t dst_s1 = w_out * dst_s2;
+ const size_t dst_s0 = h_out * dst_s1;
+
+ size_t tmp_dst_i = dst_i;
+ const size_t b_idx = tmp_dst_i / dst_s0;
+ tmp_dst_i -= b_idx * dst_s0;
+ const size_t h_idx = tmp_dst_i / dst_s1;
+ tmp_dst_i -= h_idx * dst_s1;
+ const size_t w_idx = tmp_dst_i / dst_s2;
+ tmp_dst_i -= w_idx * dst_s2;
+ const size_t c_idx = tmp_dst_i / dst_s3;
+ tmp_dst_i -= c_idx * dst_s3;
+ const size_t h_k_idx = tmp_dst_i / dst_s4;
+ tmp_dst_i -= h_k_idx * dst_s4;
+ const size_t w_k_idx = tmp_dst_i;
+ size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
+ size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
+ if (src_h_idx < padding || src_h_idx >= h_in + padding) {
+ dst[dst_i] = static_cast<T>(0);
+ }
+ else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
+ dst[dst_i] = static_cast<T>(0);
+ }
+ else {
+ src_h_idx -= padding;
+ src_w_idx -= padding;
+ const size_t src_i =
+ b_idx * src_s[0]
+ + c_idx * src_s[1]
+ + src_h_idx * src_s[2]
+ + src_w_idx * src_s[3];
+ dst[dst_i] = src[src_i];
+ }
+}
+
// Naive implementation of conv2d.
template <typename T, typename A>
__device__ void conv2d(
@@ -363,6 +475,38 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
+#define IM2COL1D_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t dst_numel, \
+ const size_t l_out, \
+ const size_t l_k, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t dilation, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
+} \
+
+#define IM2COL_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t dst_numel, \
+ const size_t h_out, \
+ const size_t w_out, \
+ const size_t h_k, \
+ const size_t w_k, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t dilation, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
+} \
+
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@@ -428,6 +572,8 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
+IM2COL_OP(__nv_bfloat16, im2col_bf16)
+IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@@ -437,6 +583,8 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
+IM2COL_OP(__half, im2col_f16)
+IM2COL1D_OP(__half, im2col1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@@ -468,3 +616,13 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
+
+IM2COL_OP(float, im2col_f32)
+IM2COL_OP(double, im2col_f64)
+IM2COL_OP(uint8_t, im2col_u8)
+IM2COL_OP(uint32_t, im2col_u32)
+
+IM2COL1D_OP(float, im2col1d_f32)
+IM2COL1D_OP(double, im2col1d_f64)
+IM2COL1D_OP(uint8_t, im2col1d_u8)
+IM2COL1D_OP(uint32_t, im2col1d_u32)
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh
index 4096d2d1..8e46a07c 100644
--- a/candle-kernels/src/cuda_utils.cuh
+++ b/candle-kernels/src/cuda_utils.cuh
@@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
+__device__ __forceinline__ float erfg(float a) { return erff(a); }
+__device__ __forceinline__ double erfg(double a) { return erf(a); }
+__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
+__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
@@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
+__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
+__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
@@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
+__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
+__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 271502c5..fca6865e 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
+// Softmax implementation adapted from ggml.
+// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
+template <typename T, typename ACC>
+__device__ void softmax(const T * x, T * dst, const int ncols) {
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+ const int block_size = blockDim.y;
+ const int tid = threadIdx.y;
+
+ T max_val = -INFINITY;
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const int i = row*ncols + col;
+ max_val = maxg(max_val, x[i]);
+ }
+
+ // find the max value in the block
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+ }
+
+ ACC tmp = 0.;
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const int i = row*ncols + col;
+ const T val = expg(x[i] - max_val);
+ tmp += static_cast<ACC>(val);
+ dst[i] = val;
+ }
+
+ // sum up partial sums
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ }
+
+ const ACC inv_tmp = 1. / tmp;
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const int i = row*ncols + col;
+ dst[i] *= inv_tmp;
+ }
+}
+
template <typename T>
__device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
@@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
} \
}
+#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \
+ extern "C" __global__ void FN_NAME( \
+ const TYPENAME *src, TYPENAME *dst, \
+ const int n_cols) { \
+ softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
+ } \
+
#if __CUDA_ARCH__ >= 800
+SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
#if __CUDA_ARCH__ >= 530
+SOFTMAX_OP(__half, float, softmax_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa
SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
+SOFTMAX_OP(float, float, softmax_f32)
+SOFTMAX_OP(double, double, softmax_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu
index c6142a03..105d8c3a 100644
--- a/candle-kernels/src/unary.cu
+++ b/candle-kernels/src/unary.cu
@@ -29,6 +29,11 @@ extern "C" __global__ void FN_NAME( \
} \
template<typename T>
+__device__ __forceinline__ T gelu_erf_fwd(T x) {
+ return x * normcdfg(x);
+}
+
+template<typename T>
__device__ __forceinline__ T gelu_fwd(T x) {
T x_sq = x * x;
T x_cube = x_sq * x;
@@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
+UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
+UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
+UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
@@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x))
UNARY_OP(__half, usin_f16, sing(x))
UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, utanh_f16, tanhg(x))
+UNARY_OP(__half, uerf_f16, erfg(x))
+UNARY_OP(__half, unormcdf_f16, normcdfg(x))
UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
+UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP1(__half, upowf_f16, powg(x, param))
@@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x))
UNARY_OP(double, ucos_f64, cosg(x))
UNARY_OP(float, utanh_f32, tanhg(x))
UNARY_OP(double, utanh_f64, tanhg(x))
+UNARY_OP(float, uerf_f32, erfg(x))
+UNARY_OP(double, uerf_f64, erfg(x))
+UNARY_OP(float, unormcdf_f32, normcdfg(x))
+UNARY_OP(double, unormcdf_f64, normcdfg(x))
UNARY_OP(float, uabs_f32, absg(x))
UNARY_OP(double, uabs_f64, absg(x))
UNARY_OP(float, usqr_f32, x*x)
@@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(double, usqrt_f64, sqrtg(x))
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
+UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
+UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index aa055583..a6629d33 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -11,13 +11,18 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
+num-traits = { workspace = true }
+rayon = { workspace = true }
safetensors = { workspace = true }
+serde = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
+clap = { workspace = true }
[features]
default = []
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
new file mode 100644
index 00000000..204a7109
--- /dev/null
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -0,0 +1,302 @@
+/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use candle::quantized::GgmlType;
+use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
+use clap::{Parser, Subcommand};
+
+const CHECK_CONV2D: bool = false;
+
+trait Benchmark {
+ type PreProcessData;
+ type RunResult;
+
+ fn preprocess() -> Result<Self::PreProcessData>;
+ fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
+
+ const ITERS: usize;
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl candle::CustomOp1 for Im2Col {
+ fn name(&self) -> &'static str {
+ "im2col"
+ }
+
+ fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
+ let &Self {
+ h_k,
+ w_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, h, w) = layout.shape().dims4()?;
+ let (h_out, w_out) = self.hw_out(h, w);
+ let slice = storage.as_slice::<f32>()?;
+ let src = &slice[layout.start_offset()..];
+ let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k];
+ let (src_s0, src_s1, src_s2, src_s3) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2], s[3])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - h_k = w_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
+ for h_idx in 0..h_out {
+ let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
+ for w_idx in 0..w_out {
+ let dst_idx = dst_idx + w_idx * c * h_k * w_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * h_k * w_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for h_k_idx in 0..h_k {
+ let src_h = h_idx * stride + h_k_idx * dilation;
+ if padding != 0 && (src_h < padding || src_h >= h + padding) {
+ continue;
+ }
+ let src_h = src_h - padding;
+ let src_idx = src_idx + src_h * src_s2;
+ let dst_idx = dst_idx + h_k_idx * w_k;
+ for w_k_idx in 0..w_k {
+ let src_w = w_idx * stride + w_k_idx * dilation;
+ if padding != 0 && (src_w < padding || src_w >= w + padding) {
+ continue;
+ }
+ let src_w = src_w - padding;
+ let src_idx = src_idx + src_w * src_s3;
+ let dst_idx = dst_idx + w_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ }
+ }
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, (b * h_out * w_out, c * h_k * w_k).into()))
+ }
+}
+
+// Conv1d example as used in whisper.
+struct Conv1d;
+impl Benchmark for Conv1d {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
+ Ok((inp, w))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.conv1d(&d.1, 0, 1, 1, 1)
+ }
+
+ const ITERS: usize = 5;
+}
+
+// Conv2d example as used in stable-diffusion.
+struct Conv2d;
+impl Benchmark for Conv2d {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
+ Ok((inp, w))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.conv2d(&d.1, 0, 1, 1, 1)
+ }
+
+ const ITERS: usize = 5;
+}
+
+// Conv2d example as used in stable-diffusion, im2col implementation.
+struct Conv2dIm2Col;
+impl Benchmark for Conv2dIm2Col {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
+ Ok((inp, w))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ // d.0.conv2d(&d.1, 0, 1, 1, 1)
+ let (b, _, h, w) = d.0.dims4()?;
+ let (_, _, h_k, w_k) = d.1.dims4()?;
+ let op = Im2Col {
+ h_k,
+ w_k,
+ stride: 1,
+ dilation: 1,
+ padding: 0,
+ };
+ let (h_out, w_out) = op.hw_out(h, w);
+ let col = d.0.apply_op1_no_bwd(&op)?;
+ let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
+ let res = res
+ .reshape((b, h_out, w_out, ()))?
+ .permute((0, 3, 1, 2))?
+ .contiguous()?;
+ if CHECK_CONV2D {
+ let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1);
+ let diff = (&res - res2)?.sqr()?.mean_all()?;
+ println!("{diff}");
+ }
+ Ok(res)
+ }
+
+ const ITERS: usize = 5;
+}
+
+struct Matmul;
+impl Benchmark for Matmul {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ Ok((lhs, rhs))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.matmul(&d.1)
+ }
+
+ const ITERS: usize = 100;
+}
+
+// This benchmark is similar to:
+// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
+struct QMatMul;
+impl Benchmark for QMatMul {
+ type PreProcessData = (candle::quantized::QMatMul, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
+ let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
+ let mm = candle::quantized::QMatMul::from_qtensor(mm);
+ let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
+ Ok((mm, arg))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.forward(&d.1)
+ }
+
+ const ITERS: usize = 100;
+}
+
+struct Softmax;
+impl Benchmark for Softmax {
+ type PreProcessData = Tensor;
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ // Typical whisper tiny size.
+ let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
+ Ok(x)
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ candle_nn::ops::softmax(d, D::Minus1)
+ }
+
+ const ITERS: usize = 100;
+}
+
+struct SoftmaxLastDim;
+impl Benchmark for SoftmaxLastDim {
+ type PreProcessData = Tensor;
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ // Typical whisper tiny size.
+ let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
+ Ok(x)
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ candle_nn::ops::softmax_last_dim(d)
+ }
+
+ const ITERS: usize = 100;
+}
+
+fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
+ use std::hint::black_box;
+
+ let iters = iters.unwrap_or(B::ITERS);
+ let d = B::preprocess()?;
+ let start = std::time::Instant::now();
+ for _iter in 0..iters {
+ let _res = black_box(B::run_one(black_box(&d))?);
+ }
+ println!("{:?}", start.elapsed() / iters as u32);
+ Ok(())
+}
+
+#[derive(Subcommand, Debug, Clone)]
+enum Task {
+ Conv1d,
+ Conv2d,
+ Conv2dIm2Col,
+ Matmul,
+ Qmatmul,
+ Softmax,
+ SoftmaxLastDim,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+ /// The benchmark to be run.
+ #[command(subcommand)]
+ task: Task,
+
+ #[arg(long)]
+ iters: Option<usize>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ match args.task {
+ Task::Conv1d => run::<Conv1d>(args.iters)?,
+ Task::Conv2d => run::<Conv2d>(args.iters)?,
+ Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?,
+ Task::Matmul => run::<Matmul>(args.iters)?,
+ Task::Softmax => run::<Softmax>(args.iters)?,
+ Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
+ Task::Qmatmul => run::<QMatMul>(args.iters)?,
+ }
+ Ok(())
+}
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 0db3edc9..17467b31 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -1,18 +1,29 @@
use candle::Tensor;
+use serde::Deserialize;
-#[derive(Debug, Clone, Copy, PartialEq)]
+#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
pub enum Activation {
+ #[default]
Gelu,
+ #[serde(rename = "gated-gelu")]
+ NewGelu,
Relu,
Elu(f64),
+ LeakyRelu(f64),
}
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
+ // TODO: This is "gelu_new", not the original "gelu".
+ // There's some small numerical difference:
+ // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
+ Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),
&Self::Elu(alpha) => xs.elu(alpha),
+ &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}
}
}
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 2dac0758..27ef15f7 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
}
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct BatchNorm {
running_mean: Tensor,
running_var: Tensor,
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index dbf23aa5..89e9f42d 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
}
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Conv1d {
weight: Tensor,
bias: Option<Tensor>,
@@ -39,6 +39,14 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
+
+ pub fn weight(&self) -> &Tensor {
+ &self.weight
+ }
+
+ pub fn bias(&self) -> Option<&Tensor> {
+ self.bias.as_ref()
+ }
}
impl crate::Module for Conv1d {
@@ -80,8 +88,7 @@ impl Default for Conv2dConfig {
}
}
-#[allow(dead_code)]
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Conv2d {
weight: Tensor,
bias: Option<Tensor>,
@@ -100,6 +107,14 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
+
+ pub fn weight(&self) -> &Tensor {
+ &self.weight
+ }
+
+ pub fn bias(&self) -> Option<&Tensor> {
+ self.bias.as_ref()
+ }
}
impl crate::Module for Conv2d {
@@ -122,15 +137,76 @@ impl crate::Module for Conv2d {
}
}
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct ConvTranspose2dConfig {
+ pub padding: usize,
+ pub output_padding: usize,
+ pub stride: usize,
+ pub dilation: usize,
+ // TODO: support groups.
+}
+
+impl Default for ConvTranspose2dConfig {
+ fn default() -> Self {
+ Self {
+ padding: 0,
+ output_padding: 0,
+ stride: 1,
+ dilation: 1,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ConvTranspose2d {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ config: ConvTranspose2dConfig,
+}
+
+impl ConvTranspose2d {
+ pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
+ Self {
+ weight,
+ bias,
+ config,
+ }
+ }
+
+ pub fn config(&self) -> &ConvTranspose2dConfig {
+ &self.config
+ }
+}
+
+impl crate::Module for ConvTranspose2d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv_transpose2d(
+ &self.weight,
+ self.config.padding,
+ self.config.output_padding,
+ self.config.stride,
+ self.config.dilation,
+ )?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
+ }
+}
+
pub fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
- vs: crate::VarBuilder,
+ vb: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints(
+ let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
@@ -140,7 +216,7 @@ pub fn conv1d(
lo: -bound,
up: bound,
};
- let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
+ let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv1d::new(ws, Some(bs), cfg))
}
@@ -149,10 +225,10 @@ pub fn conv2d(
out_channels: usize,
kernel_size: usize,
cfg: Conv2dConfig,
- vs: crate::VarBuilder,
+ vb: crate::VarBuilder,
) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints(
+ let ws = vb.get_with_hints(
(
out_channels,
in_channels / cfg.groups,
@@ -167,7 +243,7 @@ pub fn conv2d(
lo: -bound,
up: bound,
};
- let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
+ let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv2d::new(ws, Some(bs), cfg))
}
@@ -176,10 +252,10 @@ pub fn conv2d_no_bias(
out_channels: usize,
kernel_size: usize,
cfg: Conv2dConfig,
- vs: crate::VarBuilder,
+ vb: crate::VarBuilder,
) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
- let ws = vs.get_with_hints(
+ let ws = vb.get_with_hints(
(
out_channels,
in_channels / cfg.groups,
@@ -191,3 +267,44 @@ pub fn conv2d_no_bias(
)?;
Ok(Conv2d::new(ws, None, cfg))
}
+
+pub fn conv_transpose2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose2dConfig,
+ vb: crate::VarBuilder,
+) -> Result<ConvTranspose2d> {
+ let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vb.get_with_hints(
+ (in_channels, out_channels, kernel_size, kernel_size),
+ "weight",
+ init,
+ )?;
+ let bs = vb.get_with_hints(out_channels, "bias", init)?;
+ Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
+}
+
+pub fn conv_transpose2d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose2dConfig,
+ vb: crate::VarBuilder,
+) -> Result<ConvTranspose2d> {
+ let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vb.get_with_hints(
+ (in_channels, out_channels, kernel_size, kernel_size),
+ "weight",
+ init,
+ )?;
+ Ok(ConvTranspose2d::new(ws, None, cfg))
+}
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index d84f9f53..52968bc2 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -1,7 +1,7 @@
//! Embedding Layer.
use candle::{Result, Tensor};
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Embedding {
embeddings: Tensor,
hidden_size: usize,
@@ -18,6 +18,11 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
+
+ /// Get the hidden size of the embedding matrix
+ pub fn hidden_size(&self) -> usize {
+ self.hidden_size
+ }
}
impl crate::Module for Embedding {
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index eb1b889f..5b80b970 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -4,7 +4,7 @@
use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean.
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct GroupNorm {
weight: Tensor,
bias: Tensor,
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 08e2f628..7617fc6c 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
-use candle::{DType, Result, Tensor};
+use candle::{DType, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
}
// This layer norm version handles both weight and bias so removes the mean.
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Option<Tensor>,
@@ -104,15 +104,15 @@ impl crate::Module for LayerNorm {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
- let (_bsize, _seq_len, hidden_size) = x.dims3()?;
+ let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = if self.remove_mean {
- let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
match &self.bias {
@@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
}
/// RmsNorm is a specialized version of the LayerNorm module.
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm);
impl RmsNorm {
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 6e268f4e..8e5580df 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -16,7 +16,10 @@ pub mod var_map;
pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
-pub use conv::{conv1d, conv2d, conv2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
+pub use conv::{
+ conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
+ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
+};
pub use embedding::{embedding, Embedding};
pub use func::{func, Func};
pub use group_norm::{group_norm, GroupNorm};
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 7028f68c..94632296 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -19,7 +19,7 @@
//! ```
use candle::{Result, Tensor};
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
@@ -41,8 +41,9 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let w = match x.dims() {
- &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ let w = match *x.dims() {
+ [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs
index cddf278e..72451f83 100644
--- a/candle-nn/src/loss.rs
+++ b/candle-nn/src/loss.rs
@@ -1,6 +1,6 @@
use candle::{Result, Tensor};
-/// The negative loss likelihodd loss.
+/// The negative log likelihood loss.
///
/// Arguments
///
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index c3b6ffa2..32de1af9 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,4 +1,5 @@
-use candle::{Result, Tensor};
+use candle::{CpuStorage, Layout, Result, Shape, Tensor};
+use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@@ -43,6 +44,11 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
(xs.neg()?.exp()? + 1.0)?.recip()
}
+pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
+ let zeros = xs.zeros_like()?;
+ xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
+}
+
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
// This implementation is inefficient as it stores the full mask for the backward pass.
// Instead we could just store the seed and have a specialized kernel that would both
@@ -77,3 +83,149 @@ impl Dropout {
}
}
}
+
+struct SoftmaxLastDim;
+
+impl candle::CustomOp1 for SoftmaxLastDim {
+ fn name(&self) -> &'static str {
+ "softmax-last-dim"
+ }
+
+ fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
+ fn softmax<T: candle::WithDType + num_traits::Float>(
+ src: &[T],
+ layout: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let el_count = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(dim_m1)
+ .zip(dst.par_chunks_mut(dim_m1))
+ .for_each(|(src, dst)| {
+ let mut max = T::neg_infinity();
+ unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
+ for (s, d) in src.iter().zip(dst.iter_mut()) {
+ *d = (*s - max).exp();
+ }
+ let mut sum_exp = T::zero();
+ unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
+ for d in dst.iter_mut() {
+ *d /= sum_exp
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, Shape::from_dims(dims)))
+ }
+
+ match storage {
+ CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
+ CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
+ CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
+ CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
+ _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
+ }
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ storage: &candle::CudaStorage,
+ layout: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
+ };
+ use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ struct S;
+ impl Map1 for S {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let el = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let (n_rows, n_cols) = (el / dim_m1, dim_m1);
+
+ let cfg = LaunchConfig {
+ grid_dim: (n_rows as u32, 1, 1),
+ block_dim: (1, 32, 1),
+ shared_mem_bytes: 0,
+ };
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(el) }.w()?;
+ let params = (src, &dst, n_cols as i32);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+ }
+
+ use candle::backend::BackendStorage;
+ let dev = storage.device();
+ let slice = S.map(&storage.slice, dev, layout)?;
+ let dst = candle::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, layout.shape().clone()))
+ }
+}
+
+pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
+ xs.apply_op1_no_bwd(&SoftmaxLastDim)
+}
+
+// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
+pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
+ let (b_size, c, h, w) = xs.dims4()?;
+ let out_c = c / upscale_factor / upscale_factor;
+ xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
+ .permute((0, 1, 4, 2, 5, 3))?
+ .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
+}
+
+pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
+ let (b_size, c, h, w) = xs.dims4()?;
+ let out_c = c * downscale_factor * downscale_factor;
+ xs.reshape((
+ b_size,
+ c,
+ h / downscale_factor,
+ downscale_factor,
+ w / downscale_factor,
+ downscale_factor,
+ ))?
+ .permute((0, 1, 3, 5, 2, 4))?
+ .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
+}
+
+// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
+pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
+ match pad {
+ 0 => Ok(xs.clone()),
+ 1 => {
+ let (_b_size, _c, h, w) = xs.dims4()?;
+ let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
+ let xs = Tensor::cat(&[&first, xs, &last], 3)?;
+ let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
+ Tensor::cat(&[&first, &xs, &last], 2)
+ }
+ n => candle::bail!("replication-pad with a size of {n} is not supported"),
+ }
+}
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index d52a9082..18a4a71c 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -85,7 +85,7 @@ impl LSTMConfig {
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)]
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
w_hh: Tensor,
@@ -235,7 +235,7 @@ impl GRUConfig {
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)]
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
w_hh: Tensor,
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index bf5d5b43..4ccbaf17 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -5,14 +5,14 @@ use crate::VarMap;
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
-use std::rc::Rc;
+use std::sync::Arc;
/// A structure used to retrieve variables, these variables can either come from storage or be
/// generated via some form of initialization.
///
/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
pub struct VarBuilderArgs<'a, B: Backend> {
- data: Rc<TensorData<B>>,
+ data: Arc<TensorData<B>>,
path: Vec<String>,
_phantom: std::marker::PhantomData<&'a B>,
}
@@ -43,7 +43,7 @@ struct TensorData<B: Backend> {
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
/// of the time. The main restriction is that it doesn't allow for specific args (besides
/// initialization hints).
-pub trait Backend {
+pub trait Backend: Send + Sync {
type Hints: Default;
/// Retrieve a tensor with some target shape.
@@ -59,7 +59,7 @@ pub trait Backend {
fn contains_tensor(&self, name: &str) -> bool;
}
-pub trait SimpleBackend {
+pub trait SimpleBackend: Send + Sync {
/// Retrieve a tensor based on a target name and shape.
fn get(
&self,
@@ -99,7 +99,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
device: dev.clone(),
};
Self {
- data: Rc::new(data),
+ data: Arc::new(data),
path: vec![],
_phantom: std::marker::PhantomData,
}
@@ -333,7 +333,7 @@ impl<'a> VarBuilder<'a> {
device,
};
Self {
- data: Rc::new(data),
+ data: Arc::new(data),
path: vec![],
_phantom: std::marker::PhantomData,
}
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs
index 209fc10a..5bbaf238 100644
--- a/candle-nn/tests/batch_norm.rs
+++ b/candle-nn/tests/batch_norm.rs
@@ -59,8 +59,8 @@ fn batch_norm() -> Result<()> {
);
let bn2 = BatchNorm::new(
5,
- running_mean.clone(),
- running_var.clone(),
+ running_mean,
+ running_var,
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index 4ba8cfcc..5ca01b37 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -41,6 +41,16 @@ fn softmax() -> Result<()> {
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
+ let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?;
+ assert_eq!(
+ to_vec3_round(&t2, 4)?,
+ &[
+ // (3, 1, 4) / 8, (1, 5, 9) / 15
+ [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
+ // (2, 1, 7) / 10, (8, 2, 8) / 18
+ [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
+ ]
+ );
Ok(())
}
diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore
new file mode 100644
index 00000000..68bc17f9
--- /dev/null
+++ b/candle-pyo3/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 97631b0a..7fd0ac28 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -12,11 +12,10 @@ readme = "README.md"
[lib]
name = "candle"
crate-type = ["cdylib"]
-doc = false
[dependencies]
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md
index 07dff468..be6d4f68 100644
--- a/candle-pyo3/README.md
+++ b/candle-pyo3/README.md
@@ -1,7 +1,26 @@
+## Installation
+
From the `candle-pyo3` directory, enable a virtual env where you will want the
candle package to be installed then run.
```bash
-maturin develop
+maturin develop -r
python test.py
```
+
+## Generating Stub Files for Type Hinting
+
+For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script.
+
+### Steps:
+1. Install the package using `maturin`.
+2. Generate the stub files by running:
+ ```
+ python stub.py
+ ```
+
+### Validation:
+To ensure that the stub files match the current implementation, execute:
+```
+python stub.py --check
+```
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
new file mode 100644
index 00000000..951609cc
--- /dev/null
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -0,0 +1,5 @@
+from .candle import *
+
+__doc__ = candle.__doc__
+if hasattr(candle, "__all__"):
+ __all__ = candle.__all__ \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
new file mode 100644
index 00000000..414f0bc4
--- /dev/null
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -0,0 +1,375 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+
+class bf16(DType):
+ pass
+
+@staticmethod
+def cat(tensors: List[Tensor], dim: int) -> Tensor:
+ """
+ Concatenate the tensors across one axis.
+ """
+ pass
+
+class f16(DType):
+ pass
+
+class f32(DType):
+ pass
+
+class f64(DType):
+ pass
+
+class i64(DType):
+ pass
+
+@staticmethod
+def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor filled with ones.
+ """
+ pass
+
+@staticmethod
+def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor with random values.
+ """
+ pass
+
+@staticmethod
+def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor with random values from a normal distribution.
+ """
+ pass
+
+@staticmethod
+def stack(tensors: List[Tensor], dim: int) -> Tensor:
+ """
+ Stack the tensors along a new axis.
+ """
+ pass
+
+@staticmethod
+def tensor(data: _ArrayLike) -> Tensor:
+ """
+ Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+ """
+ pass
+
+class u32(DType):
+ pass
+
+class u8(DType):
+ pass
+
+@staticmethod
+def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor filled with zeros.
+ """
+ pass
+
+class DType:
+ """
+ A `candle` dtype.
+ """
+
+class QTensor:
+ """
+ A quantized tensor.
+ """
+
+ def dequantize(self) -> Tensor:
+ """
+ Dequantizes the tensor.
+ """
+ pass
+ @property
+ def ggml_dtype(self) -> str:
+ """
+ Gets the tensors quantized dtype.
+ """
+ pass
+ def matmul_t(self, lhs: Tensor) -> Tensor:
+ """
+ Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
+ """
+ pass
+ @property
+ def rank(self) -> int:
+ """
+ Gets the rank of the tensor.
+ """
+ pass
+ @property
+ def shape(self) -> Tuple[int]:
+ """
+ Gets the shape of the tensor.
+ """
+ pass
+
+class Tensor:
+ """
+ A `candle` tensor.
+ """
+
+ def __init__(self, data: _ArrayLike):
+ pass
+ def argmax_keepdim(self, dim: int) -> Tensor:
+ """
+ Returns the indices of the maximum value(s) across the selected dimension.
+ """
+ pass
+ def argmin_keepdim(self, dim: int) -> Tensor:
+ """
+ Returns the indices of the minimum value(s) across the selected dimension.
+ """
+ pass
+ def broadcast_add(self, rhs: Tensor) -> Tensor:
+ """
+ Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
+ pass
+ def broadcast_as(self, shape: Sequence[int]) -> Tensor:
+ """
+ Broadcasts the tensor to the given shape.
+ """
+ pass
+ def broadcast_div(self, rhs: Tensor) -> Tensor:
+ """
+ Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
+ pass
+ def broadcast_left(self, shape: Sequence[int]) -> Tensor:
+ """
+ Broadcasts the tensor to the given shape, adding new dimensions on the left.
+ """
+ pass
+ def broadcast_mul(self, rhs: Tensor) -> Tensor:
+ """
+ Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
+ pass
+ def broadcast_sub(self, rhs: Tensor) -> Tensor:
+ """
+ Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
+ pass
+ def contiguous(self) -> Tensor:
+ """
+ Makes the tensor contiguous in memory.
+ """
+ pass
+ def copy(self) -> Tensor:
+ """
+ Returns a copy of the tensor.
+ """
+ pass
+ def cos(self) -> Tensor:
+ """
+ Performs the `cos` operation on the tensor.
+ """
+ pass
+ def detach(self) -> Tensor:
+ """
+ Detach the tensor from the computation graph.
+ """
+ pass
+ @property
+ def device(self) -> Device:
+ """
+ Gets the tensor's device.
+ """
+ pass
+ @property
+ def dtype(self) -> DType:
+ """
+ Gets the tensor's dtype.
+ """
+ pass
+ def exp(self) -> Tensor:
+ """
+ Performs the `exp` operation on the tensor.
+ """
+ pass
+ def flatten_all(self) -> Tensor:
+ """
+ Flattens the tensor into a 1D tensor.
+ """
+ pass
+ def flatten_from(self, dim: int) -> Tensor:
+ """
+ Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
+ """
+ pass
+ def flatten_to(self, dim: int) -> Tensor:
+ """
+ Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
+ """
+ pass
+ def get(self, index: int) -> Tensor:
+ """
+ Gets the value at the specified index.
+ """
+ pass
+ def index_select(self, rhs: Tensor, dim: int) -> Tensor:
+ """
+ Select values for the input tensor at the target indexes across the specified dimension.
+
+ The `indexes` is argument is an int tensor with a single dimension.
+ The output has the same number of dimension as the `self` input. The target dimension of
+ the output has length the length of `indexes` and the values are taken from `self` using
+ the index from `indexes`. Other dimensions have the same number of elements as the input
+ tensor.
+ """
+ pass
+ def is_contiguous(self) -> bool:
+ """
+ Returns true if the tensor is contiguous in C order.
+ """
+ pass
+ def is_fortran_contiguous(self) -> bool:
+ """
+ Returns true if the tensor is contiguous in Fortran order.
+ """
+ pass
+ def log(self) -> Tensor:
+ """
+ Performs the `log` operation on the tensor.
+ """
+ pass
+ def matmul(self, rhs: Tensor) -> Tensor:
+ """
+ Performs a matrix multiplication between the two tensors.
+ """
+ pass
+ def max_keepdim(self, dim: int) -> Tensor:
+ """
+ Gathers the maximum value across the selected dimension.
+ """
+ pass
+ def mean_all(self) -> Tensor:
+ """
+ Returns the mean of the tensor.
+ """
+ pass
+ def min_keepdim(self, dim: int) -> Tensor:
+ """
+ Gathers the minimum value across the selected dimension.
+ """
+ pass
+ def narrow(self, dim: int, start: int, len: int) -> Tensor:
+ """
+ Returns a new tensor that is a narrowed version of the input, the dimension `dim`
+ ranges from `start` to `start + len`.
+ """
+ pass
+ def powf(self, p: float) -> Tensor:
+ """
+ Performs the `pow` operation on the tensor with the given exponent.
+ """
+ pass
+ def quantize(self, quantized_dtype: str) -> QTensor:
+ """
+ Quantize the tensor.
+ """
+ pass
+ @property
+ def rank(self) -> int:
+ """
+ Gets the tensor's rank.
+ """
+ pass
+ def recip(self) -> Tensor:
+ """
+ Get the `recip` of the tensor.
+ """
+ pass
+ def reshape(self, shape: Sequence[int]) -> Tensor:
+ """
+ Reshapes the tensor to the given shape.
+ """
+ pass
+ @property
+ def shape(self) -> Tuple[int]:
+ """
+ Gets the tensor's shape.
+ """
+ pass
+ def sin(self) -> Tensor:
+ """
+ Performs the `sin` operation on the tensor.
+ """
+ pass
+ def sqr(self) -> Tensor:
+ """
+ Squares the tensor.
+ """
+ pass
+ def sqrt(self) -> Tensor:
+ """
+ Calculates the square root of the tensor.
+ """
+ pass
+ def squeeze(self, dim: int) -> Tensor:
+ """
+ Creates a new tensor with the specified dimension removed if its size was one.
+ """
+ pass
+ @property
+ def stride(self) -> Tuple[int]:
+ """
+ Gets the tensor's strides.
+ """
+ pass
+ def sum_all(self) -> Tensor:
+ """
+ Returns the sum of the tensor.
+ """
+ pass
+ def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
+ """
+ Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
+ """
+ pass
+ def t(self) -> Tensor:
+ """
+ Transposes the tensor.
+ """
+ pass
+ def to_device(self, device: Union[str, Device]) -> Tensor:
+ """
+ Move the tensor to a new device.
+ """
+ pass
+ def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
+ """
+ Convert the tensor to a new dtype.
+ """
+ pass
+ def transpose(self, dim1: int, dim2: int) -> Tensor:
+ """
+ Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
+ """
+ pass
+ def unsqueeze(self, dim: int) -> Tensor:
+ """
+ Creates a new tensor with a dimension of size one inserted at the specified position.
+ """
+ pass
+ def values(self) -> _ArrayLike:
+ """
+ Gets the tensor's data as a Python scalar or array-like object.
+ """
+ pass
+ def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
+ """
+ Returns a tensor with the same shape as the input tensor, the values are taken from
+ `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
+ input tensor is equal to zero.
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py
new file mode 100644
index 00000000..b8c5cfb7
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/__init__.py
@@ -0,0 +1,5 @@
+# Generated content DO NOT EDIT
+from .. import nn
+
+silu = nn.silu
+softmax = nn.softmax
diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi
new file mode 100644
index 00000000..01b30fce
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/__init__.pyi
@@ -0,0 +1,19 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+from candle import Tensor, DType, QTensor
+
+@staticmethod
+def silu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def softmax(tensor: Tensor, dim: int) -> Tensor:
+ """
+ Applies the Softmax function to a given tensor.#
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
new file mode 100644
index 00000000..ea85d2a3
--- /dev/null
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -0,0 +1,16 @@
+from typing import TypeVar, Union, Sequence
+
+_T = TypeVar("_T")
+
+_ArrayLike = Union[
+ _T,
+ Sequence[_T],
+ Sequence[Sequence[_T]],
+ Sequence[Sequence[Sequence[_T]]],
+ Sequence[Sequence[Sequence[Sequence[_T]]]],
+]
+
+CPU:str = "cpu"
+CUDA:str = "cuda"
+
+Device = TypeVar("Device", CPU, CUDA) \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py
new file mode 100644
index 00000000..62d85dc9
--- /dev/null
+++ b/candle-pyo3/py_src/candle/utils/__init__.py
@@ -0,0 +1,12 @@
+# Generated content DO NOT EDIT
+from .. import utils
+
+cuda_is_available = utils.cuda_is_available
+get_num_threads = utils.get_num_threads
+has_accelerate = utils.has_accelerate
+has_mkl = utils.has_mkl
+load_ggml = utils.load_ggml
+load_gguf = utils.load_gguf
+load_safetensors = utils.load_safetensors
+save_gguf = utils.save_gguf
+save_safetensors = utils.save_safetensors
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
new file mode 100644
index 00000000..61964ffc
--- /dev/null
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -0,0 +1,70 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+from candle import Tensor, DType, QTensor
+
+@staticmethod
+def cuda_is_available() -> bool:
+ """
+ Returns true if the 'cuda' backend is available.
+ """
+ pass
+
+@staticmethod
+def get_num_threads() -> int:
+ """
+ Returns the number of threads used by the candle.
+ """
+ pass
+
+@staticmethod
+def has_accelerate() -> bool:
+ """
+ Returns true if candle was compiled with 'accelerate' support.
+ """
+ pass
+
+@staticmethod
+def has_mkl() -> bool:
+ """
+ Returns true if candle was compiled with MKL support.
+ """
+ pass
+
+@staticmethod
+def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
+ """
+ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
+ a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
+ """
+ pass
+
+@staticmethod
+def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
+ """
+ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
+ and the second maps metadata keys to metadata values.
+ """
+ pass
+
+@staticmethod
+def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
+ """
+ Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
+ """
+ pass
+
+@staticmethod
+def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
+ """
+ Save quanitzed tensors and metadata to a GGUF file.
+ """
+ pass
+
+@staticmethod
+def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None:
+ """
+ Saves a dictionary of tensors to a safetensors file.
+ """
+ pass
diff --git a/candle-pyo3/pyproject.toml b/candle-pyo3/pyproject.toml
new file mode 100644
index 00000000..88793493
--- /dev/null
+++ b/candle-pyo3/pyproject.toml
@@ -0,0 +1,30 @@
+[project]
+name = 'candle-nn'
+requires-python = '>=3.7'
+authors = [
+ {name = 'The Candle Team'},
+]
+
+dynamic = [
+ 'description',
+ 'license',
+ 'readme',
+]
+
+[project.urls]
+Homepage = 'https://github.com/huggingface/candle'
+Source = 'https://github.com/huggingface/candle'
+
+[build-system]
+requires = ["maturin>=1.0,<2.0"]
+build-backend = "maturin"
+
+[tool.maturin]
+python-source = "py_src"
+module-name = "candle.candle"
+bindings = 'pyo3'
+features = ["pyo3/extension-module"]
+
+[tool.black]
+line-length = 119
+target-version = ['py35']
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 7d74c25e..46d9ff62 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -1,26 +1,28 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
+from typing import Dict, Tuple, Any
import candle
+from candle import Tensor, QTensor, utils, nn
MAX_SEQ_LEN = 4096
-def masked_fill(on_false, mask, on_true):
+def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
class RmsNorm:
- def __init__(self, qtensor):
+ def __init__(self, qtensor:QTensor):
self.weight = qtensor.dequantize()
- def __call__(self, x):
+ def __call__(self, x:Tensor):
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer:
- def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
+ def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
@@ -40,7 +42,7 @@ class QuantizedLayer:
self.cos = cos_sin[0]
self.sin = cos_sin[1]
- def __call__(self, x, mask, index_pos):
+ def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
@@ -50,11 +52,11 @@ class QuantizedLayer:
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
- mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
+ mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
- def forward_attn(self, x, mask, index_pos):
+ def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
@@ -79,12 +81,12 @@ class QuantizedLayer:
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
- att = candle.nn.softmax(att, -1)
+ att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
- def apply_rotary_emb(self, x, index_pos):
+ def apply_rotary_emb(self, x:Tensor, index_pos:int):
(b_size, n_head, seq_len, n_embd) = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
@@ -106,17 +108,18 @@ def precompute_freqs_cis(hparams, freq_base):
return (m.cos(), m.sin())
class QuantizedLlama:
- def __init__(self, hparams, all_tensors):
+ def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = []
- cos_sin = precompute_freqs_cis(hparams, 10000.)
+ rope_freq = hparams.get("rope_freq", 10000.)
+ cos_sin = precompute_freqs_cis(hparams, rope_freq)
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
- def __call__(self, token, index_pos):
+ def __call__(self, token:Tensor, index_pos:int):
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
@@ -133,17 +136,47 @@ class QuantizedLlama:
x = self.output.matmul_t(x)
return x
+def gguf_rename(tensor_name:str):
+ if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
+ if tensor_name == 'output_norm.weight': return 'norm.weight'
+ tensor_name = tensor_name.replace('blk.', 'layers.')
+ tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
+ tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
+ tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
+ tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
+ tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
+ tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
+ tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
+ tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
+ return tensor_name
+
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors = candle.load_gguf(sys.argv[1])
- hparams = None
- vocab = None
+ all_tensors, metadata = utils.load_gguf(sys.argv[1])
+ vocab = metadata["tokenizer.ggml.tokens"]
+ for i, v in enumerate(vocab):
+ vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
+ hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
+ print(hparams)
+ hparams = {
+ 'n_vocab': len(vocab),
+ 'n_embd': metadata['llama.embedding_length'],
+ 'n_mult': 256,
+ 'n_head': metadata['llama.attention.head_count'],
+ 'n_head_kv': metadata['llama.attention.head_count_kv'],
+ 'n_layer': metadata['llama.block_count'],
+ 'n_rot': metadata['llama.rope.dimension_count'],
+ 'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
+ 'ftype': metadata['general.file_type'],
+ }
+ all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
+
else:
- all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 79f86479..55b7a888 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,8 +1,7 @@
#![allow(clippy::redundant_closure_call)]
-// TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
-use pyo3::types::{IntoPyDict, PyTuple};
+use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::sync::Arc;
@@ -32,6 +31,7 @@ impl From<PyShape> for ::candle::Shape {
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
+/// A `candle` tensor.
struct PyTensor(Tensor);
impl std::ops::Deref for PyTensor {
@@ -44,6 +44,7 @@ impl std::ops::Deref for PyTensor {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[pyclass(name = "DType")]
+/// A `candle` dtype.
struct PyDType(DType);
#[pymethods]
@@ -198,38 +199,40 @@ trait MapDType {
#[pymethods]
impl PyTensor {
#[new]
+ #[pyo3(text_signature = "(self, data:_ArrayLike)")]
// TODO: Handle arbitrary input dtype and shape.
- fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
+ /// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+ fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
use Device::Cpu;
- let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
+ let tensor = if let Ok(vs) = data.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<i64>(py) {
+ } else if let Ok(vs) = data.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<f32>(py) {
+ } else if let Ok(vs) = data.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
- let ty = vs.as_ref(py).get_type();
+ let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@@ -237,7 +240,8 @@ impl PyTensor {
Ok(Self(tensor))
}
- /// Gets the tensor data as a Python value/array/array of array/...
+ /// Gets the tensor's data as a Python scalar or array-like object.
+ /// &RETURNS&: _ArrayLike
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
@@ -281,26 +285,36 @@ impl PyTensor {
}
#[getter]
+ /// Gets the tensor's shape.
+ /// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
+ /// Gets the tensor's strides.
+ /// &RETURNS&: Tuple[int]
fn stride(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.stride()).to_object(py)
}
#[getter]
+ /// Gets the tensor's dtype.
+ /// &RETURNS&: DType
fn dtype(&self) -> PyDType {
PyDType(self.0.dtype())
}
#[getter]
+ /// Gets the tensor's device.
+ /// &RETURNS&: Device
fn device(&self, py: Python<'_>) -> PyObject {
PyDevice::from_device(self.0.device()).to_object(py)
}
#[getter]
+ /// Gets the tensor's rank.
+ /// &RETURNS&: int
fn rank(&self) -> usize {
self.0.rank()
}
@@ -313,69 +327,117 @@ impl PyTensor {
self.__repr__()
}
+ /// Performs the `sin` operation on the tensor.
+ /// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
}
+ /// Performs the `cos` operation on the tensor.
+ /// &RETURNS&: Tensor
fn cos(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
}
+ /// Performs the `log` operation on the tensor.
+ /// &RETURNS&: Tensor
fn log(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
}
+ /// Squares the tensor.
+ /// &RETURNS&: Tensor
fn sqr(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
}
+ /// Calculates the square root of the tensor.
+ /// &RETURNS&: Tensor
fn sqrt(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
}
+ /// Get the `recip` of the tensor.
+ /// &RETURNS&: Tensor
fn recip(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
}
+ /// Performs the `exp` operation on the tensor.
+ /// &RETURNS&: Tensor
fn exp(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, p:float)")]
+ /// Performs the `pow` operation on the tensor with the given exponent.
+ /// &RETURNS&: Tensor
fn powf(&self, p: f64) -> PyResult<Self> {
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")]
+ /// Select values for the input tensor at the target indexes across the specified dimension.
+ ///
+ /// The `indexes` is argument is an int tensor with a single dimension.
+ /// The output has the same number of dimension as the `self` input. The target dimension of
+ /// the output has length the length of `indexes` and the values are taken from `self` using
+ /// the index from `indexes`. Other dimensions have the same number of elements as the input
+ /// tensor.
+ /// &RETURNS&: Tensor
fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Performs a matrix multiplication between the two tensors.
+ /// &RETURNS&: Tensor
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")]
+ /// Returns a tensor with the same shape as the input tensor, the values are taken from
+ /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
+ /// input tensor is equal to zero.
+ /// &RETURNS&: Tensor
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
Ok(PyTensor(
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
))
}
+ /// Add two tensors.
+ /// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)?
@@ -391,6 +453,8 @@ impl PyTensor {
self.__add__(rhs)
}
+ /// Multiply two tensors.
+ /// &RETURNS&: Tensor
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 * &rhs.0).map_err(wrap_err)?
@@ -406,6 +470,8 @@ impl PyTensor {
self.__mul__(rhs)
}
+ /// Subtract two tensors.
+ /// &RETURNS&: Tensor
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 - &rhs.0).map_err(wrap_err)?
@@ -417,6 +483,8 @@ impl PyTensor {
Ok(Self(tensor))
}
+ /// Divide two tensors.
+ /// &RETURNS&: Tensor
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 / &rhs.0).map_err(wrap_err)?
@@ -428,62 +496,102 @@ impl PyTensor {
Ok(Self(tensor))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Reshapes the tensor to the given shape.
+ /// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Broadcasts the tensor to the given shape.
+ /// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
+ /// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Creates a new tensor with the specified dimension removed if its size was one.
+ /// &RETURNS&: Tensor
fn squeeze(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Creates a new tensor with a dimension of size one inserted at the specified position.
+ /// &RETURNS&: Tensor
fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, index:int)")]
+ /// Gets the value at the specified index.
+ /// &RETURNS&: Tensor
fn get(&self, index: i64) -> PyResult<Self> {
let index = actual_index(self, 0, index).map_err(wrap_err)?;
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim1:int, dim2:int)")]
+ /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
+ /// &RETURNS&: Tensor
fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")]
+ /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
+ /// ranges from `start` to `start + len`.
+ /// &RETURNS&: Tensor
fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
let start = actual_index(self, dim, start).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Returns the indices of the maximum value(s) across the selected dimension.
+ /// &RETURNS&: Tensor
fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Returns the indices of the minimum value(s) across the selected dimension.
+ /// &RETURNS&: Tensor
fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Gathers the maximum value across the selected dimension.
+ /// &RETURNS&: Tensor
fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Gathers the minimum value across the selected dimension.
+ /// &RETURNS&: Tensor
fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")]
+ /// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
+ /// &RETURNS&: Tensor
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
vec![dim]
@@ -495,10 +603,14 @@ impl PyTensor {
))
}
+ /// Returns the sum of the tensor.
+ /// &RETURNS&: Tensor
fn sum_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
}
+ /// Returns the mean of the tensor.
+ /// &RETURNS&: Tensor
fn mean_all(&self) -> PyResult<Self> {
let elements = self.0.elem_count();
let sum = self.0.sum_all().map_err(wrap_err)?;
@@ -506,54 +618,83 @@ impl PyTensor {
Ok(PyTensor(mean))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
+ /// &RETURNS&: Tensor
fn flatten_from(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
+ /// &RETURNS&: Tensor
fn flatten_to(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
}
+ /// Flattens the tensor into a 1D tensor.
+ /// &RETURNS&: Tensor
fn flatten_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
}
+ /// Transposes the tensor.
+ /// &RETURNS&: Tensor
fn t(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.t().map_err(wrap_err)?))
}
+ /// Makes the tensor contiguous in memory.
+ /// &RETURNS&: Tensor
fn contiguous(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
}
+ /// Returns true if the tensor is contiguous in C order.
+ /// &RETURNS&: bool
fn is_contiguous(&self) -> bool {
self.0.is_contiguous()
}
+ /// Returns true if the tensor is contiguous in Fortran order.
+ /// &RETURNS&: bool
fn is_fortran_contiguous(&self) -> bool {
self.0.is_fortran_contiguous()
}
+ /// Detach the tensor from the computation graph.
+ /// &RETURNS&: Tensor
fn detach(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
}
+ /// Returns a copy of the tensor.
+ /// &RETURNS&: Tensor
fn copy(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
+ /// Convert the tensor to a new dtype.
+ /// &RETURNS&: Tensor
fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, device:Union[str,Device])")]
+ /// Move the tensor to a new device.
+ /// &RETURNS&: Tensor
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
let device = device.as_device()?;
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, quantized_dtype:str)")]
+ /// Quantize the tensor.
+ /// &RETURNS&: QTensor
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype {
@@ -581,8 +722,10 @@ impl PyTensor {
}
}
-/// Concatenate the tensors across one axis.
#[pyfunction]
+#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
+/// Concatenate the tensors across one axis.
+/// &RETURNS&: Tensor
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
@@ -594,6 +737,9 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
}
#[pyfunction]
+#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
+/// Stack the tensors along a new axis.
+/// &RETURNS&: Tensor
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
@@ -601,12 +747,17 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
}
#[pyfunction]
-fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
- PyTensor::new(py, vs)
+#[pyo3(text_signature = "(data:_ArrayLike)")]
+/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+/// &RETURNS&: Tensor
+fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
+ PyTensor::new(py, data)
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None))]
+#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+/// Creates a new tensor with random values.
+/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -614,7 +765,9 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None))]
+#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+/// Creates a new tensor with random values from a normal distribution.
+/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -622,7 +775,9 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None))]
+#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+/// Creates a new tensor filled with ones.
+/// &RETURNS&: Tensor
fn ones(
py: Python<'_>,
shape: PyShape,
@@ -639,7 +794,9 @@ fn ones(
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None))]
+#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+/// Creates a new tensor filled with zeros.
+/// &RETURNS&: Tensor
fn zeros(
py: Python<'_>,
shape: PyShape,
@@ -655,8 +812,9 @@ fn zeros(
Ok(PyTensor(tensor))
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[pyclass(name = "QTensor")]
+/// A quantized tensor.
struct PyQTensor(Arc<QTensor>);
impl std::ops::Deref for PyQTensor {
@@ -670,16 +828,22 @@ impl std::ops::Deref for PyQTensor {
#[pymethods]
impl PyQTensor {
#[getter]
+ ///Gets the tensors quantized dtype.
+ /// &RETURNS&: str
fn ggml_dtype(&self) -> String {
format!("{:?}", self.0.dtype())
}
#[getter]
+ ///Gets the rank of the tensor.
+ /// &RETURNS&: int
fn rank(&self) -> usize {
self.0.rank()
}
#[getter]
+ ///Gets the shape of the tensor.
+ /// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.shape().dims()).to_object(py)
}
@@ -692,11 +856,16 @@ impl PyQTensor {
self.__repr__()
}
+ /// Dequantizes the tensor.
+ /// &RETURNS&: Tensor
fn dequantize(&self) -> PyResult<PyTensor> {
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
+ #[pyo3(text_signature = "(self, lhs:Tensor)")]
+ /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
+ /// &RETURNS&: Tensor
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
@@ -705,6 +874,9 @@ impl PyQTensor {
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
+/// &RETURNS&: Dict[str,Tensor]
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
@@ -715,6 +887,25 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
+/// Saves a dictionary of tensors to a safetensors file.
+/// &RETURNS&: None
+fn save_safetensors(
+ path: &str,
+ tensors: std::collections::HashMap<String, PyTensor>,
+) -> PyResult<()> {
+ let tensors = tensors
+ .into_iter()
+ .map(|(s, t)| (s, t.0))
+ .collect::<std::collections::HashMap<_, _>>();
+ ::candle::safetensors::save(&tensors, path).map_err(wrap_err)
+}
+
+#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
+/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
+/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@@ -746,10 +937,39 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
-fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
+/// and the second maps metadata keys to metadata values.
+/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
+fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+ use ::candle::quantized::gguf_file;
+ fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
+ let v: PyObject = match v {
+ gguf_file::Value::U8(x) => x.into_py(py),
+ gguf_file::Value::I8(x) => x.into_py(py),
+ gguf_file::Value::U16(x) => x.into_py(py),
+ gguf_file::Value::I16(x) => x.into_py(py),
+ gguf_file::Value::U32(x) => x.into_py(py),
+ gguf_file::Value::I32(x) => x.into_py(py),
+ gguf_file::Value::U64(x) => x.into_py(py),
+ gguf_file::Value::I64(x) => x.into_py(py),
+ gguf_file::Value::F32(x) => x.into_py(py),
+ gguf_file::Value::F64(x) => x.into_py(py),
+ gguf_file::Value::Bool(x) => x.into_py(py),
+ gguf_file::Value::String(x) => x.into_py(py),
+ gguf_file::Value::Array(x) => {
+ let list = pyo3::types::PyList::empty(py);
+ for elem in x.iter() {
+ list.append(gguf_value_to_pyobject(elem, py)?)?;
+ }
+ list.into()
+ }
+ };
+ Ok(v)
+ }
let mut file = std::fs::File::open(path)?;
- let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?;
- let res = gguf
+ let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?;
+ let tensors = gguf
.tensor_infos
.keys()
.map(|key| {
@@ -758,25 +978,129 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
- Ok(res.into_py_dict(py).to_object(py))
+ let tensors = tensors.into_py_dict(py).to_object(py);
+ let metadata = gguf
+ .metadata
+ .iter()
+ .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
+ .collect::<PyResult<Vec<_>>>()?
+ .into_py_dict(py)
+ .to_object(py);
+ Ok((tensors, metadata))
}
#[pyfunction]
+#[pyo3(
+ text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
+)]
+/// Save quanitzed tensors and metadata to a GGUF file.
+fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
+ use ::candle::quantized::gguf_file;
+
+ fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
+ let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
+ gguf_file::Value::U8(x)
+ } else if let Ok(x) = v.extract::<i8>() {
+ gguf_file::Value::I8(x)
+ } else if let Ok(x) = v.extract::<u16>() {
+ gguf_file::Value::U16(x)
+ } else if let Ok(x) = v.extract::<i16>() {
+ gguf_file::Value::I16(x)
+ } else if let Ok(x) = v.extract::<u32>() {
+ gguf_file::Value::U32(x)
+ } else if let Ok(x) = v.extract::<i32>() {
+ gguf_file::Value::I32(x)
+ } else if let Ok(x) = v.extract::<u64>() {
+ gguf_file::Value::U64(x)
+ } else if let Ok(x) = v.extract::<i64>() {
+ gguf_file::Value::I64(x)
+ } else if let Ok(x) = v.extract::<f32>() {
+ gguf_file::Value::F32(x)
+ } else if let Ok(x) = v.extract::<f64>() {
+ gguf_file::Value::F64(x)
+ } else if let Ok(x) = v.extract::<bool>() {
+ gguf_file::Value::Bool(x)
+ } else if let Ok(x) = v.extract::<String>() {
+ gguf_file::Value::String(x)
+ } else if let Ok(x) = v.extract::<Vec<PyObject>>() {
+ let x = x
+ .into_iter()
+ .map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
+ .collect::<PyResult<Vec<_>>>()?;
+ gguf_file::Value::Array(x)
+ } else {
+ return Err(PyErr::new::<PyValueError, _>(format!(
+ "unsupported type {:?}",
+ v
+ )));
+ };
+ Ok(v)
+ }
+ let tensors = tensors
+ .extract::<&PyDict>(py)
+ .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
+ .iter()
+ .map(|(key, value)| {
+ Ok((
+ key.extract::<String>()
+ .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
+ value.extract::<PyQTensor>()?.0,
+ ))
+ })
+ .collect::<PyResult<Vec<_>>>()?;
+
+ let metadata = metadata
+ .extract::<&PyDict>(py)
+ .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
+ .iter()
+ .map(|(key, value)| {
+ Ok((
+ key.extract::<String>()
+ .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
+ pyobject_to_gguf_value(value, py)?,
+ ))
+ })
+ .collect::<PyResult<Vec<_>>>()?;
+
+ let converted_metadata: Vec<_> = metadata
+ .iter()
+ .map(|(name, value)| (name.as_str(), value))
+ .collect();
+
+ let converted_tensors: Vec<_> = tensors
+ .iter()
+ .map(|(name, tensor)| (name.as_str(), tensor.as_ref()))
+ .collect();
+
+ let mut file = std::fs::File::create(path)?;
+
+ gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err)
+}
+
+#[pyfunction]
+/// Returns true if the 'cuda' backend is available.
+/// &RETURNS&: bool
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
+/// Returns true if candle was compiled with 'accelerate' support.
+/// &RETURNS&: bool
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
+/// Returns true if candle was compiled with MKL support.
+/// &RETURNS&: bool
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
+/// Returns the number of threads used by the candle.
+/// &RETURNS&: int
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
@@ -786,19 +1110,30 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
+ m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
+ m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
+ m.add_function(wrap_pyfunction!(save_gguf, m)?)?;
+ m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
+ m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
Ok(())
}
#[pyfunction]
-fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
- let dim = actual_dim(&t, dim).map_err(wrap_err)?;
- let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
+#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
+/// Applies the Softmax function to a given tensor.#
+/// &RETURNS&: Tensor
+fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
+ let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
+ let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
Ok(PyTensor(sm))
}
#[pyfunction]
-fn silu(t: PyTensor) -> PyResult<PyTensor> {
- let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
+#[pyo3(text_signature = "(tensor:Tensor)")]
+/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+/// &RETURNS&: Tensor
+fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
+ let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
Ok(PyTensor(s))
}
@@ -827,9 +1162,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
- m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
- m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
- m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
new file mode 100644
index 00000000..149715c2
--- /dev/null
+++ b/candle-pyo3/stub.py
@@ -0,0 +1,232 @@
+#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
+import argparse
+import inspect
+import os
+from typing import Optional
+import black
+from pathlib import Path
+
+
+INDENT = " " * 4
+GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
+TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+"""
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
+CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
+RETURN_TYPE_MARKER = "&RETURNS&: "
+
+
+def do_indent(text: Optional[str], indent: str):
+ if text is None:
+ return ""
+ return text.replace("\n", f"\n{indent}")
+
+
+def function(obj, indent:str, text_signature:str=None):
+ if text_signature is None:
+ text_signature = obj.__text_signature__
+
+ text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
+ doc_string = obj.__doc__
+ if doc_string is None:
+ doc_string = ""
+
+ # Check if we have a return type annotation in the docstring
+ return_type = None
+ doc_lines = doc_string.split("\n")
+ if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
+ # Extract the return type and remove it from the docstring
+ return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
+ doc_string = "\n".join(doc_lines[:-1])
+
+ string = ""
+ if return_type:
+ string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n"
+ else:
+ string += f"{indent}def {obj.__name__}{text_signature}:\n"
+ indent += INDENT
+ string += f'{indent}"""\n'
+ string += f"{indent}{do_indent(doc_string, indent)}\n"
+ string += f'{indent}"""\n'
+ string += f"{indent}pass\n"
+ string += "\n"
+ string += "\n"
+ return string
+
+
+def member_sort(member):
+ if inspect.isclass(member):
+ value = 10 + len(inspect.getmro(member))
+ else:
+ value = 1
+ return value
+
+
+def fn_predicate(obj):
+ value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
+ if value:
+ return obj.__text_signature__ and not obj.__name__.startswith("_")
+ if inspect.isgetsetdescriptor(obj):
+ return not obj.__name__.startswith("_")
+ return False
+
+
+def get_module_members(module):
+ members = [
+ member
+ for name, member in inspect.getmembers(module)
+ if not name.startswith("_") and not inspect.ismodule(member)
+ ]
+ members.sort(key=member_sort)
+ return members
+
+
+def pyi_file(obj, indent=""):
+ string = ""
+ if inspect.ismodule(obj):
+ string += GENERATED_COMMENT
+ string += TYPING
+ string += CANDLE_SPECIFIC_TYPING
+ if obj.__name__ != "candle.candle":
+ string += CANDLE_TENSOR_IMPORTS
+ members = get_module_members(obj)
+ for member in members:
+ string += pyi_file(member, indent)
+
+ elif inspect.isclass(obj):
+ indent += INDENT
+ mro = inspect.getmro(obj)
+ if len(mro) > 2:
+ inherit = f"({mro[1].__name__})"
+ else:
+ inherit = ""
+ string += f"class {obj.__name__}{inherit}:\n"
+
+ body = ""
+ if obj.__doc__:
+ body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
+
+ fns = inspect.getmembers(obj, fn_predicate)
+
+ # Init
+ if obj.__text_signature__:
+ body += f"{indent}def __init__{obj.__text_signature__}:\n"
+ body += f"{indent+INDENT}pass\n"
+ body += "\n"
+
+ for (name, fn) in fns:
+ body += pyi_file(fn, indent=indent)
+
+ if not body:
+ body += f"{indent}pass\n"
+
+ string += body
+ string += "\n\n"
+
+ elif inspect.isbuiltin(obj):
+ string += f"{indent}@staticmethod\n"
+ string += function(obj, indent)
+
+ elif inspect.ismethoddescriptor(obj):
+ string += function(obj, indent)
+
+ elif inspect.isgetsetdescriptor(obj):
+ # TODO it would be interesing to add the setter maybe ?
+ string += f"{indent}@property\n"
+ string += function(obj, indent, text_signature="(self)")
+
+ elif obj.__class__.__name__ == "DType":
+ string += f"class {str(obj).lower()}(DType):\n"
+ string += f"{indent+INDENT}pass\n"
+ else:
+ raise Exception(f"Object {obj} is not supported")
+ return string
+
+
+def py_file(module, origin):
+ members = get_module_members(module)
+
+ string = GENERATED_COMMENT
+ string += f"from .. import {origin}\n"
+ string += "\n"
+ for member in members:
+ if hasattr(member, "__name__"):
+ name = member.__name__
+ else:
+ name = str(member)
+ string += f"{name} = {origin}.{name}\n"
+ return string
+
+
+def do_black(content, is_pyi):
+ mode = black.Mode(
+ target_versions={black.TargetVersion.PY35},
+ line_length=119,
+ is_pyi=is_pyi,
+ string_normalization=True,
+ experimental_string_processing=False,
+ )
+ try:
+ return black.format_file_contents(content, fast=True, mode=mode)
+ except black.NothingChanged:
+ return content
+
+
+def write(module, directory, origin, check=False):
+ submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
+
+ filename = os.path.join(directory, "__init__.pyi")
+ pyi_content = pyi_file(module)
+ pyi_content = do_black(pyi_content, is_pyi=True)
+ os.makedirs(directory, exist_ok=True)
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(pyi_content)
+
+ filename = os.path.join(directory, "__init__.py")
+ py_content = py_file(module, origin)
+ py_content = do_black(py_content, is_pyi=False)
+ os.makedirs(directory, exist_ok=True)
+
+ is_auto = False
+ if not os.path.exists(filename):
+ is_auto = True
+ else:
+ with open(filename, "r") as f:
+ line = f.readline()
+ if line == GENERATED_COMMENT:
+ is_auto = True
+
+ if is_auto:
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(py_content)
+
+ for name, submodule in submodules:
+ write(submodule, os.path.join(directory, name), f"{name}", check=check)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check", action="store_true")
+
+ args = parser.parse_args()
+
+ #Enable execution from the candle and candle-pyo3 directories
+ cwd = Path.cwd()
+ directory = "py_src/candle/"
+ if cwd.name != "candle-pyo3":
+ directory = f"candle-pyo3/{directory}"
+
+ import candle
+
+ write(candle.candle, directory, "candle", check=args.check)
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index a05b9bb7..a3115c2b 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -11,14 +11,21 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.2.1" }
+candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
+candle-nn = { path = "../candle-nn", version = "0.2.3" }
intel-mkl-src = { workspace = true, optional = true }
+num-traits = { workspace = true }
rand = { workspace = true }
+rayon = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+tracing = { workspace = true }
wav = { workspace = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
+flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index b1d20168..b1a567c3 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -1,35 +1,82 @@
-use candle::{DType, Error, Result, Tensor, D};
+use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
+ top_p: Option<f64>,
}
impl LogitsProcessor {
- pub fn new(seed: u64, temperature: Option<f64>) -> Self {
+ pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
+ let temperature = if temperature.map_or(true, |v| v < 1e-7) {
+ None
+ } else {
+ temperature
+ };
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
+ top_p,
+ }
+ }
+
+ fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ let next_token = logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap();
+ Ok(next_token)
+ }
+
+ fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
+ let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
+ let next_token = distr.sample(&mut self.rng) as u32;
+ Ok(next_token)
+ }
+
+ fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
+ // tokens that exceed probability top_p. This way we never sample tokens that
+ // have very low probabilities and are less likely to go "off the rails".
+ let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
+
+ // Sort by descending probability.
+ argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
+
+ // Clamp smaller probabilities to zero.
+ let mut cumsum = 0.;
+ for index in &argsort_indices {
+ if cumsum >= top_p {
+ prs[*index] = 0.0;
+ } else {
+ cumsum += prs[*index];
+ }
}
+ // Sample with clamped probabilities.
+ self.sample_multinomial(prs)
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
- let temperature = self.temperature.unwrap_or(0.);
- let next_token = if temperature > 0. {
- let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
- let prs: Vec<f32> = prs.to_vec1()?;
- let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
- distr.sample(&mut self.rng) as u32
- } else {
- let logits_v: Vec<f32> = logits.to_vec1()?;
- logits_v
- .iter()
- .enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
+ let next_token = match self.temperature {
+ None => self.sample_argmax(logits)?,
+ Some(temperature) => {
+ let logits = &(&logits / temperature)?;
+ let prs = candle_nn::ops::softmax_last_dim(logits)?;
+ let mut prs: Vec<f32> = prs.to_vec1()?;
+ let top_p = self.top_p.unwrap_or(1.);
+ if top_p <= 0.0 || top_p >= 1.0 {
+ // simply sample from the predicted probability distribution
+ self.sample_multinomial(&prs)?
+ } else {
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
+ self.sample_topp(&mut prs, top_p as f32)?
+ }
+ }
};
Ok(next_token)
}
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index a8890dc8..b83e5056 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -1,4 +1,5 @@
pub mod generation;
pub mod models;
+pub mod object_detection;
pub mod pipelines;
pub mod utils;
diff --git a/candle-examples/examples/bert/model.rs b/candle-transformers/src/models/bert.rs
index 3f164a3a..3f164a3a 100644
--- a/candle-examples/examples/bert/model.rs
+++ b/candle-transformers/src/models/bert.rs
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-transformers/src/models/bigcode.rs
index 1e63956b..1e63956b 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-transformers/src/models/bigcode.rs
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
new file mode 100644
index 00000000..0edc8494
--- /dev/null
+++ b/candle-transformers/src/models/dinov2.rs
@@ -0,0 +1,279 @@
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+const IMG_SIZE: usize = 518;
+const PATCH_SIZE: usize = 14;
+const NUM_CLASSES: usize = 1000;
+
+fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ if bias {
+ candle_nn::linear(in_dim, out_dim, vb)
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: Linear,
+ proj: Linear,
+ num_heads: usize,
+ scale: f64,
+}
+
+impl Attention {
+ fn new(
+ vb: VarBuilder,
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ proj_bias: bool,
+ ) -> Result<Self> {
+ let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
+ let scale = 1. / ((dim / num_heads) as f64).sqrt();
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = xs.dims3()?;
+ let qkv = self
+ .qkv
+ .forward(xs)?
+ .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
+ .transpose(1, 2)? // 02134
+ .transpose(0, 1)? // 20134
+ .transpose(2, 3)?; // 20314
+ let q = (qkv.i(0)? * self.scale)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
+ let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
+ self.proj.forward(&attn)
+ }
+}
+
+#[derive(Debug)]
+struct LayerScale {
+ gamma: Tensor,
+}
+
+impl LayerScale {
+ fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
+ let gamma = vb.get(dim, "gamma")?;
+ Ok(Self { gamma })
+ }
+}
+
+impl Module for LayerScale {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.broadcast_mul(&self.gamma)
+ }
+}
+
+#[derive(Debug)]
+struct Mlp {
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl Mlp {
+ fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
+ let out_features = in_features;
+ let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
+ let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?.gelu()?;
+ self.fc2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ ls1: LayerScale,
+ norm2: LayerNorm,
+ mlp: Mlp,
+ ls2: LayerScale,
+}
+
+impl Block {
+ fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
+ let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
+ let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
+ let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
+ let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
+ Ok(Self {
+ norm1,
+ attn,
+ ls1,
+ norm2,
+ mlp,
+ ls2,
+ })
+ }
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self
+ .ls1
+ .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = self
+ .ls2
+ .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+ patch_size: (usize, usize),
+ num_patches: usize,
+}
+
+impl PatchEmbed {
+ fn new(
+ vb: VarBuilder,
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ ) -> Result<Self> {
+ let config = candle_nn::Conv2dConfig {
+ stride: patch_size,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
+ let num_patches = (img_size / patch_size) * (img_size / patch_size);
+ Ok(Self {
+ proj,
+ patch_size: (patch_size, patch_size),
+ num_patches,
+ })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _c, h, w) = xs.dims4()?;
+ let (patch_h, patch_w) = self.patch_size;
+ if (h % patch_h) != 0 {
+ candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
+ }
+ if (w % patch_w) != 0 {
+ candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
+ }
+ let xs = self.proj.forward(xs)?;
+ let (b, c, h, w) = xs.dims4()?;
+ // flatten embeddings.
+ xs.reshape((b, c, h * w))?.transpose(1, 2)
+ }
+}
+
+#[derive(Debug)]
+pub struct DinoVisionTransformer {
+ patch_embed: PatchEmbed,
+ cls_token: Tensor,
+ pos_embed: Tensor,
+ blocks: Vec<Block>,
+ norm: LayerNorm,
+ head: Linear,
+}
+
+impl DinoVisionTransformer {
+ pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
+ let patch_embed =
+ PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
+ let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
+ let num_tokens = 1;
+ let pos_embed = vb.get(
+ (1, patch_embed.num_patches + num_tokens, embed_dim),
+ "pos_embed",
+ )?;
+ let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
+ let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
+ let vb_b = vb.pp("blocks");
+ let blocks = (0..depth)
+ .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ patch_embed,
+ cls_token,
+ pos_embed,
+ blocks,
+ norm,
+ head,
+ })
+ }
+
+ fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
+ let npatch = xs.dim(1)? - 1;
+ let n = self.pos_embed.dim(1)? - 1;
+ let sqrt_n = (n as f64).sqrt();
+ if npatch == n && w == h {
+ return Ok(xs.clone());
+ }
+ let class_pos_embed = self.pos_embed.i((.., ..1))?;
+ let patch_pos_embed = self.pos_embed.i((.., 1..))?;
+ let dim = xs.dim(D::Minus1)?;
+ let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
+ let patch_pos_embed = patch_pos_embed
+ .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
+ .transpose(2, 3)?
+ .transpose(1, 2)?;
+ // This uses bicubic interpolation in the original implementation.
+ let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
+ let el_count = patch_pos_embed.shape().elem_count();
+ let patch_pos_embed =
+ patch_pos_embed
+ .transpose(1, 2)?
+ .transpose(2, 3)?
+ .reshape((1, el_count / dim, dim))?;
+ Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
+ }
+
+ fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _nc, w, h) = xs.dims4()?;
+ let xs = self.patch_embed.forward(xs)?;
+ let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
+ &xs + &self.interpolate_pos_encoding(&xs, w, h)?
+ }
+}
+
+impl Module for DinoVisionTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.prepare_tokens_with_mask(xs)?;
+ for blk in self.blocks.iter() {
+ xs = blk.forward(&xs)?
+ }
+ let xs = self.norm.forward(&xs)?;
+ let xs_norm_clstoken = xs.i((.., 0))?;
+ let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
+ let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
+ self.head.forward(&xs)
+ }
+}
+
+pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
+ DinoVisionTransformer::new(vb, 12, 384, 6)
+}
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs
new file mode 100644
index 00000000..ab51c76d
--- /dev/null
+++ b/candle-transformers/src/models/efficientnet.rs
@@ -0,0 +1,331 @@
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+use nn::{Module, VarBuilder};
+
+// Based on the Python version from torchvision.
+// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
+#[derive(Debug, Clone, Copy)]
+pub struct MBConvConfig {
+ expand_ratio: f64,
+ kernel: usize,
+ stride: usize,
+ input_channels: usize,
+ out_channels: usize,
+ num_layers: usize,
+}
+
+fn make_divisible(v: f64, divisor: usize) -> usize {
+ let min_value = divisor;
+ let new_v = usize::max(
+ min_value,
+ (v + divisor as f64 * 0.5) as usize / divisor * divisor,
+ );
+ if (new_v as f64) < 0.9 * v {
+ new_v + divisor
+ } else {
+ new_v
+ }
+}
+
+fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
+ let bneck_conf = |e, k, s, i, o, n| {
+ let input_channels = make_divisible(i as f64 * width_mult, 8);
+ let out_channels = make_divisible(o as f64 * width_mult, 8);
+ let num_layers = (n as f64 * depth_mult).ceil() as usize;
+ MBConvConfig {
+ expand_ratio: e,
+ kernel: k,
+ stride: s,
+ input_channels,
+ out_channels,
+ num_layers,
+ }
+ };
+ vec![
+ bneck_conf(1., 3, 1, 32, 16, 1),
+ bneck_conf(6., 3, 2, 16, 24, 2),
+ bneck_conf(6., 5, 2, 24, 40, 2),
+ bneck_conf(6., 3, 2, 40, 80, 3),
+ bneck_conf(6., 5, 1, 80, 112, 3),
+ bneck_conf(6., 5, 2, 112, 192, 4),
+ bneck_conf(6., 3, 1, 192, 320, 1),
+ ]
+}
+
+impl MBConvConfig {
+ pub fn b0() -> Vec<Self> {
+ bneck_confs(1.0, 1.0)
+ }
+ pub fn b1() -> Vec<Self> {
+ bneck_confs(1.0, 1.1)
+ }
+ pub fn b2() -> Vec<Self> {
+ bneck_confs(1.1, 1.2)
+ }
+ pub fn b3() -> Vec<Self> {
+ bneck_confs(1.2, 1.4)
+ }
+ pub fn b4() -> Vec<Self> {
+ bneck_confs(1.4, 1.8)
+ }
+ pub fn b5() -> Vec<Self> {
+ bneck_confs(1.6, 2.2)
+ }
+ pub fn b6() -> Vec<Self> {
+ bneck_confs(1.8, 2.6)
+ }
+ pub fn b7() -> Vec<Self> {
+ bneck_confs(2.0, 3.1)
+ }
+}
+
+/// Conv2D with same padding.
+#[derive(Debug)]
+struct Conv2DSame {
+ conv2d: nn::Conv2d,
+ s: usize,
+ k: usize,
+}
+
+impl Conv2DSame {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ bias: bool,
+ ) -> Result<Self> {
+ let conv_config = nn::Conv2dConfig {
+ stride,
+ groups,
+ ..Default::default()
+ };
+ let conv2d = if bias {
+ nn::conv2d(i, o, k, conv_config, vb)?
+ } else {
+ nn::conv2d_no_bias(i, o, k, conv_config, vb)?
+ };
+ Ok(Self {
+ conv2d,
+ s: stride,
+ k,
+ })
+ }
+}
+
+impl Module for Conv2DSame {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let s = self.s;
+ let k = self.k;
+ let (_, _, ih, iw) = xs.dims4()?;
+ let oh = (ih + s - 1) / s;
+ let ow = (iw + s - 1) / s;
+ let pad_h = usize::max((oh - 1) * s + k - ih, 0);
+ let pad_w = usize::max((ow - 1) * s + k - iw, 0);
+ if pad_h > 0 || pad_w > 0 {
+ let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
+ let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
+ self.conv2d.forward(&xs)
+ } else {
+ self.conv2d.forward(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct ConvNormActivation {
+ conv2d: Conv2DSame,
+ bn2d: nn::BatchNorm,
+ activation: bool,
+}
+
+impl ConvNormActivation {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
+ let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
+ Ok(Self {
+ conv2d,
+ bn2d,
+ activation: true,
+ })
+ }
+
+ fn no_activation(self) -> Self {
+ Self {
+ activation: false,
+ ..self
+ }
+ }
+}
+
+impl Module for ConvNormActivation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.conv2d.forward(xs)?;
+ let xs = self.bn2d.forward(&xs)?;
+ if self.activation {
+ swish(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct SqueezeExcitation {
+ fc1: Conv2DSame,
+ fc2: Conv2DSame,
+}
+
+impl SqueezeExcitation {
+ fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
+ let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
+ let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for SqueezeExcitation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ // equivalent to adaptive_avg_pool2d([1, 1])
+ let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
+ let xs = self.fc1.forward(&xs)?;
+ let xs = swish(&xs)?;
+ let xs = self.fc2.forward(&xs)?;
+ let xs = nn::ops::sigmoid(&xs)?;
+ residual.broadcast_mul(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct MBConv {
+ expand_cna: Option<ConvNormActivation>,
+ depthwise_cna: ConvNormActivation,
+ squeeze_excitation: SqueezeExcitation,
+ project_cna: ConvNormActivation,
+ config: MBConvConfig,
+}
+
+impl MBConv {
+ fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
+ let vb = vb.pp("block");
+ let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
+ let expand_cna = if exp != c.input_channels {
+ Some(ConvNormActivation::new(
+ vb.pp("0"),
+ c.input_channels,
+ exp,
+ 1,
+ 1,
+ 1,
+ )?)
+ } else {
+ None
+ };
+ let start_index = if expand_cna.is_some() { 1 } else { 0 };
+ let depthwise_cna =
+ ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
+ let squeeze_channels = usize::max(1, c.input_channels / 4);
+ let squeeze_excitation =
+ SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
+ let project_cna =
+ ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
+ .no_activation();
+ Ok(Self {
+ expand_cna,
+ depthwise_cna,
+ squeeze_excitation,
+ project_cna,
+ config: c,
+ })
+ }
+}
+
+impl Module for MBConv {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let use_res_connect =
+ self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
+ let ys = match &self.expand_cna {
+ Some(expand_cna) => expand_cna.forward(xs)?,
+ None => xs.clone(),
+ };
+ let ys = self.depthwise_cna.forward(&ys)?;
+ let ys = self.squeeze_excitation.forward(&ys)?;
+ let ys = self.project_cna.forward(&ys)?;
+ if use_res_connect {
+ ys + xs
+ } else {
+ Ok(ys)
+ }
+ }
+}
+
+fn swish(s: &Tensor) -> Result<Tensor> {
+ s * nn::ops::sigmoid(s)?
+}
+
+#[derive(Debug)]
+pub struct EfficientNet {
+ init_cna: ConvNormActivation,
+ blocks: Vec<MBConv>,
+ final_cna: ConvNormActivation,
+ classifier: nn::Linear,
+}
+
+impl EfficientNet {
+ pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
+ let f_p = p.pp("features");
+ let first_in_c = configs[0].input_channels;
+ let last_out_c = configs.last().unwrap().out_channels;
+ let final_out_c = 4 * last_out_c;
+ let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
+ let nconfigs = configs.len();
+ let mut blocks = vec![];
+ for (index, cnf) in configs.into_iter().enumerate() {
+ let f_p = f_p.pp(index + 1);
+ for r_index in 0..cnf.num_layers {
+ let cnf = if r_index == 0 {
+ cnf
+ } else {
+ MBConvConfig {
+ input_channels: cnf.out_channels,
+ stride: 1,
+ ..cnf
+ }
+ };
+ blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
+ }
+ }
+ let final_cna =
+ ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
+ let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
+ Ok(Self {
+ init_cna,
+ blocks,
+ final_cna,
+ classifier,
+ })
+ }
+}
+
+impl Module for EfficientNet {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.init_cna.forward(xs)?;
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ let xs = self.final_cna.forward(&xs)?;
+ // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
+ let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
+ self.classifier.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/falcon/model.rs b/candle-transformers/src/models/falcon.rs
index b638dd51..6ede136a 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -1,5 +1,4 @@
-use anyhow::Result;
-use candle::{DType, Device, Tensor, D};
+use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
@@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
- return Err(err.into());
+ return Err(err);
}
}
};
@@ -82,13 +81,13 @@ impl Default for Config {
impl Config {
pub fn validate(&self) -> Result<()> {
if self.alibi {
- anyhow::bail!("alibi is not supported");
+ candle::bail!("alibi is not supported");
}
if self.new_decoder_architecture {
- anyhow::bail!("new_decoder_architecture is not supported");
+ candle::bail!("new_decoder_architecture is not supported");
}
if self.n_head_kv.is_some() {
- anyhow::bail!("n_head_kv is not supported");
+ candle::bail!("n_head_kv is not supported");
}
Ok(())
}
diff --git a/candle-examples/examples/llama/model.rs b/candle-transformers/src/models/llama.rs
index 275856e0..eed4df5e 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -4,7 +4,7 @@ use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
-use super::MAX_SEQ_LEN;
+pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Deserialize)]
pub struct LlamaConfig {
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 8b137891..d783a2c6 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -1 +1,13 @@
-
+pub mod bert;
+pub mod bigcode;
+pub mod dinov2;
+pub mod efficientnet;
+pub mod falcon;
+pub mod llama;
+pub mod quantized_llama;
+pub mod quantized_t5;
+pub mod segment_anything;
+pub mod stable_diffusion;
+pub mod t5;
+pub mod whisper;
+pub mod wuerstchen;
diff --git a/candle-examples/examples/quantized/model.rs b/candle-transformers/src/models/quantized_llama.rs
index da0bd0b0..2988b0fb 100644
--- a/candle-examples/examples/quantized/model.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -144,7 +144,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+ let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
new file mode 100644
index 00000000..a10c3b80
--- /dev/null
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -0,0 +1,884 @@
+// T5 Text Model, quantized version
+// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
+
+use candle::quantized::QTensor;
+use candle::{DType, Device, Module, Result, Shape, Tensor, D};
+use candle_nn::Activation;
+use serde::Deserialize;
+use std::sync::Arc;
+
+// VarBuilder specialized for QTensors
+pub struct VarBuilder {
+ data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
+ path: Vec<String>,
+ device: Device,
+}
+
+impl VarBuilder {
+ pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ let mut file = std::fs::File::open(p)?;
+ let content = candle::quantized::gguf_file::Content::read(&mut file)?;
+ let mut data = std::collections::HashMap::new();
+ for tensor_name in content.tensor_infos.keys() {
+ let tensor = content.tensor(&mut file, tensor_name)?;
+ data.insert(tensor_name.to_string(), Arc::new(tensor));
+ }
+ Ok(Self {
+ data: Arc::new(data),
+ path: Vec::new(),
+ device: Device::Cpu,
+ })
+ }
+
+ fn pp<S: ToString>(&self, s: S) -> Self {
+ let mut path = self.path.clone();
+ path.push(s.to_string());
+ Self {
+ data: self.data.clone(),
+ path,
+ device: self.device.clone(),
+ }
+ }
+
+ fn path(&self, tensor_name: &str) -> String {
+ if self.path.is_empty() {
+ tensor_name.to_string()
+ } else {
+ [&self.path.join("."), tensor_name].join(".")
+ }
+ }
+
+ fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
+ let path = self.path(name);
+ match self.data.get(&path) {
+ None => {
+ candle::bail!("cannot find tensor {name}")
+ }
+ Some(qtensor) => {
+ let shape = s.into();
+ if qtensor.shape() != &shape {
+ candle::bail!(
+ "shape mismatch for {name}, got {:?}, expected {shape:?}",
+ qtensor.shape()
+ )
+ }
+ Ok(qtensor.clone())
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?;
+ let inner = candle_nn::Embedding::new(embeddings, d2);
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+// QMatMul wrapper adding some tracing.
+struct QMatMul {
+ inner: candle::quantized::QMatMul,
+ span: tracing::Span,
+}
+
+impl QMatMul {
+ fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let ws = vb.get((in_dim, out_dim), "weight")?;
+ let inner = candle::quantized::QMatMul::from_arc(ws);
+ let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+impl std::fmt::Debug for QMatMul {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "QMatMul")
+ }
+}
+
+fn default_relative_attention_max_distance() -> usize {
+ 128
+}
+
+fn default_is_decoder() -> bool {
+ false
+}
+
+fn default_use_cache() -> bool {
+ true
+}
+
+fn default_tie_word_embeddings() -> bool {
+ true
+}
+
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ d_model: usize,
+ d_kv: usize,
+ d_ff: usize,
+ num_layers: usize,
+ num_decoder_layers: Option<usize>,
+ num_heads: usize,
+ relative_attention_num_buckets: usize,
+ #[serde(default = "default_relative_attention_max_distance")]
+ relative_attention_max_distance: usize,
+ dropout_rate: f64,
+ layer_norm_epsilon: f64,
+ initializer_factor: f64,
+ #[serde(default)]
+ feed_forward_proj: Activation,
+ #[serde(default = "default_tie_word_embeddings")]
+ tie_word_embeddings: bool,
+ #[serde(default = "default_is_decoder")]
+ is_decoder: bool,
+ is_encoder_decoder: bool,
+ #[serde(default = "default_use_cache")]
+ pub use_cache: bool,
+ pub pad_token_id: usize,
+ pub eos_token_id: usize,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 32128,
+ d_model: 512,
+ d_kv: 64,
+ d_ff: 2048,
+ num_layers: 6,
+ num_decoder_layers: None,
+ num_heads: 8,
+ relative_attention_num_buckets: 32,
+ relative_attention_max_distance: 128,
+ dropout_rate: 0.1,
+ layer_norm_epsilon: 1e-6,
+ initializer_factor: 1.0,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ use_cache: true,
+ pad_token_id: 0,
+ eos_token_id: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerNorm {
+ weight: Tensor,
+ variance_epsilon: f64,
+ span: tracing::Span,
+}
+
+impl T5LayerNorm {
+ fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(h, "weight")?.dequantize(&vb.device)?;
+ Ok(Self {
+ weight,
+ variance_epsilon: eps,
+ span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
+ })
+ }
+}
+
+impl Module for T5LayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let dtype = xs.dtype();
+ let xs_f32 = xs.to_dtype(DType::F32)?;
+ // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
+ let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
+ let xs = xs.to_dtype(dtype)?;
+ let xs = xs.broadcast_mul(&self.weight)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseActDense {
+ wi: QMatMul,
+ wo: QMatMul,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi,
+ wo,
+ act: Activation::Relu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.wi.forward(xs)?;
+ let xs = self.act.forward(&xs)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseGatedActDense {
+ wi_0: QMatMul,
+ wi_1: QMatMul,
+ wo: QMatMul,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseGatedActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi_0,
+ wi_1,
+ wo,
+ act: Activation::NewGelu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseGatedActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
+ let hidden_linear = self.wi_1.forward(xs)?;
+ let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerFF {
+ dense_act: Option<T5DenseActDense>,
+ gated_dense_act: Option<T5DenseGatedActDense>,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerFF {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
+ (
+ None,
+ Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ )
+ } else {
+ (
+ Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ None,
+ )
+ };
+ Ok(Self {
+ dense_act,
+ gated_dense_act,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
+ })
+ }
+}
+
+impl Module for T5LayerFF {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let ys = self.layer_norm.forward(xs)?;
+ let ys = match &self.dense_act {
+ Some(dense_act) => dense_act.forward(&ys)?,
+ None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
+ };
+ let xs = (xs + ys)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5Attention {
+ q: QMatMul,
+ k: QMatMul,
+ v: QMatMul,
+ o: QMatMul,
+ n_heads: usize,
+ d_kv: usize,
+ relative_attention_bias: Option<Embedding>,
+ relative_attention_num_buckets: usize,
+ relative_attention_max_distance: usize,
+ inner_dim: usize,
+ use_cache: bool,
+ kv_cache: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_cache: tracing::Span,
+ span_mm: tracing::Span,
+ span_sm: tracing::Span,
+}
+
+impl T5Attention {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let inner_dim = cfg.num_heads * cfg.d_kv;
+ let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?;
+ let relative_attention_bias = if has_relative_attention_bias {
+ let emb = Embedding::new(
+ cfg.relative_attention_num_buckets,
+ cfg.num_heads,
+ vb.pp("relative_attention_bias"),
+ )?;
+ Some(emb)
+ } else {
+ None
+ };
+ Ok(Self {
+ q,
+ k,
+ v,
+ o,
+ n_heads: cfg.num_heads,
+ d_kv: cfg.d_kv,
+ relative_attention_bias,
+ relative_attention_num_buckets: cfg.relative_attention_num_buckets,
+ relative_attention_max_distance: cfg.relative_attention_max_distance,
+ inner_dim,
+ use_cache: cfg.use_cache && decoder,
+ kv_cache: None,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
+ span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
+ span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ // Performs Self-attention (if key_value_states is None) or attention
+ // over source sentence (provided by key_value_states).
+ let _enter = self.span.enter();
+ let kv_input = match key_value_states {
+ None => xs,
+ Some(key_value_states) => key_value_states,
+ };
+ let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_len = kv_input.dim(1)?;
+ let q = self.q.forward(xs)?;
+ let k = self.k.forward(kv_input)?;
+ let v = self.v.forward(kv_input)?;
+ let q = q
+ .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut k = k
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut v = v
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ if self.use_cache {
+ let _enter = self.span_cache.enter();
+ if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
+ k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+ };
+ // TODO: Use flash_attn.
+ let scores = {
+ let _enter = self.span_mm.enter();
+ q.matmul(&k.t()?)?
+ };
+ let scores = match mask {
+ None => scores,
+ Some(mask) => masked_fill(
+ &scores,
+ &mask
+ .unsqueeze(0)?
+ .unsqueeze(0)?
+ .repeat((b_sz, self.n_heads))?,
+ f32::NEG_INFINITY,
+ )?,
+ };
+
+ let (scores, position_bias) = match position_bias {
+ Some(position_bias) => (
+ scores.broadcast_add(position_bias)?,
+ Some(position_bias.clone()),
+ ),
+ None => match &self.relative_attention_bias {
+ None => (scores, None),
+ Some(relative_attention_bias) => {
+ // This only handles the bidirectional case.
+ let kv_len = k.dim(2)?;
+ let (q_start, q_end) = match self.use_cache {
+ true => ((kv_len - q_len) as u32, kv_len as u32),
+ false => (0_u32, kv_len as u32),
+ };
+ let num_buckets = self.relative_attention_num_buckets as u32 / 2;
+ let max_exact = num_buckets / 2;
+ let relative_position = (q_start..q_end)
+ .map(|i| {
+ (0..kv_len as u32)
+ .map(|j| {
+ if i < j {
+ if j - i < max_exact {
+ j - i + num_buckets
+ } else {
+ let b = f32::log(
+ (j - i) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ u32::min(
+ max_exact + num_buckets + b as u32,
+ self.relative_attention_num_buckets as u32 - 1,
+ )
+ }
+ } else if i - j < max_exact {
+ i - j
+ } else {
+ let b = f32::log(
+ (i - j) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ max_exact + b as u32
+ }
+ })
+ .collect::<Vec<u32>>()
+ })
+ .collect::<Vec<Vec<_>>>();
+ let relative_buckets = Tensor::new(relative_position, q.device())?;
+ let position_bias = relative_attention_bias
+ .forward(&relative_buckets)?
+ .permute((2, 0, 1))?
+ .unsqueeze(0)?;
+ (scores.broadcast_add(&position_bias)?, Some(position_bias))
+ // TODO: position_bias_masked?
+ }
+ },
+ };
+
+ let attn_weights = {
+ let _enter = self.span_sm.enter();
+ candle_nn::ops::softmax(&scores, D::Minus1)?
+ };
+ let attn_output = attn_weights.matmul(&v)?;
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.inner_dim))?;
+ let attn_output = self.o.forward(&attn_output)?;
+ Ok((attn_output, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerSelfAttention {
+ self_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerSelfAttention {
+ fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ self_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_xs = self.layer_norm.forward(xs)?;
+ let (ys, position_bias) =
+ self.self_attention
+ .forward(&normed_xs, position_bias, None, mask)?;
+ let ys = (xs + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerCrossAttention {
+ cross_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerCrossAttention {
+ fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ cross_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ hidden_states: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: &Tensor,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
+ let (ys, position_bias) = self.cross_attention.forward(
+ &normed_hidden_states,
+ position_bias,
+ Some(key_value_states),
+ None,
+ )?;
+ let ys = (hidden_states + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.cross_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5Block {
+ self_attn: T5LayerSelfAttention,
+ cross_attn: Option<T5LayerCrossAttention>,
+ ff: T5LayerFF,
+ span: tracing::Span,
+}
+
+impl T5Block {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let vb = vb.pp("layer");
+ let self_attn =
+ T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
+ let cross_attn = if cfg.is_decoder {
+ Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
+ } else {
+ None
+ };
+ let ff_i = if cross_attn.is_some() { 2 } else { 1 };
+ let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?;
+ Ok(Self {
+ self_attn,
+ cross_attn,
+ ff,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ // TODO: Cache masks
+ let mask = match self.cross_attn.is_some() {
+ true => {
+ let mask_len = xs.dim(1)?;
+ // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
+ // issues when using the KV cache in the decoder.
+ if mask_len <= 1 {
+ None
+ } else {
+ Some(get_mask(mask_len, xs.device())?)
+ }
+ }
+ false => None,
+ };
+ let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
+ // TODO: clamp for f16?
+ if let Some(cross_attn) = &mut self.cross_attn {
+ (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
+ // TODO: clamp for f16?
+ }
+ let xs = self.ff.forward(&xs)?;
+ // TODO: clamp for f16?
+ Ok((xs, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
+ }
+}
+
+#[derive(Debug)]
+struct T5Stack {
+ block: Vec<T5Block>,
+ shared: Arc<Embedding>,
+ final_layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5Stack {
+ fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
+ let block = (0..cfg.num_layers)
+ .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let final_layer_norm = T5LayerNorm::load(
+ cfg.d_model,
+ cfg.layer_norm_epsilon,
+ vb.pp("final_layer_norm"),
+ )?;
+ Ok(Self {
+ block,
+ shared: shared.clone(),
+ final_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "stack"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let mut hidden_states = input_embeds;
+ let mut position_bias = None;
+ for block in self.block.iter_mut() {
+ (hidden_states, position_bias) = block.forward(
+ &hidden_states,
+ position_bias.as_ref(),
+ encoder_hidden_states,
+ )?
+ }
+ self.final_layer_norm.forward(&hidden_states)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.block.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
+}
+
+#[derive(Debug)]
+pub struct T5EncoderModel {
+ encoder: T5Stack,
+ device: Device,
+ span: tracing::Span,
+}
+
+impl T5EncoderModel {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
+ Ok(Self {
+ encoder,
+ device: vb.device.clone(),
+ span: tracing::span!(tracing::Level::TRACE, "encoder"),
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+pub struct T5ForConditionalGeneration {
+ encoder: T5Stack,
+ decoder: T5Stack,
+ d_model: usize,
+ tie_word_embeddings: bool,
+ lm_head: Option<QMatMul>,
+ shared: Arc<Embedding>,
+ device: Device,
+ span_decode: tracing::Span,
+ span_decode_head: tracing::Span,
+}
+
+impl T5ForConditionalGeneration {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ assert!(cfg.is_encoder_decoder);
+ let d_model = cfg.d_model;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+
+ let mut encoder_cfg = cfg.clone();
+ encoder_cfg.is_decoder = false;
+ encoder_cfg.use_cache = false;
+ encoder_cfg.is_encoder_decoder = false;
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
+
+ let mut decoder_cfg = cfg.clone();
+ decoder_cfg.is_decoder = true;
+ decoder_cfg.is_encoder_decoder = false;
+ decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
+ let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
+
+ let tie_word_embeddings = cfg.tie_word_embeddings;
+ let lm_head = if tie_word_embeddings {
+ None
+ } else {
+ Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
+ };
+
+ Ok(Self {
+ encoder,
+ decoder,
+ d_model,
+ tie_word_embeddings,
+ lm_head,
+ shared,
+ device: vb.device.clone(),
+ span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
+ span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
+ })
+ }
+
+ pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn decode(
+ &mut self,
+ decoder_input_ids: &Tensor,
+ encoder_output: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span_decode.enter();
+ let decoder_output = self
+ .decoder
+ .forward(decoder_input_ids, Some(encoder_output))?;
+
+ let scaling_factor = if self.tie_word_embeddings {
+ // Rescale output before projecting on vocab
+ // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ (self.d_model as f64).sqrt()
+ } else {
+ 1.0
+ };
+ let sequence_output = ((decoder_output
+ .narrow(1, decoder_output.dim(1)? - 1, 1)?
+ .squeeze(1)?)
+ * scaling_factor)?;
+ let output = {
+ let _enter = self.span_decode_head.enter();
+ match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ }
+ };
+
+ // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
+ Ok(output)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
+ let encoder_output = self.encode(input_ids)?;
+ self.decode(decoder_input_ids, &encoder_output)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache();
+ self.decoder.clear_kv_cache();
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs
new file mode 100644
index 00000000..0b313830
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/image_encoder.rs
@@ -0,0 +1,483 @@
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl PatchEmbed {
+ fn new(
+ in_chans: usize,
+ embed_dim: usize,
+ k_size: usize,
+ stride: usize,
+ padding: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ stride,
+ padding,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { proj, span })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.proj)?.permute((0, 2, 3, 1))
+ }
+}
+
+// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
+// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
+// (attn.reshape((b, q_h, q_w, k_h, k_w))?
+// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+// .reshape((b, q_h * q_w, k_h * k_w))
+// Ideally we would perform this operation in place but this is not supported in candle at the
+// moment. We should also investigate using f16 rather than f32.
+struct Add3(usize, usize, usize, usize, usize);
+impl candle::CustomOp3 for Add3 {
+ fn name(&self) -> &'static str {
+ "add3"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &candle::CpuStorage,
+ l1: &candle::Layout,
+ s2: &candle::CpuStorage,
+ l2: &candle::Layout,
+ s3: &candle::CpuStorage,
+ l3: &candle::Layout,
+ ) -> Result<(candle::CpuStorage, candle::Shape)> {
+ use rayon::prelude::*;
+
+ let Add3(b, q_h, q_w, k_h, k_w) = *self;
+ let s1 = s1.as_slice::<f32>()?;
+ let s1 = match l1.contiguous_offsets() {
+ None => candle::bail!("input1 has to be contiguous"),
+ Some((o1, o2)) => &s1[o1..o2],
+ };
+ let s2 = s2.as_slice::<f32>()?;
+ let s2 = match l2.contiguous_offsets() {
+ None => candle::bail!("input2 has to be contiguous"),
+ Some((o1, o2)) => &s2[o1..o2],
+ };
+ let s3 = s3.as_slice::<f32>()?;
+ let s3 = match l3.contiguous_offsets() {
+ None => candle::bail!("input3 has to be contiguous"),
+ Some((o1, o2)) => &s3[o1..o2],
+ };
+ let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
+ dst.par_chunks_exact_mut(k_h * k_w)
+ .enumerate()
+ .for_each(|(b_idx, dst)| {
+ let s1_idx = b_idx * k_h * k_w;
+ let s2_idx = b_idx * k_h;
+ let s3_idx = b_idx * k_w;
+ for h_idx in 0..k_h {
+ let s1_idx = s1_idx + h_idx * k_w;
+ let s2_idx = s2_idx + h_idx;
+ let dst_idx = h_idx * k_w;
+ for w_idx in 0..k_w {
+ let s1_idx = s1_idx + w_idx;
+ let s3_idx = s3_idx + w_idx;
+ let dst_idx = dst_idx + w_idx;
+ dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
+ }
+ }
+ });
+ let dst = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: super::Linear,
+ proj: super::Linear,
+ num_heads: usize,
+ scale: f64,
+ rel_pos_hw: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_matmul: tracing::Span,
+ span_rel_pos: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl Attention {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ input_size: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
+ let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
+ let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
+ let head_dim = dim / num_heads;
+ let scale = 1. / (head_dim as f64).sqrt();
+ let rel_pos_hw = if use_rel_pos {
+ let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
+ let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
+ Some((h, w))
+ } else {
+ None
+ };
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ rel_pos_hw,
+ span,
+ span_matmul,
+ span_rel_pos,
+ span_softmax,
+ })
+ }
+
+ fn add_decomposed_rel_pos(
+ &self,
+ attn: Tensor,
+ q: &Tensor,
+ (q_h, q_w): (usize, usize),
+ (k_h, k_w): (usize, usize),
+ ) -> Result<Tensor> {
+ match &self.rel_pos_hw {
+ Some((rel_pos_h, rel_pos_w)) => {
+ let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
+ let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
+ let (b, _, dim) = q.dims3()?;
+ let r_q = q.reshape((b, q_h, q_w, dim))?;
+ // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
+ // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+ let rel_w = r_q
+ .transpose(1, 2)? // -> bwhc
+ .contiguous()?
+ .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
+ .transpose(1, 2)?
+ .contiguous()?;
+ if attn.device().is_cpu() {
+ let op = Add3(b, q_h, q_w, k_h, k_w);
+ attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
+ } else {
+ (attn.reshape((b, q_h, q_w, k_h, k_w))?
+ + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+ .reshape((b, q_h * q_w, k_h * k_w))
+ }
+ }
+ None => Ok(attn),
+ }
+ }
+}
+
+fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
+ let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
+ let dev = rel_pos.device();
+ let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
+ todo!("interpolation")
+ } else {
+ rel_pos
+ };
+ let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
+ .reshape((q_size, 1))?
+ .to_dtype(DType::F32)?;
+ let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
+ .reshape((1, k_size))?
+ .to_dtype(DType::F32)?;
+ let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
+ let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
+ let relative_coords = (q_coords.broadcast_sub(&k_coords)?
+ + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
+ let (d1, d2) = relative_coords.dims2()?;
+ let relative_coords = relative_coords.to_dtype(DType::U32)?;
+ rel_pos_resized
+ .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
+ .reshape((d1, d2, ()))
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b, h, w, c) = xs.dims4()?;
+ let qkv = self
+ .qkv
+ .forward(&xs.flatten_to(1)?)?
+ .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
+ .permute((2, 0, 3, 1, 4))?
+ .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
+ let q = qkv.i(0)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ (&q * self.scale)?.matmul(&k.t()?)?
+ };
+ let attn = {
+ let _enter = self.span_rel_pos.enter();
+ self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
+ };
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ attn.matmul(&v)?
+ };
+ let attn = attn
+ .reshape((b, self.num_heads, h, w, c / self.num_heads))?
+ .permute((0, 2, 3, 1, 4))?
+ .reshape((b, h * w, c))?;
+ self.proj.forward(&attn)?.reshape((b, h, w, c))
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ norm2: LayerNorm,
+ mlp: super::MlpBlock,
+ window_size: usize,
+ span: tracing::Span,
+}
+
+impl Block {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ window_size: usize,
+ input_size: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
+ let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
+ let input_size_attn = if window_size == 0 {
+ input_size
+ } else {
+ (window_size, window_size)
+ };
+ let attn = Attention::new(
+ dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ input_size_attn,
+ vb.pp("attn"),
+ )?;
+ let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ie-block");
+ Ok(Self {
+ norm1,
+ attn,
+ norm2,
+ mlp,
+ window_size,
+ span,
+ })
+ }
+}
+
+fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
+ let (b, h, w, c) = xs.dims4()?;
+ let pad_h = (window_size - h % window_size) % window_size;
+ let pad_w = (window_size - w % window_size) % window_size;
+ let xs = if pad_h > 0 {
+ xs.pad_with_zeros(1, 0, pad_h)?
+ } else {
+ xs
+ };
+ let xs = if pad_w > 0 {
+ xs.pad_with_zeros(2, 0, pad_w)?
+ } else {
+ xs
+ };
+ let (h_p, w_p) = (h + pad_h, w + pad_w);
+ let windows = xs
+ .reshape((
+ b,
+ h_p / window_size,
+ window_size,
+ w_p / window_size,
+ window_size,
+ c,
+ ))?
+ .transpose(2, 3)?
+ .contiguous()?
+ .flatten_to(2)?;
+ Ok((windows, (h_p, w_p)))
+}
+
+fn window_unpartition(
+ windows: Tensor,
+ window_size: usize,
+ (h_p, w_p): (usize, usize),
+ (h, w): (usize, usize),
+) -> Result<Tensor> {
+ let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
+ let xs = windows
+ .reshape((
+ b,
+ h_p / window_size,
+ w_p / window_size,
+ window_size,
+ window_size,
+ windows.elem_count() / b / h_p / w_p,
+ ))?
+ .transpose(2, 3)?
+ .contiguous()?
+ .reshape((b, h_p, w_p, ()))?;
+ let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
+ let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
+ Ok(xs)
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let shortcut = xs;
+ let xs = self.norm1.forward(xs)?;
+ let hw = (xs.dim(1)?, xs.dim(2)?);
+ let (xs, pad_hw) = if self.window_size > 0 {
+ window_partition(xs, self.window_size)?
+ } else {
+ (xs, (0, 0))
+ };
+ let xs = self.attn.forward(&xs)?;
+ let xs = if self.window_size > 0 {
+ window_unpartition(xs, self.window_size, pad_hw, hw)?
+ } else {
+ xs
+ };
+ let xs = (xs + shortcut)?;
+ &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
+ }
+}
+
+#[derive(Debug)]
+pub struct ImageEncoderViT {
+ patch_embed: PatchEmbed,
+ blocks: Vec<Block>,
+ neck_conv1: candle_nn::Conv2d,
+ neck_ln1: super::LayerNorm2d,
+ neck_conv2: candle_nn::Conv2d,
+ neck_ln2: super::LayerNorm2d,
+ pos_embed: Option<Tensor>,
+ span: tracing::Span,
+}
+
+impl ImageEncoderViT {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ depth: usize,
+ num_heads: usize,
+ out_chans: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ use_abs_pos: bool,
+ window_size: usize,
+ global_attn_indexes: &[usize],
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let patch_embed = PatchEmbed::new(
+ in_chans,
+ embed_dim,
+ patch_size,
+ patch_size,
+ 0,
+ vb.pp("patch_embed"),
+ )?;
+ let mut blocks = Vec::with_capacity(depth);
+ let vb_b = vb.pp("blocks");
+ for i in 0..depth {
+ let window_size = if global_attn_indexes.contains(&i) {
+ 0
+ } else {
+ window_size
+ };
+ let block = Block::new(
+ embed_dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ window_size,
+ (img_size / patch_size, img_size / patch_size),
+ vb_b.pp(i),
+ )?;
+ blocks.push(block)
+ }
+ let neck_conv1 = candle_nn::conv2d_no_bias(
+ embed_dim,
+ out_chans,
+ 1,
+ Default::default(),
+ vb.pp("neck.0"),
+ )?;
+ let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
+ let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
+ let pos_embed = if use_abs_pos {
+ let p = vb.get(
+ (1, img_size / patch_size, img_size / patch_size, embed_dim),
+ "pos_embed",
+ )?;
+ Some(p)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
+ Ok(Self {
+ patch_embed,
+ blocks,
+ neck_conv1,
+ neck_ln1,
+ neck_conv2,
+ neck_ln2,
+ pos_embed,
+ span,
+ })
+ }
+}
+
+impl Module for ImageEncoderViT {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.patch_embed.forward(xs)?;
+ let mut xs = match &self.pos_embed {
+ Some(pos_embed) => (xs + pos_embed)?,
+ None => xs,
+ };
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ xs.permute((0, 3, 1, 2))?
+ .apply(&self.neck_conv1)?
+ .apply(&self.neck_ln1)?
+ .apply(&self.neck_conv2)?
+ .apply(&self.neck_ln2)
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs
new file mode 100644
index 00000000..2a91cd44
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs
@@ -0,0 +1,239 @@
+use candle::{IndexOp, Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+use super::transformer::TwoWayTransformer;
+
+#[derive(Debug)]
+struct MlpMaskDecoder {
+ layers: Vec<super::Linear>,
+ sigmoid_output: bool,
+ span: tracing::Span,
+}
+
+impl MlpMaskDecoder {
+ fn new(
+ input_dim: usize,
+ hidden_dim: usize,
+ output_dim: usize,
+ num_layers: usize,
+ sigmoid_output: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut layers = Vec::with_capacity(num_layers);
+ let vb = vb.pp("layers");
+ for i in 0..num_layers {
+ let in_dim = if i == 0 { input_dim } else { hidden_dim };
+ let out_dim = if i + 1 == num_layers {
+ output_dim
+ } else {
+ hidden_dim
+ };
+ let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
+ layers.push(layer)
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
+ Ok(Self {
+ layers,
+ sigmoid_output,
+ span,
+ })
+ }
+}
+
+impl Module for MlpMaskDecoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for (i, layer) in self.layers.iter().enumerate() {
+ xs = layer.forward(&xs)?;
+ if i + 1 < self.layers.len() {
+ xs = xs.relu()?
+ }
+ }
+ if self.sigmoid_output {
+ candle_nn::ops::sigmoid(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct MaskDecoder {
+ iou_token: candle_nn::Embedding,
+ mask_tokens: candle_nn::Embedding,
+ iou_prediction_head: MlpMaskDecoder,
+ output_upscaling_conv1: candle_nn::ConvTranspose2d,
+ output_upscaling_ln: super::LayerNorm2d,
+ output_upscaling_conv2: candle_nn::ConvTranspose2d,
+ num_mask_tokens: usize,
+ output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
+ transformer: TwoWayTransformer,
+ span: tracing::Span,
+}
+
+impl MaskDecoder {
+ pub fn new(
+ transformer_dim: usize,
+ num_multimask_outputs: usize,
+ iou_head_depth: usize,
+ iou_head_hidden_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_mask_tokens = num_multimask_outputs + 1;
+ let iou_prediction_head = MlpMaskDecoder::new(
+ transformer_dim,
+ iou_head_hidden_dim,
+ num_mask_tokens,
+ iou_head_depth,
+ false,
+ vb.pp("iou_prediction_head"),
+ )?;
+ let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
+ let mask_tokens =
+ candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let output_upscaling_conv1 = candle_nn::conv_transpose2d(
+ transformer_dim,
+ transformer_dim / 4,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.0"),
+ )?;
+ let output_upscaling_ln =
+ super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
+ let output_upscaling_conv2 = candle_nn::conv_transpose2d(
+ transformer_dim / 4,
+ transformer_dim / 8,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.3"),
+ )?;
+ let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
+ let vb_o = vb.pp("output_hypernetworks_mlps");
+ for i in 0..num_mask_tokens {
+ let mlp = MlpMaskDecoder::new(
+ transformer_dim,
+ transformer_dim,
+ transformer_dim / 8,
+ 3,
+ false,
+ vb_o.pp(i),
+ )?;
+ output_hypernetworks_mlps.push(mlp)
+ }
+ let transformer = TwoWayTransformer::new(
+ /* depth */ 2,
+ /* embedding_dim */ transformer_dim,
+ /* num_heads */ 8,
+ /* mlp_dim */ 2048,
+ vb.pp("transformer"),
+ )?;
+ let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
+ Ok(Self {
+ iou_token,
+ mask_tokens,
+ iou_prediction_head,
+ output_upscaling_conv1,
+ output_upscaling_ln,
+ output_upscaling_conv2,
+ num_mask_tokens,
+ output_hypernetworks_mlps,
+ transformer,
+ span,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
+ let (masks, iou_pred) = self.predict_masks(
+ image_embeddings,
+ image_pe,
+ sparse_prompt_embeddings,
+ dense_prompt_embeddings,
+ )?;
+ let masks = if multimask_output {
+ masks.i((.., 1..))?
+ } else {
+ masks.i((.., 0..1))?
+ };
+ let iou_pred = if multimask_output {
+ iou_pred.i((.., 1..))?
+ } else {
+ iou_pred.i((.., 0..1))?
+ };
+ Ok((masks, iou_pred))
+ }
+
+ fn predict_masks(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Concatenate ouput tokens.
+ let output_tokens = Tensor::cat(
+ &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
+ 0,
+ )?;
+ let (d1, d2) = output_tokens.dims2()?;
+ let output_tokens =
+ output_tokens
+ .unsqueeze(0)?
+ .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
+ let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
+
+ // Expand per-image data in batch direction to be per mask
+ let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
+ let src = src.broadcast_add(dense_prompt_embeddings)?;
+ let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
+ let (b, c, h, w) = src.dims4()?;
+
+ // Run the transformer
+ let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
+ let iou_token_out = hs.i((.., 0))?;
+ let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
+
+ // Upscale mask embeddings and predict masks using the masks tokens.
+ let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
+ let upscaled_embedding = self
+ .output_upscaling_conv1
+ .forward(&src)?
+ .apply(&self.output_upscaling_ln)?
+ .gelu()?
+ .apply(&self.output_upscaling_conv2)?
+ .gelu()?;
+ let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
+ for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
+ let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
+ hyper_in_list.push(h)
+ }
+ let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
+ let (b, c, h, w) = upscaled_embedding.dims4()?;
+ let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
+ let masks = masks.reshape((b, (), h, w))?;
+
+ // Generate mask quality predictions.
+ let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
+ Ok((masks, iou_pred))
+ }
+}
+
+// Equivalent to torch.repeat_interleave
+fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
+ let img = img.unsqueeze(dim + 1)?;
+ let mut dims = img.dims().to_vec();
+ dims[dim + 1] = repeats;
+ img.broadcast_as(dims)?.flatten(dim, dim + 1)
+}
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs
new file mode 100644
index 00000000..c29db70a
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mod.rs
@@ -0,0 +1,100 @@
+use candle::{Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+pub mod image_encoder;
+pub mod mask_decoder;
+pub mod prompt_encoder;
+pub mod sam;
+pub mod tiny_vit;
+pub mod transformer;
+
+pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+#[derive(Debug)]
+pub struct LayerNorm2d {
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ eps: f64,
+}
+
+impl LayerNorm2d {
+ pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(num_channels, "weight")?;
+ let bias = vb.get(num_channels, "bias")?;
+ Ok(Self {
+ weight,
+ bias,
+ num_channels,
+ eps,
+ })
+ }
+}
+
+impl Module for LayerNorm2d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let u = xs.mean_keepdim(1)?;
+ let xs = xs.broadcast_sub(&u)?;
+ let s = xs.sqr()?.mean_keepdim(1)?;
+ let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
+ xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
+ .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
+ }
+}
+
+#[derive(Debug)]
+pub struct MlpBlock {
+ lin1: Linear,
+ lin2: Linear,
+ activation: candle_nn::Activation,
+ span: tracing::Span,
+}
+
+impl MlpBlock {
+ pub fn new(
+ embedding_dim: usize,
+ mlp_dim: usize,
+ activation: candle_nn::Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
+ Ok(Self {
+ lin1,
+ lin2,
+ activation,
+ span,
+ })
+ }
+}
+
+impl Module for MlpBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
new file mode 100644
index 00000000..9d0074b1
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
@@ -0,0 +1,239 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct PostionEmbeddingRandom {
+ positional_encoding_gaussian_matrix: Tensor,
+}
+
+impl PostionEmbeddingRandom {
+ fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
+ let positional_encoding_gaussian_matrix =
+ vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
+ Ok(Self {
+ positional_encoding_gaussian_matrix,
+ })
+ }
+
+ fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
+ let coords = coords.affine(2., -1.)?;
+ let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
+ let coords = (coords * (2. * std::f64::consts::PI))?;
+ Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
+ }
+
+ fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
+ let device = self.positional_encoding_gaussian_matrix.device();
+ let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let x_embed = (x_embed / w as f64)?
+ .reshape((1, ()))?
+ .broadcast_as((h, w))?;
+ let y_embed = (y_embed / h as f64)?
+ .reshape(((), 1))?
+ .broadcast_as((h, w))?;
+ let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
+ self.pe_encoding(&coords)?.permute((2, 0, 1))
+ }
+
+ fn forward_with_coords(
+ &self,
+ coords_input: &Tensor,
+ image_size: (usize, usize),
+ ) -> Result<Tensor> {
+ let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
+ let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
+ let c = coords_input.dim(D::Minus1)?;
+ let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
+ let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
+ self.pe_encoding(&coords)
+ }
+}
+
+#[derive(Debug)]
+pub struct PromptEncoder {
+ pe_layer: PostionEmbeddingRandom,
+ point_embeddings: Vec<candle_nn::Embedding>,
+ not_a_point_embed: candle_nn::Embedding,
+ mask_downscaling_conv1: candle_nn::Conv2d,
+ mask_downscaling_ln1: super::LayerNorm2d,
+ mask_downscaling_conv2: candle_nn::Conv2d,
+ mask_downscaling_ln2: super::LayerNorm2d,
+ mask_downscaling_conv3: candle_nn::Conv2d,
+ no_mask_embed: candle_nn::Embedding,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ embed_dim: usize,
+ span: tracing::Span,
+}
+
+impl PromptEncoder {
+ pub fn new(
+ embed_dim: usize,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ mask_in_chans: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_points_embeddings = 4;
+ let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
+ let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
+ let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let mask_downscaling_conv1 =
+ candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
+ let mask_downscaling_conv2 = candle_nn::conv2d(
+ mask_in_chans / 4,
+ mask_in_chans,
+ 2,
+ cfg,
+ vb.pp("mask_downscaling.3"),
+ )?;
+ let mask_downscaling_conv3 = candle_nn::conv2d(
+ mask_in_chans,
+ embed_dim,
+ 1,
+ Default::default(),
+ vb.pp("mask_downscaling.6"),
+ )?;
+ let mask_downscaling_ln1 =
+ super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ let mask_downscaling_ln2 =
+ super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
+ let vb_e = vb.pp("point_embeddings");
+ for i in 0..num_points_embeddings {
+ let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
+ point_embeddings.push(emb)
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
+ Ok(Self {
+ pe_layer,
+ point_embeddings,
+ not_a_point_embed,
+ mask_downscaling_conv1,
+ mask_downscaling_ln1,
+ mask_downscaling_conv2,
+ mask_downscaling_ln2,
+ mask_downscaling_conv3,
+ no_mask_embed,
+ image_embedding_size,
+ input_image_size,
+ embed_dim,
+ span,
+ })
+ }
+
+ pub fn get_dense_pe(&self) -> Result<Tensor> {
+ self.pe_layer
+ .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
+ .unsqueeze(0)
+ }
+
+ fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
+ masks
+ .apply(&self.mask_downscaling_conv1)?
+ .apply(&self.mask_downscaling_ln1)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv2)?
+ .apply(&self.mask_downscaling_ln2)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv3)
+ }
+
+ fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
+ let points = (points + 0.5)?;
+ let dev = points.device();
+ let (points, labels) = if pad {
+ let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
+ let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
+ let points = Tensor::cat(&[&points, &padding_point], 1)?;
+ let labels = Tensor::cat(&[labels, &padding_label], 1)?;
+ (points, labels)
+ } else {
+ (points, labels.clone())
+ };
+ let point_embedding = self
+ .pe_layer
+ .forward_with_coords(&points, self.input_image_size)?;
+ let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
+ let zeros = point_embedding.zeros_like()?;
+ let point_embedding = labels.lt(0f32)?.where_cond(
+ &self
+ .not_a_point_embed
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &point_embedding,
+ )?;
+ let labels0 = labels.eq(0f32)?.where_cond(
+ &self.point_embeddings[0]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels0)?;
+ let labels1 = labels.eq(1f32)?.where_cond(
+ &self.point_embeddings[1]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels1)?;
+ Ok(point_embedding)
+ }
+
+ fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
+ let boxes = (boxes + 0.5)?;
+ let coords = boxes.reshape(((), 2, 2))?;
+ let corner_embedding = self
+ .pe_layer
+ .forward_with_coords(&coords, self.input_image_size)?;
+ let ce1 = corner_embedding.i((.., 0))?;
+ let ce2 = corner_embedding.i((.., 1))?;
+ let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
+ let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
+ Tensor::cat(&[&ce1, &ce2], 1)
+ }
+
+ pub fn forward(
+ &self,
+ points: Option<(&Tensor, &Tensor)>,
+ boxes: Option<&Tensor>,
+ masks: Option<&Tensor>,
+ ) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
+ let se_points = match points {
+ Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
+ None => None,
+ };
+ let se_boxes = match boxes {
+ Some(boxes) => Some(self.embed_boxes(boxes)?),
+ None => None,
+ };
+ let sparse_embeddings = match (se_points, se_boxes) {
+ (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
+ (Some(se_points), None) => se_points,
+ (None, Some(se_boxes)) => se_boxes,
+ (None, None) => {
+ Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
+ }
+ };
+
+ let dense_embeddings = match masks {
+ None => {
+ let emb = self.no_mask_embed.embeddings();
+ emb.reshape((1, (), 1, 1))?.expand((
+ 1,
+ emb.elem_count(),
+ self.image_embedding_size.0,
+ self.image_embedding_size.1,
+ ))?
+ }
+ Some(masks) => self.embed_masks(masks)?,
+ };
+ Ok((sparse_embeddings, dense_embeddings))
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
new file mode 100644
index 00000000..07e9a759
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -0,0 +1,433 @@
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+use super::image_encoder::ImageEncoderViT;
+use super::mask_decoder::MaskDecoder;
+use super::prompt_encoder::PromptEncoder;
+use super::tiny_vit::{tiny_vit_5m, TinyViT};
+
+const PROMPT_EMBED_DIM: usize = 256;
+pub const IMAGE_SIZE: usize = 1024;
+const VIT_PATCH_SIZE: usize = 16;
+const PRED_IOU_THRESH: f32 = 0.88;
+const STABILITY_SCORE_OFFSET: f32 = 1.0;
+const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
+const MODEL_MASK_THRESHOLD: f32 = 0.0;
+const CROP_NMS_THRESH: f32 = 0.7;
+
+#[derive(Debug)]
+enum ImageEncoder {
+ Original(ImageEncoderViT),
+ TinyViT(TinyViT),
+}
+
+impl Module for ImageEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::Original(vit) => vit.forward(xs),
+ Self::TinyViT(vit) => vit.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Sam {
+ image_encoder: ImageEncoder,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: Tensor,
+ pixel_std: Tensor,
+}
+
+impl Sam {
+ pub fn new(
+ encoder_embed_dim: usize,
+ encoder_depth: usize,
+ encoder_num_heads: usize,
+ encoder_global_attn_indexes: &[usize],
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
+
+ let image_encoder = ImageEncoderViT::new(
+ IMAGE_SIZE,
+ VIT_PATCH_SIZE,
+ 3,
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ PROMPT_EMBED_DIM,
+ /* qkv_bias */ true,
+ /* use_rel_pos */ true,
+ /* use_abs_pos */ true,
+ /* window_size */ 14,
+ /* global_attn_indexes */ encoder_global_attn_indexes,
+ vb.pp("image_encoder"),
+ )?;
+ let prompt_encoder = PromptEncoder::new(
+ PROMPT_EMBED_DIM,
+ (image_embedding_size, image_embedding_size),
+ (IMAGE_SIZE, IMAGE_SIZE),
+ 16,
+ vb.pp("prompt_encoder"),
+ )?;
+ let mask_decoder = MaskDecoder::new(
+ PROMPT_EMBED_DIM,
+ /* num_multitask_outputs */ 3,
+ /* iou_head_depth */ 3,
+ /* iou_head_hidden_dim */ 256,
+ vb.pp("mask_decoder"),
+ )?;
+ let pixel_mean =
+ Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
+ let pixel_std =
+ Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
+ Ok(Self {
+ image_encoder: ImageEncoder::Original(image_encoder),
+ prompt_encoder,
+ mask_decoder,
+ pixel_std,
+ pixel_mean,
+ })
+ }
+
+ pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
+ let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
+
+ let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
+ let prompt_encoder = PromptEncoder::new(
+ PROMPT_EMBED_DIM,
+ (image_embedding_size, image_embedding_size),
+ (IMAGE_SIZE, IMAGE_SIZE),
+ 16,
+ vb.pp("prompt_encoder"),
+ )?;
+ let mask_decoder = MaskDecoder::new(
+ PROMPT_EMBED_DIM,
+ /* num_multitask_outputs */ 3,
+ /* iou_head_depth */ 3,
+ /* iou_head_hidden_dim */ 256,
+ vb.pp("mask_decoder"),
+ )?;
+ let pixel_mean =
+ Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
+ let pixel_std =
+ Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
+ Ok(Self {
+ image_encoder: ImageEncoder::TinyViT(image_encoder),
+ prompt_encoder,
+ mask_decoder,
+ pixel_std,
+ pixel_mean,
+ })
+ }
+
+ pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ self.image_encoder.forward(&img)
+ }
+
+ pub fn forward(
+ &self,
+ img: &Tensor,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_c, original_h, original_w) = img.dims3()?;
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+ let (low_res_mask, iou) = self.forward_for_embeddings(
+ &img_embeddings,
+ original_h,
+ original_w,
+ point,
+ multimask_output,
+ )?;
+ let mask = low_res_mask
+ .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
+ .get(0)?
+ .i((.., ..original_h, ..original_w))?;
+ Ok((mask, iou))
+ }
+
+ pub fn forward_for_embeddings(
+ &self,
+ img_embeddings: &Tensor,
+ original_h: usize,
+ original_w: usize,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = match point {
+ None => None,
+ Some((x, y)) => {
+ let points = Tensor::new(
+ &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
+ img_embeddings.device(),
+ )?;
+ let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
+ Some((points, labels))
+ }
+ };
+ let points = points.as_ref().map(|(x, y)| (x, y));
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder.forward(points, None, None)?;
+ self.mask_decoder.forward(
+ img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ multimask_output,
+ )
+ }
+
+ pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let img = img
+ .broadcast_mul(&self.pixel_std)?
+ .broadcast_add(&self.pixel_mean)?;
+ img.maximum(&img.zeros_like()?)?
+ .minimum(&(img.ones_like()? * 255.)?)
+ }
+
+ pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let (_c, h, w) = img.dims3()?;
+ let img = img
+ .to_dtype(DType::F32)?
+ .broadcast_sub(&self.pixel_mean)?
+ .broadcast_div(&self.pixel_std)?;
+ if h > IMAGE_SIZE || w > IMAGE_SIZE {
+ candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
+ }
+ let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
+ img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
+ }
+
+ fn process_crop(
+ &self,
+ img: &Tensor,
+ cb: CropBox,
+ point_grids: &[(f64, f64)],
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
+ // Crop the image and calculate embeddings.
+ let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
+ let img = self.preprocess(&img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+
+ let crop_w = cb.x1 - cb.x0;
+ let crop_h = cb.y1 - cb.y0;
+
+ // Generate masks for this crop.
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = point_grids
+ .iter()
+ .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
+ .collect::<Vec<_>>();
+
+ let mut bboxes = Vec::new();
+ for points in points.chunks(64) {
+ // Run the model on this batch.
+ let points_len = points.len();
+ let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
+ let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder
+ .forward(Some((&in_points, &in_labels)), None, None)?;
+
+ let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
+ &img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ /* multimask_output */ true,
+ )?;
+ let low_res_mask = low_res_mask.flatten(0, 1)?;
+ let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
+ let dev = low_res_mask.device();
+
+ for (i, iou) in iou_predictions.iter().enumerate() {
+ // Filter by predicted IoU.
+ if *iou < PRED_IOU_THRESH {
+ continue;
+ }
+ let low_res_mask = low_res_mask.get(i)?;
+
+ // Calculate stability score.
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let intersections = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let unions = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let stability_score = intersections / unions;
+ if stability_score < STABILITY_SCORE_THRESHOLD {
+ continue;
+ }
+
+ // Threshold masks and calculate boxes.
+ let low_res_mask = low_res_mask
+ .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
+ .to_dtype(DType::U32)?;
+ let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
+ let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
+ let min_max_x = min_max_indexes(&low_res_mask_per_x);
+ let min_max_y = min_max_indexes(&low_res_mask_per_y);
+ if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
+ let bbox = crate::object_detection::Bbox {
+ xmin: x0 as f32,
+ ymin: y0 as f32,
+ xmax: x1 as f32,
+ ymax: y1 as f32,
+ confidence: *iou,
+ data: low_res_mask,
+ };
+ bboxes.push(bbox);
+ }
+ // TODO:
+ // Filter boxes that touch crop boundaries
+ // Compress to RLE.
+ }
+ }
+
+ let mut bboxes = vec![bboxes];
+ // Remove duplicates within this crop.
+ crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
+
+ // TODO: Return to the original image frame.
+ Ok(bboxes.remove(0))
+ }
+
+ pub fn generate_masks(
+ &self,
+ img: &Tensor,
+ points_per_side: usize,
+ crop_n_layer: usize,
+ crop_overlap_ratio: f64,
+ crop_n_points_downscale_factor: usize,
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
+ let (_c, h, w) = img.dims3()?;
+ let point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layer,
+ crop_n_points_downscale_factor,
+ );
+ let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
+ let mut bboxes = Vec::new();
+ for crop_box in crop_boxes.into_iter() {
+ let layer_idx = crop_box.layer_idx;
+ let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
+ bboxes.extend(b)
+ }
+ // TODO: remove duplicates
+ Ok(bboxes)
+ }
+}
+
+// Return the first and last indexes i for which values[i] > 0
+fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
+ let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
+ for (i, &s) in values.iter().enumerate() {
+ if s == 0 {
+ continue;
+ }
+ min_i = usize::min(i, min_i);
+ max_i = usize::max(i, max_i);
+ }
+ if max_i < min_i {
+ None
+ } else {
+ Some((min_i, max_i))
+ }
+}
+
+#[derive(Debug)]
+struct CropBox {
+ x0: usize,
+ y0: usize,
+ x1: usize,
+ y1: usize,
+ layer_idx: usize,
+}
+
+impl CropBox {
+ fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
+ Self {
+ x0,
+ y0,
+ x1,
+ y1,
+ layer_idx,
+ }
+ }
+}
+
+fn generate_crop_boxes(
+ (im_h, im_w): (usize, usize),
+ n_layers: usize,
+ overlap_ratio: f64,
+) -> Vec<CropBox> {
+ fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
+ f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
+ }
+
+ let short_side = usize::min(im_h, im_w);
+
+ let mut crop_boxes = Vec::new();
+
+ // Original image.
+ crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
+
+ for layer_idx in 1..=n_layers {
+ let n_crops_per_side = 1 << layer_idx;
+ let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
+ let crop_w = crop_len(im_w, n_crops_per_side, overlap);
+ let crop_h = crop_len(im_w, n_crops_per_side, overlap);
+
+ for i_x in 0..n_crops_per_side {
+ let x0 = (crop_w - overlap) * i_x;
+ for i_y in 0..n_crops_per_side {
+ let y0 = (crop_h - overlap) * i_y;
+ let x1 = usize::min(im_w, x0 + crop_w);
+ let y1 = usize::min(im_h, y0 + crop_h);
+ crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
+ }
+ }
+ }
+
+ crop_boxes
+}
+
+// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
+fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
+ let offset = 1f64 / (2 * n_per_side) as f64;
+ let mut points = Vec::with_capacity(n_per_side * n_per_side);
+ for i_x in 0..n_per_side {
+ let x = offset + i_x as f64 / n_per_side as f64;
+ for i_y in 0..n_per_side {
+ let y = offset + i_y as f64 / n_per_side as f64;
+ points.push((x, y))
+ }
+ }
+ points
+}
+
+fn build_all_layer_point_grids(
+ n_per_side: usize,
+ n_layers: usize,
+ scale_per_layer: usize,
+) -> Vec<Vec<(f64, f64)>> {
+ let mut points_by_layer = Vec::with_capacity(n_layers + 1);
+ for i in 0..=n_layers {
+ let n_points = n_per_side / scale_per_layer.pow(i as u32);
+ points_by_layer.push(build_point_grid(n_points))
+ }
+ points_by_layer
+}
diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs
new file mode 100644
index 00000000..cd2936ab
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs
@@ -0,0 +1,633 @@
+// Adapted from:
+// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{Conv2dConfig, Module, VarBuilder};
+
+const MBCONV_EXPAND_RATIO: usize = 4;
+const MLP_RATIO: usize = 4;
+const LOCAL_CONV_SIZE: usize = 3;
+const IMG_SIZE: usize = 1024;
+const IN_CHANNELS: usize = 3;
+
+#[derive(Debug)]
+struct Conv2dBN {
+ c: candle_nn::Conv2d,
+ bn: candle_nn::BatchNorm,
+ span: tracing::Span,
+}
+
+impl Conv2dBN {
+ fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
+ let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
+ let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
+ Ok(Self { c, bn, span })
+ }
+}
+
+impl Module for Conv2dBN {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.c)?.apply(&self.bn)
+ }
+}
+
+#[derive(Debug)]
+struct PatchEmbed {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ span: tracing::Span,
+}
+
+impl PatchEmbed {
+ fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ padding: 1,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
+ let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { conv1, conv2, span })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
+ }
+}
+
+#[derive(Debug)]
+struct MBConv {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ conv3: Conv2dBN,
+ span: tracing::Span,
+}
+
+impl MBConv {
+ fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
+ let hidden = in_ * expand_ratio;
+ let cfg2 = candle_nn::Conv2dConfig {
+ padding: 1,
+ groups: hidden,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
+ let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
+ let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
+ Ok(Self {
+ conv1,
+ conv2,
+ conv3,
+ span,
+ })
+ }
+}
+
+impl Module for MBConv {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let shortcut = xs;
+ let xs = xs
+ .apply(&self.conv1)?
+ .gelu()?
+ .apply(&self.conv2)?
+ .gelu()?
+ .apply(&self.conv3)?;
+ (xs + shortcut)?.gelu()
+ }
+}
+
+#[derive(Debug)]
+struct PatchMerging {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ conv3: Conv2dBN,
+ input_resolution: (usize, usize),
+ span: tracing::Span,
+}
+
+impl PatchMerging {
+ fn new(
+ input_resolution: (usize, usize),
+ dim: usize,
+ out: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
+ let cfg2 = candle_nn::Conv2dConfig {
+ padding: 1,
+ stride,
+ groups: out,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
+ let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
+ let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
+ Ok(Self {
+ conv1,
+ conv2,
+ conv3,
+ input_resolution,
+ span,
+ })
+ }
+}
+
+impl Module for PatchMerging {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = if xs.rank() == 3 {
+ let (h, w) = self.input_resolution;
+ let b = xs.dim(0)?;
+ xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
+ } else {
+ xs.clone()
+ };
+ xs.apply(&self.conv1)?
+ .gelu()?
+ .apply(&self.conv2)?
+ .gelu()?
+ .apply(&self.conv3)?
+ .flatten_from(2)?
+ .transpose(1, 2)
+ }
+}
+
+#[derive(Debug)]
+struct ConvLayer {
+ blocks: Vec<MBConv>,
+ downsample: Option<PatchMerging>,
+ span: tracing::Span,
+}
+
+impl ConvLayer {
+ fn new(
+ dim: usize,
+ out: usize,
+ input_resolution: (usize, usize),
+ depth: usize,
+ downsample: bool,
+ conv_expand_ratio: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_b = vb.pp("blocks");
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
+ blocks.push(block)
+ }
+ let downsample = if downsample {
+ let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
+ Some(downsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
+ Ok(Self {
+ blocks,
+ downsample,
+ span,
+ })
+ }
+}
+
+impl Module for ConvLayer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ match &self.downsample {
+ None => Ok(xs),
+ Some(downsample) => downsample.forward(&xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Mlp {
+ norm: candle_nn::LayerNorm,
+ fc1: super::Linear,
+ fc2: super::Linear,
+ span: tracing::Span,
+}
+
+impl Mlp {
+ fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
+ let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
+ let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
+ let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
+ Ok(Self {
+ norm,
+ fc1,
+ fc2,
+ span,
+ })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.norm)?
+ .apply(&self.fc1)?
+ .gelu()?
+ .apply(&self.fc2)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ norm: candle_nn::LayerNorm,
+ qkv: super::Linear,
+ proj: super::Linear,
+ ab: Tensor,
+ key_dim: usize,
+ num_heads: usize,
+ d: usize,
+ dh: usize,
+ scale: f64,
+ span: tracing::Span,
+ span_matmul: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl Attention {
+ fn new(
+ dim: usize,
+ key_dim: usize,
+ num_heads: usize,
+ attn_ratio: usize,
+ resolution: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let d = attn_ratio * key_dim;
+ let dh = d * num_heads;
+ let nh_kd = key_dim * num_heads;
+ let h = dh + nh_kd * 2;
+ let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
+ let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
+ let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
+
+ let points = (0..resolution.0)
+ .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
+ .collect::<Vec<_>>();
+ let mut idxs = Vec::with_capacity(points.len() * points.len());
+ let mut attention_offsets = std::collections::HashMap::new();
+ for &(x1, y1) in points.iter() {
+ for &(x2, y2) in points.iter() {
+ let offset = ((x2 - x1).abs(), (y2 - y1).abs());
+ let l = attention_offsets.len();
+ let idx = attention_offsets.entry(offset).or_insert(l);
+ idxs.push(*idx as u32)
+ }
+ }
+ let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
+ let idxs = Tensor::new(idxs, attention_biases.device())?;
+ let ab =
+ attention_biases
+ .index_select(&idxs, 1)?
+ .reshape(((), points.len(), points.len()))?;
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
+ Ok(Self {
+ norm,
+ qkv,
+ proj,
+ ab,
+ key_dim,
+ num_heads,
+ d,
+ dh,
+ scale: 1f64 / (key_dim as f64).sqrt(),
+ span,
+ span_matmul,
+ span_softmax,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b, n, _) = xs.dims3()?;
+ let xs = xs.apply(&self.norm)?;
+ let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
+ let q = qkv
+ .narrow(D::Minus1, 0, self.key_dim)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let k = qkv
+ .narrow(D::Minus1, self.key_dim, self.key_dim)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let v = qkv
+ .narrow(D::Minus1, 2 * self.key_dim, self.d)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ (q.matmul(&k.t()?)? * self.scale)?
+ };
+ let attn = attn.broadcast_add(&self.ab)?;
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ attn.matmul(&v)?
+ };
+ attn.transpose(1, 2)?
+ .reshape((b, n, self.dh))?
+ .apply(&self.proj)
+ }
+}
+
+#[derive(Debug)]
+struct TinyViTBlock {
+ attn: Attention,
+ local_conv: Conv2dBN,
+ mlp: Mlp,
+ window_size: usize,
+ input_resolution: (usize, usize),
+ span: tracing::Span,
+}
+
+impl TinyViTBlock {
+ fn new(
+ dim: usize,
+ input_resolution: (usize, usize),
+ num_heads: usize,
+ window_size: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let head_dim = dim / num_heads;
+ let attn = Attention::new(
+ dim,
+ head_dim,
+ num_heads,
+ 1,
+ (window_size, window_size),
+ vb.pp("attn"),
+ )?;
+ let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: LOCAL_CONV_SIZE / 2,
+ groups: dim,
+ ..Default::default()
+ };
+ let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ Ok(Self {
+ attn,
+ local_conv,
+ mlp,
+ window_size,
+ input_resolution,
+ span,
+ })
+ }
+}
+
+impl Module for TinyViTBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (h, w) = self.input_resolution;
+ let (b, l, c) = xs.dims3()?;
+ let res_x = xs;
+ let xs = if h == self.window_size && w == self.window_size {
+ self.attn.forward(xs)?
+ } else {
+ let xs = xs.reshape((b, h, w, c))?;
+ let pad_b = (self.window_size - h % self.window_size) % self.window_size;
+ let pad_r = (self.window_size - w % self.window_size) % self.window_size;
+
+ let xs = if pad_b > 0 {
+ xs.pad_with_zeros(1, 0, pad_b)?
+ } else {
+ xs
+ };
+ let xs = if pad_r > 0 {
+ xs.pad_with_zeros(2, 0, pad_r)?
+ } else {
+ xs
+ };
+ let (p_h, p_w) = (h + pad_b, w + pad_r);
+ let n_h = p_h / self.window_size;
+ let n_w = p_w / self.window_size;
+ let xs = xs
+ .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
+ .transpose(2, 3)?
+ .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
+ let xs = self.attn.forward(&xs)?;
+ let xs = xs
+ .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
+ .transpose(2, 3)?
+ .reshape((b, p_h, p_w, c))?;
+ let xs = if pad_r > 0 {
+ xs.i((.., .., ..w))?.contiguous()?
+ } else {
+ xs
+ };
+ let xs = if pad_b > 0 {
+ xs.i((.., ..h, ..))?.contiguous()?
+ } else {
+ xs
+ };
+ xs.reshape((b, l, c))?
+ };
+ let xs = (xs + res_x)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .reshape((b, c, h, w))?
+ .apply(&self.local_conv)?
+ .reshape((b, c, l))?
+ .transpose(1, 2)?;
+ &xs + self.mlp.forward(&xs)?
+ }
+}
+
+#[derive(Debug)]
+struct BasicLayer {
+ blocks: Vec<TinyViTBlock>,
+ downsample: Option<PatchMerging>,
+ span: tracing::Span,
+}
+
+impl BasicLayer {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ dim: usize,
+ input_resolution: (usize, usize),
+ depth: usize,
+ num_heads: usize,
+ window_size: usize,
+ downsample: bool,
+ out: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_b = vb.pp("blocks");
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let block = TinyViTBlock::new(
+ dim,
+ input_resolution,
+ num_heads,
+ window_size,
+ vb_b.pp(index),
+ )?;
+ blocks.push(block)
+ }
+ let downsample = if downsample {
+ let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
+ Some(downsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
+ Ok(Self {
+ blocks,
+ downsample,
+ span,
+ })
+ }
+}
+
+impl Module for BasicLayer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ match &self.downsample {
+ None => Ok(xs),
+ Some(downsample) => downsample.forward(&xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct TinyViT {
+ patch_embed: PatchEmbed,
+ layer0: ConvLayer,
+ layers: Vec<BasicLayer>,
+ // norm_head: candle_nn::LayerNorm,
+ // head: candle_nn::Linear,
+ neck_conv1: candle_nn::Conv2d,
+ neck_ln1: super::LayerNorm2d,
+ neck_conv2: candle_nn::Conv2d,
+ neck_ln2: super::LayerNorm2d,
+ span: tracing::Span,
+ span_neck: tracing::Span,
+}
+
+impl TinyViT {
+ pub fn new(
+ embed_dims: &[usize],
+ depths: &[usize],
+ num_heads: &[usize],
+ window_sizes: &[usize],
+ _num_classes: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
+ let patches_resolution = IMG_SIZE / 4;
+
+ let vb_l = vb.pp("layers");
+ let layer0 = ConvLayer::new(
+ /* dim */ embed_dims[0],
+ /* out */ embed_dims[1],
+ /* input_resolution */ (patches_resolution, patches_resolution),
+ /* depth */ depths[0],
+ /* downsample */ true,
+ /* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
+ vb_l.pp(0),
+ )?;
+
+ let num_layers = embed_dims.len();
+ let mut layers = Vec::with_capacity(num_layers - 1);
+ for i_layer in 1..num_layers {
+ let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
+ let layer = BasicLayer::new(
+ /* dim */ embed_dims[i_layer],
+ /* input_resolution */ (patches_resolution, patches_resolution),
+ /* depth */ depths[i_layer],
+ /* num_heads */ num_heads[i_layer],
+ /* window_size */ window_sizes[i_layer],
+ /* downsample */ i_layer < num_layers - 1,
+ /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
+ vb_l.pp(i_layer),
+ )?;
+ layers.push(layer)
+ }
+
+ let last_embed_dim = embed_dims[embed_dims.len() - 1];
+ // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
+ // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
+ let neck_conv1 =
+ candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
+ let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
+ let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
+
+ let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
+ let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
+ Ok(Self {
+ patch_embed,
+ layer0,
+ layers,
+ neck_conv1,
+ neck_ln1,
+ neck_conv2,
+ neck_ln2,
+ span,
+ span_neck,
+ })
+ }
+}
+
+impl Module for TinyViT {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.patch_embed.forward(xs)?;
+ let mut xs = self.layer0.forward(&xs)?;
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs)?
+ }
+ let (b, _, c) = xs.dims3()?;
+ let _enter = self.span_neck.enter();
+ xs.reshape((b, 64, 64, c))?
+ .permute((0, 3, 1, 2))?
+ .apply(&self.neck_conv1)?
+ .apply(&self.neck_ln1)?
+ .apply(&self.neck_conv2)?
+ .apply(&self.neck_ln2)
+ }
+}
+
+pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
+ TinyViT::new(
+ /* embed_dims */ &[64, 128, 160, 320],
+ /* depths */ &[2, 2, 6, 2],
+ /* num_heads */ &[2, 4, 5, 10],
+ /* window_sizes */ &[7, 7, 14, 7],
+ /* num_classes */ 1000,
+ vb,
+ )
+}
diff --git a/candle-transformers/src/models/segment_anything/transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs
new file mode 100644
index 00000000..80efb38c
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/transformer.rs
@@ -0,0 +1,221 @@
+use candle::{Result, Tensor};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ out_proj: Linear,
+ num_heads: usize,
+}
+
+impl Attention {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ downsample_rate: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let internal_dim = embedding_dim / downsample_rate;
+ let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
+ let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
+ let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ out_proj,
+ num_heads,
+ })
+ }
+
+ fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = x.dims3()?;
+ x.reshape((b, n, self.num_heads, c / self.num_heads))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
+ let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
+ x.transpose(1, 2)?
+ .reshape((b, n_tokens, n_heads * c_per_head))
+ }
+
+ fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let q = self.q_proj.forward(&q.contiguous()?)?;
+ let k = self.k_proj.forward(&k.contiguous()?)?;
+ let v = self.v_proj.forward(&v.contiguous()?)?;
+
+ let q = self.separate_heads(&q)?;
+ let k = self.separate_heads(&k)?;
+ let v = self.separate_heads(&v)?;
+
+ let (_, _, _, c_per_head) = q.dims4()?;
+ let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
+ let attn = candle_nn::ops::softmax_last_dim(&attn)?;
+
+ let out = attn.matmul(&v)?;
+ self.recombine_heads(&out)?.apply(&self.out_proj)
+ }
+}
+
+#[derive(Debug)]
+struct TwoWayAttentionBlock {
+ self_attn: Attention,
+ norm1: LayerNorm,
+ cross_attn_token_to_image: Attention,
+ norm2: LayerNorm,
+ mlp: super::MlpBlock,
+ norm3: LayerNorm,
+ norm4: LayerNorm,
+ cross_attn_image_to_token: Attention,
+ skip_first_layer_pe: bool,
+}
+
+impl TwoWayAttentionBlock {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ skip_first_layer_pe: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
+ let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
+ let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let cross_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_token_to_image"),
+ )?;
+ let cross_attn_image_to_token = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_image_to_token"),
+ )?;
+ let mlp = super::MlpBlock::new(
+ embedding_dim,
+ mlp_dim,
+ candle_nn::Activation::Relu,
+ vb.pp("mlp"),
+ )?;
+ Ok(Self {
+ self_attn,
+ norm1,
+ cross_attn_image_to_token,
+ norm2,
+ mlp,
+ norm3,
+ norm4,
+ cross_attn_token_to_image,
+ skip_first_layer_pe,
+ })
+ }
+
+ fn forward(
+ &self,
+ queries: &Tensor,
+ keys: &Tensor,
+ query_pe: &Tensor,
+ key_pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Self attention block
+ let queries = if self.skip_first_layer_pe {
+ self.self_attn.forward(queries, queries, queries)?
+ } else {
+ let q = (queries + query_pe)?;
+ let attn_out = self.self_attn.forward(&q, &q, queries)?;
+ (queries + attn_out)?
+ };
+ let queries = self.norm1.forward(&queries)?;
+
+ // Cross attention block, tokens attending to image embedding
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
+ let queries = (&queries + attn_out)?;
+ let queries = self.norm2.forward(&queries)?;
+
+ // MLP block
+ let mlp_out = self.mlp.forward(&queries);
+ let queries = (queries + mlp_out)?;
+ let queries = self.norm3.forward(&queries)?;
+
+ // Cross attention block, image embedding attending to tokens
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
+ let keys = (keys + attn_out)?;
+ let keys = self.norm4.forward(&keys)?;
+
+ Ok((queries, keys))
+ }
+}
+
+#[derive(Debug)]
+pub struct TwoWayTransformer {
+ layers: Vec<TwoWayAttentionBlock>,
+ final_attn_token_to_image: Attention,
+ norm_final_attn: LayerNorm,
+}
+
+impl TwoWayTransformer {
+ pub fn new(
+ depth: usize,
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_l = vb.pp("layers");
+ let mut layers = Vec::with_capacity(depth);
+ for i in 0..depth {
+ let layer =
+ TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
+ layers.push(layer)
+ }
+ let final_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("final_attn_token_to_image"),
+ )?;
+ let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
+ Ok(Self {
+ layers,
+ final_attn_token_to_image,
+ norm_final_attn,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embedding: &Tensor,
+ image_pe: &Tensor,
+ point_embedding: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
+ let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
+
+ let mut queries = point_embedding.clone();
+ let mut keys = image_embedding;
+
+ for layer in self.layers.iter() {
+ (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
+ }
+
+ let q = (&queries + point_embedding)?;
+ let k = (&keys + image_pe)?;
+ let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
+ let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
+
+ Ok((queries, keys))
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
index 1ae1bfc3..b3ea91f9 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -17,7 +17,7 @@ impl GeGlu {
}
}
-impl GeGlu {
+impl Module for GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
@@ -53,7 +53,7 @@ impl FeedForward {
}
}
-impl FeedForward {
+impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
@@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
}
#[derive(Debug)]
-struct CrossAttention {
+pub struct CrossAttention {
to_q: nn::Linear,
to_k: nn::Linear,
to_v: nn::Linear,
@@ -94,7 +94,7 @@ struct CrossAttention {
impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None
- fn new(
+ pub fn new(
vs: nn::VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
@@ -198,14 +198,14 @@ impl CrossAttention {
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
- nn::ops::softmax(&xs, D::Minus1)?
+ nn::ops::softmax_last_dim(&xs)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};
self.reshape_batch_dim_to_heads(&xs)
}
- fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?;
@@ -501,8 +501,10 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
+}
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl Module for AttentionBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
index d26c1c46..e7a20270 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -12,13 +12,15 @@ use candle_nn::Module;
pub enum Activation {
QuickGelu,
Gelu,
+ GeluErf,
}
-impl Activation {
+impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
+ Activation::GeluErf => xs.gelu_erf(),
}
}
}
@@ -99,6 +101,36 @@ impl Config {
activation: Activation::Gelu,
}
}
+
+ // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
+ pub fn wuerstchen() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ projection_dim: 1024,
+ activation: Activation::GeluErf,
+ }
+ }
+
+ // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
+ pub fn wuerstchen_prior() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 512,
+ activation: Activation::GeluErf,
+ }
+ }
}
// CLIP Text Model
@@ -129,7 +161,7 @@ impl ClipTextEmbeddings {
}
}
-impl ClipTextEmbeddings {
+impl Module for ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
@@ -319,21 +351,39 @@ impl ClipTextTransformer {
}
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
- fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
+ fn build_causal_attention_mask(
+ bsz: usize,
+ seq_len: usize,
+ mask_after: usize,
+ device: &Device,
+ ) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
- .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
+ .flat_map(|i| {
+ (0..seq_len).map(move |j| {
+ if j > i || j > mask_after {
+ f32::MIN
+ } else {
+ 0.
+ }
+ })
+ })
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))
}
-}
-impl ClipTextTransformer {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
- let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
+ let causal_attention_mask =
+ Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
}
+
+impl Module for ClipTextTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward_with_mask(xs, usize::MAX)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs
index f2e021ce..916b7349 100644
--- a/candle-examples/examples/stable-diffusion/ddim.rs
+++ b/candle-transformers/src/models/stable_diffusion/ddim.rs
@@ -7,7 +7,7 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
-use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@@ -67,14 +67,14 @@ impl DDIMScheduler {
.rev()
.collect();
let betas = match config.beta_schedule {
- BetaSchedule::ScaledLinear => crate::utils::linspace(
+ BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
- crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};
@@ -163,6 +163,17 @@ impl DDIMScheduler {
}
}
+ pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
+ let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
+ (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
+ }
+
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs
new file mode 100644
index 00000000..d393f39a
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs
@@ -0,0 +1,205 @@
+use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use candle::{Result, Tensor};
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum DDPMVarianceType {
+ FixedSmall,
+ FixedSmallLog,
+ FixedLarge,
+ FixedLargeLog,
+ Learned,
+}
+
+impl Default for DDPMVarianceType {
+ fn default() -> Self {
+ Self::FixedSmall
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DDPMSchedulerConfig {
+ /// The value of beta at the beginning of training.
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// Option to predicted sample between -1 and 1 for numerical stability.
+ pub clip_sample: bool,
+ /// Option to clip the variance used when adding noise to the denoised sample.
+ pub variance_type: DDPMVarianceType,
+ /// prediction type of the scheduler function
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model.
+ pub train_timesteps: usize,
+}
+
+impl Default for DDPMSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085,
+ beta_end: 0.012,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ clip_sample: false,
+ variance_type: DDPMVarianceType::FixedSmall,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ }
+ }
+}
+
+pub struct DDPMScheduler {
+ alphas_cumprod: Vec<f64>,
+ init_noise_sigma: f64,
+ timesteps: Vec<usize>,
+ step_ratio: usize,
+ pub config: DDPMSchedulerConfig,
+}
+
+impl DDPMScheduler {
+ pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> {
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => super::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+
+ // min(train_timesteps, inference_steps)
+ // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187
+ let inference_steps = inference_steps.min(config.train_timesteps);
+ // arange the number of the scheduler's timesteps
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
+
+ Ok(Self {
+ alphas_cumprod,
+ init_noise_sigma: 1.0,
+ timesteps,
+ step_ratio,
+ config,
+ })
+ }
+
+ fn get_variance(&self, timestep: usize) -> f64 {
+ let prev_t = timestep as isize - self.step_ratio as isize;
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = if prev_t >= 0 {
+ self.alphas_cumprod[prev_t as usize]
+ } else {
+ 1.0
+ };
+ let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
+
+ // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ // and sample from it to get previous sample
+ // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
+
+ // retrieve variance
+ match self.config.variance_type {
+ DDPMVarianceType::FixedSmall => variance.max(1e-20),
+ // for rl-diffuser https://arxiv.org/abs/2205.09991
+ DDPMVarianceType::FixedSmallLog => {
+ let variance = variance.max(1e-20).ln();
+ (variance * 0.5).exp()
+ }
+ DDPMVarianceType::FixedLarge => current_beta_t,
+ DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
+ DDPMVarianceType::Learned => variance,
+ }
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
+ sample
+ }
+
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let prev_t = timestep as isize - self.step_ratio as isize;
+
+ // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272
+ // 1. compute alphas, betas
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = if prev_t >= 0 {
+ self.alphas_cumprod[prev_t as usize]
+ } else {
+ 1.0
+ };
+ let beta_prod_t = 1. - alpha_prod_t;
+ let beta_prod_t_prev = 1. - alpha_prod_t_prev;
+ let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
+ let current_beta_t = 1. - current_alpha_t;
+
+ // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15)
+ let mut pred_original_sample = match self.config.prediction_type {
+ PredictionType::Epsilon => {
+ ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())?
+ }
+ PredictionType::Sample => model_output.clone(),
+ PredictionType::VPrediction => {
+ ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())?
+ }
+ };
+
+ // 3. clip predicted x_0
+ if self.config.clip_sample {
+ pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?;
+ }
+
+ // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
+ let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
+
+ // 5. Compute predicted previous sample µ_t
+ // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)?
+ + sample * current_sample_coeff)?;
+
+ // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305
+ // 6. Add noise
+ let mut variance = model_output.zeros_like()?;
+ if timestep > 0 {
+ let variance_noise = model_output.randn_like(0., 1.)?;
+ if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
+ variance = (variance_noise * self.get_variance(timestep))?;
+ } else {
+ variance = (variance_noise * self.get_variance(timestep).sqrt())?;
+ }
+ }
+ &pred_prev_sample + variance
+ }
+
+ pub fn add_noise(
+ &self,
+ original_samples: &Tensor,
+ noise: Tensor,
+ timestep: usize,
+ ) -> Result<Tensor> {
+ (original_samples * self.alphas_cumprod[timestep].sqrt())?
+ + noise * (1. - self.alphas_cumprod[timestep]).sqrt()
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-transformers/src/models/stable_diffusion/embeddings.rs
index 97bc61f1..0de5f9a7 100644
--- a/candle-examples/examples/stable-diffusion/embeddings.rs
+++ b/candle-transformers/src/models/stable_diffusion/embeddings.rs
@@ -17,8 +17,8 @@ impl TimestepEmbedding {
}
}
-impl TimestepEmbedding {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl Module for TimestepEmbedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
@@ -41,8 +41,8 @@ impl Timesteps {
}
}
-impl Timesteps {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl Module for Timesteps {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?;
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
index cffc00d8..c6f1b904 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -1,5 +1,15 @@
-use crate::schedulers::PredictionType;
-use crate::{clip, ddim, unet_2d, vae};
+pub mod attention;
+pub mod clip;
+pub mod ddim;
+pub mod ddpm;
+pub mod embeddings;
+pub mod resnet;
+pub mod schedulers;
+pub mod unet_2d;
+pub mod unet_2d_blocks;
+pub mod utils;
+pub mod vae;
+
use candle::{DType, Device, Result};
use candle_nn as nn;
@@ -80,7 +90,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
- prediction_type: PredictionType,
+ prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
@@ -154,7 +164,7 @@ impl StableDiffusionConfig {
sliced_attention_size,
height,
width,
- PredictionType::VPrediction,
+ schedulers::PredictionType::VPrediction,
)
}
@@ -162,7 +172,7 @@ impl StableDiffusionConfig {
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
- prediction_type: PredictionType,
+ prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
@@ -235,7 +245,7 @@ impl StableDiffusionConfig {
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
- PredictionType::Epsilon,
+ schedulers::PredictionType::Epsilon,
)
}
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs
index 4cfd386d..0d818115 100644
--- a/candle-examples/examples/stable-diffusion/resnet.rs
+++ b/candle-transformers/src/models/stable_diffusion/resnet.rs
@@ -4,7 +4,7 @@
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
-use crate::utils::{conv2d, Conv2d};
+use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs
index 3f6a1d72..3f6a1d72 100644
--- a/candle-examples/examples/stable-diffusion/schedulers.rs
+++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
index 81bd9547..a3ed136e 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -2,9 +2,9 @@
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
//! timestep and return a denoised version of the input.
-use crate::embeddings::{TimestepEmbedding, Timesteps};
-use crate::unet_2d_blocks::*;
-use crate::utils::{conv2d, Conv2d};
+use super::embeddings::{TimestepEmbedding, Timesteps};
+use super::unet_2d_blocks::*;
+use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
index 26a1035b..29510cef 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
@@ -1,11 +1,11 @@
//! 2D UNet Building Blocks
//!
-use crate::attention::{
+use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
-use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
-use crate::utils::{conv2d, Conv2d};
-use candle::{Result, Tensor, D};
+use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use super::utils::{conv2d, Conv2d};
+use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@@ -43,7 +43,7 @@ impl Downsample2D {
}
}
-impl Downsample2D {
+impl Module for Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
@@ -172,8 +172,8 @@ impl DownEncoderBlock2D {
}
}
-impl DownEncoderBlock2D {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl Module for DownEncoderBlock2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
@@ -256,8 +256,8 @@ impl UpDecoderBlock2D {
}
}
-impl UpDecoderBlock2D {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl Module for UpDecoderBlock2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
@@ -754,6 +754,7 @@ impl UpBlock2D {
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
@@ -855,6 +856,7 @@ impl CrossAttnUpBlock2D {
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index c62f17af..c62f17af 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs
index aa8e13a0..21709afe 100644
--- a/candle-examples/examples/stable-diffusion/vae.rs
+++ b/candle-transformers/src/models/stable_diffusion/vae.rs
@@ -4,7 +4,7 @@
//! Auto-encoder models compress their input to a usually smaller latent space
//! before expanding it back to its original shape. This results in the latent values
//! compressing the original information.
-use crate::unet_2d_blocks::{
+use super::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};
@@ -132,14 +132,15 @@ impl Encoder {
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.conv_in.forward(xs)?;
+ let mut xs = xs.apply(&self.conv_in)?;
for down_block in self.down_blocks.iter() {
- xs = down_block.forward(&xs)?
+ xs = xs.apply(down_block)?
}
- let xs = self.mid_block.forward(&xs, None)?;
- let xs = self.conv_norm_out.forward(&xs)?;
- let xs = nn::ops::silu(&xs)?;
- self.conv_out.forward(&xs)
+ let xs = self
+ .mid_block
+ .forward(&xs, None)?
+ .apply(&self.conv_norm_out)?;
+ nn::ops::silu(&xs)?.apply(&self.conv_out)
}
}
@@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution {
}
pub fn sample(&self) -> Result<Tensor> {
- let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
+ let sample = self.mean.randn_like(0., 1.);
&self.mean + &self.std * sample
}
}
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
new file mode 100644
index 00000000..539ae89b
--- /dev/null
+++ b/candle-transformers/src/models/t5.rs
@@ -0,0 +1,841 @@
+// T5 Text Model
+// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
+
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+use serde::Deserialize;
+use std::sync::Arc;
+
+#[derive(Debug)]
+struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::embedding(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+fn default_relative_attention_max_distance() -> usize {
+ 128
+}
+
+fn default_is_decoder() -> bool {
+ false
+}
+
+fn default_use_cache() -> bool {
+ true
+}
+
+fn default_tie_word_embeddings() -> bool {
+ true
+}
+
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ d_model: usize,
+ d_kv: usize,
+ d_ff: usize,
+ num_layers: usize,
+ num_decoder_layers: Option<usize>,
+ num_heads: usize,
+ relative_attention_num_buckets: usize,
+ #[serde(default = "default_relative_attention_max_distance")]
+ relative_attention_max_distance: usize,
+ dropout_rate: f64,
+ layer_norm_epsilon: f64,
+ initializer_factor: f64,
+ #[serde(default)]
+ feed_forward_proj: Activation,
+ #[serde(default = "default_tie_word_embeddings")]
+ tie_word_embeddings: bool,
+ #[serde(default = "default_is_decoder")]
+ is_decoder: bool,
+ is_encoder_decoder: bool,
+ #[serde(default = "default_use_cache")]
+ pub use_cache: bool,
+ pub pad_token_id: usize,
+ pub eos_token_id: usize,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 32128,
+ d_model: 512,
+ d_kv: 64,
+ d_ff: 2048,
+ num_layers: 6,
+ num_decoder_layers: None,
+ num_heads: 8,
+ relative_attention_num_buckets: 32,
+ relative_attention_max_distance: 128,
+ dropout_rate: 0.1,
+ layer_norm_epsilon: 1e-6,
+ initializer_factor: 1.0,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ use_cache: true,
+ pad_token_id: 0,
+ eos_token_id: 1,
+ }
+ }
+}
+
+impl Config {
+ // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
+ pub fn musicgen_small() -> Self {
+ Self {
+ d_ff: 3072,
+ d_kv: 64,
+ d_model: 768,
+ dropout_rate: 0.1,
+ eos_token_id: 1,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ initializer_factor: 1.0,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ layer_norm_epsilon: 1e-6,
+ num_decoder_layers: Some(12),
+ num_heads: 12,
+ num_layers: 12,
+ pad_token_id: 0,
+ relative_attention_max_distance: 128,
+ relative_attention_num_buckets: 32,
+ use_cache: true,
+ vocab_size: 32128,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerNorm {
+ weight: Tensor,
+ variance_epsilon: f64,
+ span: tracing::Span,
+}
+
+impl T5LayerNorm {
+ fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(h, "weight")?;
+ Ok(Self {
+ weight,
+ variance_epsilon: eps,
+ span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
+ })
+ }
+}
+
+impl Module for T5LayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let dtype = xs.dtype();
+ let xs_f32 = xs.to_dtype(DType::F32)?;
+ // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
+ let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
+ let xs = xs.to_dtype(dtype)?;
+ let xs = xs.broadcast_mul(&self.weight)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseActDense {
+ wi: Linear,
+ wo: Linear,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi,
+ wo,
+ act: Activation::Relu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.wi.forward(xs)?;
+ let xs = self.act.forward(&xs)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseGatedActDense {
+ wi_0: Linear,
+ wi_1: Linear,
+ wo: Linear,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseGatedActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi_0,
+ wi_1,
+ wo,
+ act: Activation::NewGelu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseGatedActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
+ let hidden_linear = self.wi_1.forward(xs)?;
+ let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerFF {
+ dense_act: Option<T5DenseActDense>,
+ gated_dense_act: Option<T5DenseGatedActDense>,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerFF {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
+ (
+ None,
+ Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ )
+ } else {
+ (
+ Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ None,
+ )
+ };
+ Ok(Self {
+ dense_act,
+ gated_dense_act,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
+ })
+ }
+}
+
+impl Module for T5LayerFF {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let ys = self.layer_norm.forward(xs)?;
+ let ys = match &self.dense_act {
+ Some(dense_act) => dense_act.forward(&ys)?,
+ None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
+ };
+ let xs = (xs + ys)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5Attention {
+ q: Linear,
+ k: Linear,
+ v: Linear,
+ o: Linear,
+ n_heads: usize,
+ d_kv: usize,
+ relative_attention_bias: Option<Embedding>,
+ relative_attention_num_buckets: usize,
+ relative_attention_max_distance: usize,
+ inner_dim: usize,
+ use_cache: bool,
+ kv_cache: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_cache: tracing::Span,
+ span_mm: tracing::Span,
+ span_sm: tracing::Span,
+}
+
+impl T5Attention {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let inner_dim = cfg.num_heads * cfg.d_kv;
+ let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
+ let relative_attention_bias = if has_relative_attention_bias {
+ let emb = Embedding::new(
+ cfg.relative_attention_num_buckets,
+ cfg.num_heads,
+ vb.pp("relative_attention_bias"),
+ )?;
+ Some(emb)
+ } else {
+ None
+ };
+ Ok(Self {
+ q,
+ k,
+ v,
+ o,
+ n_heads: cfg.num_heads,
+ d_kv: cfg.d_kv,
+ relative_attention_bias,
+ relative_attention_num_buckets: cfg.relative_attention_num_buckets,
+ relative_attention_max_distance: cfg.relative_attention_max_distance,
+ inner_dim,
+ use_cache: cfg.use_cache && decoder,
+ kv_cache: None,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
+ span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
+ span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ // Performs Self-attention (if key_value_states is None) or attention
+ // over source sentence (provided by key_value_states).
+ let _enter = self.span.enter();
+ let kv_input = match key_value_states {
+ None => xs,
+ Some(key_value_states) => key_value_states,
+ };
+ let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_len = kv_input.dim(1)?;
+ let q = self.q.forward(xs)?;
+ let k = self.k.forward(kv_input)?;
+ let v = self.v.forward(kv_input)?;
+ let q = q
+ .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut k = k
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut v = v
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ if self.use_cache {
+ let _enter = self.span_cache.enter();
+ if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
+ k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+ };
+ // TODO: Use flash_attn.
+ let scores = {
+ let _enter = self.span_mm.enter();
+ q.matmul(&k.t()?)?
+ };
+ let scores = match mask {
+ None => scores,
+ Some(mask) => masked_fill(
+ &scores,
+ &mask
+ .unsqueeze(0)?
+ .unsqueeze(0)?
+ .repeat((b_sz, self.n_heads))?,
+ f32::NEG_INFINITY,
+ )?,
+ };
+
+ let (scores, position_bias) = match position_bias {
+ Some(position_bias) => (
+ scores.broadcast_add(position_bias)?,
+ Some(position_bias.clone()),
+ ),
+ None => match &self.relative_attention_bias {
+ None => (scores, None),
+ Some(relative_attention_bias) => {
+ // This only handles the bidirectional case.
+ let kv_len = k.dim(2)?;
+ let (q_start, q_end) = match self.use_cache {
+ true => ((kv_len - q_len) as u32, kv_len as u32),
+ false => (0_u32, kv_len as u32),
+ };
+ let num_buckets = self.relative_attention_num_buckets as u32 / 2;
+ let max_exact = num_buckets / 2;
+ let relative_position = (q_start..q_end)
+ .map(|i| {
+ (0..kv_len as u32)
+ .map(|j| {
+ if i < j {
+ if j - i < max_exact {
+ j - i + num_buckets
+ } else {
+ let b = f32::log(
+ (j - i) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ u32::min(
+ max_exact + num_buckets + b as u32,
+ self.relative_attention_num_buckets as u32 - 1,
+ )
+ }
+ } else if i - j < max_exact {
+ i - j
+ } else {
+ let b = f32::log(
+ (i - j) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ max_exact + b as u32
+ }
+ })
+ .collect::<Vec<u32>>()
+ })
+ .collect::<Vec<Vec<_>>>();
+ let relative_buckets = Tensor::new(relative_position, q.device())?;
+ let position_bias = relative_attention_bias
+ .forward(&relative_buckets)?
+ .permute((2, 0, 1))?
+ .unsqueeze(0)?;
+ (scores.broadcast_add(&position_bias)?, Some(position_bias))
+ // TODO: position_bias_masked?
+ }
+ },
+ };
+
+ let attn_weights = {
+ let _enter = self.span_sm.enter();
+ candle_nn::ops::softmax(&scores, D::Minus1)?
+ };
+ let attn_output = attn_weights.matmul(&v)?;
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.inner_dim))?;
+ let attn_output = self.o.forward(&attn_output)?;
+ Ok((attn_output, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerSelfAttention {
+ self_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerSelfAttention {
+ fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ self_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_xs = self.layer_norm.forward(xs)?;
+ let (ys, position_bias) =
+ self.self_attention
+ .forward(&normed_xs, position_bias, None, mask)?;
+ let ys = (xs + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerCrossAttention {
+ cross_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerCrossAttention {
+ fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ cross_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ hidden_states: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: &Tensor,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
+ let (ys, position_bias) = self.cross_attention.forward(
+ &normed_hidden_states,
+ position_bias,
+ Some(key_value_states),
+ None,
+ )?;
+ let ys = (hidden_states + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.cross_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5Block {
+ self_attn: T5LayerSelfAttention,
+ cross_attn: Option<T5LayerCrossAttention>,
+ ff: T5LayerFF,
+ span: tracing::Span,
+}
+
+impl T5Block {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let vb = vb.pp("layer");
+ let self_attn =
+ T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
+ let cross_attn = if cfg.is_decoder {
+ Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
+ } else {
+ None
+ };
+ let ff_i = if cross_attn.is_some() { 2 } else { 1 };
+ let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
+ Ok(Self {
+ self_attn,
+ cross_attn,
+ ff,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ // TODO: Cache masks
+ let mask = match self.cross_attn.is_some() {
+ true => {
+ let mask_len = xs.dim(1)?;
+ // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
+ // issues when using the KV cache in the decoder.
+ if mask_len <= 1 {
+ None
+ } else {
+ Some(get_mask(mask_len, xs.device())?)
+ }
+ }
+ false => None,
+ };
+ let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
+ // TODO: clamp for f16?
+ if let Some(cross_attn) = &mut self.cross_attn {
+ (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
+ // TODO: clamp for f16?
+ }
+ let xs = self.ff.forward(&xs)?;
+ // TODO: clamp for f16?
+ Ok((xs, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
+ }
+}
+
+#[derive(Debug)]
+struct T5Stack {
+ block: Vec<T5Block>,
+ shared: Arc<Embedding>,
+ final_layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5Stack {
+ fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
+ let block = (0..cfg.num_layers)
+ .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let final_layer_norm = T5LayerNorm::load(
+ cfg.d_model,
+ cfg.layer_norm_epsilon,
+ vb.pp("final_layer_norm"),
+ )?;
+ Ok(Self {
+ block,
+ shared: shared.clone(),
+ final_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "stack"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let mut hidden_states = input_embeds;
+ let mut position_bias = None;
+ for block in self.block.iter_mut() {
+ (hidden_states, position_bias) = block.forward(
+ &hidden_states,
+ position_bias.as_ref(),
+ encoder_hidden_states,
+ )?
+ }
+ self.final_layer_norm.forward(&hidden_states)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.block.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
+}
+
+#[derive(Debug)]
+pub struct T5EncoderModel {
+ encoder: T5Stack,
+ device: Device,
+ span: tracing::Span,
+}
+
+impl T5EncoderModel {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
+ Ok(Self {
+ encoder,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "encoder"),
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+pub struct T5ForConditionalGeneration {
+ encoder: T5Stack,
+ decoder: T5Stack,
+ d_model: usize,
+ tie_word_embeddings: bool,
+ lm_head: Option<Linear>,
+ shared: Arc<Embedding>,
+ device: Device,
+ span_decode: tracing::Span,
+ span_decode_head: tracing::Span,
+}
+
+impl T5ForConditionalGeneration {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ assert!(cfg.is_encoder_decoder);
+ let d_model = cfg.d_model;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+
+ let mut encoder_cfg = cfg.clone();
+ encoder_cfg.is_decoder = false;
+ encoder_cfg.use_cache = false;
+ encoder_cfg.is_encoder_decoder = false;
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
+
+ let mut decoder_cfg = cfg.clone();
+ decoder_cfg.is_decoder = true;
+ decoder_cfg.is_encoder_decoder = false;
+ decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
+ let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
+
+ let tie_word_embeddings = cfg.tie_word_embeddings;
+ let lm_head = if tie_word_embeddings {
+ None
+ } else {
+ Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
+ };
+
+ Ok(Self {
+ encoder,
+ decoder,
+ d_model,
+ tie_word_embeddings,
+ lm_head,
+ shared,
+ device: vb.device().clone(),
+ span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
+ span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
+ })
+ }
+
+ pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn decode(
+ &mut self,
+ decoder_input_ids: &Tensor,
+ encoder_output: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span_decode.enter();
+ let decoder_output = self
+ .decoder
+ .forward(decoder_input_ids, Some(encoder_output))?;
+
+ let scaling_factor = if self.tie_word_embeddings {
+ // Rescale output before projecting on vocab
+ // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ (self.d_model as f64).sqrt()
+ } else {
+ 1.0
+ };
+ let sequence_output = ((decoder_output
+ .narrow(1, decoder_output.dim(1)? - 1, 1)?
+ .squeeze(1)?)
+ * scaling_factor)?;
+ let output = {
+ let _enter = self.span_decode_head.enter();
+ match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ }
+ };
+
+ // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
+ Ok(output)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
+ let encoder_output = self.encode(input_ids)?;
+ self.decode(decoder_input_ids, &encoder_output)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache();
+ self.decoder.clear_kv_cache();
+ }
+}
diff --git a/candle-examples/examples/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs
index 2ceed065..4e01de32 100644
--- a/candle-examples/examples/whisper/audio.rs
+++ b/candle-transformers/src/models/whisper/audio.rs
@@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel
}
-pub fn pcm_to_mel<T: Float + std::fmt::Display>(
- samples: &[T],
- filters: &[T],
-) -> anyhow::Result<Vec<T>> {
- let mel = log_mel_spectrogram_(
+pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
+ log_mel_spectrogram_(
samples,
filters,
super::N_FFT,
super::HOP_LENGTH,
super::N_MELS,
false,
- );
- Ok(mel)
+ )
}
diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs
new file mode 100644
index 00000000..7dc8107b
--- /dev/null
+++ b/candle-transformers/src/models/whisper/mod.rs
@@ -0,0 +1,26 @@
+pub mod audio;
+pub mod model;
+
+pub const DTYPE: candle::DType = candle::DType::F32;
+
+// Audio parameters.
+pub const SAMPLE_RATE: usize = 16000;
+pub const N_FFT: usize = 400;
+pub const N_MELS: usize = 80;
+pub const HOP_LENGTH: usize = 160;
+pub const CHUNK_LENGTH: usize = 30;
+pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
+pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
+
+pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
+pub const LOGPROB_THRESHOLD: f64 = -1.0;
+pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
+pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
+
+// Tokenizer dependent bits.
+pub const SOT_TOKEN: &str = "<|startoftranscript|>";
+pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
+pub const TRANSLATE_TOKEN: &str = "<|translate|>";
+pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
+pub const EOT_TOKEN: &str = "<|endoftext|>";
+pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
diff --git a/candle-examples/examples/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index e58ab2ca..d2eda796 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,5 +1,5 @@
use candle::{Device, IndexOp, Result, Tensor, D};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
+use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@@ -166,7 +166,7 @@ impl MultiHeadAttention {
}
let w = {
let _enter = self.softmax_span.enter();
- softmax(&qk, D::Minus1)?
+ candle_nn::ops::softmax_last_dim(&qk)?
};
let wv = {
let _enter = self.matmul_span.enter();
diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs
new file mode 100644
index 00000000..0b90cb9d
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs
@@ -0,0 +1,118 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::{linear, Linear, VarBuilder};
+
+// A simplified version of:
+// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38
+#[derive(Debug)]
+pub struct Attention {
+ to_q: Linear,
+ to_k: Linear,
+ to_v: Linear,
+ to_out: Linear,
+ heads: usize,
+ scale: f64,
+ use_flash_attn: bool,
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+impl Attention {
+ pub fn new(
+ query_dim: usize,
+ heads: usize,
+ dim_head: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
+ let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?;
+ let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?;
+ let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?;
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ scale,
+ heads,
+ use_flash_attn,
+ })
+ }
+
+ fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((b_size / self.heads, self.heads, seq_len, dim))?
+ .permute((0, 2, 1, 3))?
+ .reshape((b_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((b_size, seq_len, self.heads, dim / self.heads))?
+ .permute((0, 2, 1, 3))?
+ .reshape((b_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> {
+ let attn_probs = (query.matmul(&key.t()?)? * self.scale)?;
+ candle_nn::ops::softmax_last_dim(&attn_probs)
+ }
+
+ pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ let (b_size, channel, h, w) = xs.dims4()?;
+ let xs = xs.reshape((b_size, channel, h * w))?.t()?;
+
+ let query = self.to_q.forward(&xs)?;
+ let key = self.to_k.forward(encoder_hidden_states)?;
+ let value = self.to_v.forward(encoder_hidden_states)?;
+
+ let query = self.head_to_batch_dim(&query)?;
+ let key = self.head_to_batch_dim(&key)?;
+ let value = self.head_to_batch_dim(&value)?;
+
+ let xs = if self.use_flash_attn {
+ let init_dtype = query.dtype();
+ let q = query
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let k = key
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let v = value
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ flash_attn(&q, &k, &v, self.scale as f32, false)?
+ .transpose(1, 2)?
+ .squeeze(0)?
+ .to_dtype(init_dtype)?
+ } else {
+ let attn_prs = self.get_attention_scores(&query, &key)?;
+ attn_prs.matmul(&value)?
+ };
+ let xs = self.batch_to_head_dim(&xs)?;
+
+ self.to_out
+ .forward(&xs)?
+ .t()?
+ .reshape((b_size, channel, h, w))
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
new file mode 100644
index 00000000..c89ec919
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -0,0 +1,203 @@
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
+#[derive(Debug)]
+pub struct WLayerNorm {
+ eps: f64,
+}
+
+impl WLayerNorm {
+ pub fn new(_size: usize) -> Result<Self> {
+ Ok(Self { eps: 1e-6 })
+ }
+}
+
+impl Module for WLayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.permute((0, 2, 3, 1))?;
+
+ let x_dtype = xs.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+
+ let hidden_size = xs.dim(D::Minus1)?;
+ let xs = xs.to_dtype(internal_dtype)?;
+ let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let xs = xs.broadcast_sub(&mean_x)?;
+ let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)?
+ .permute((0, 3, 1, 2))
+ }
+}
+
+#[derive(Debug)]
+pub struct LayerNormNoWeights {
+ eps: f64,
+}
+
+impl LayerNormNoWeights {
+ pub fn new(_size: usize) -> Result<Self> {
+ Ok(Self { eps: 1e-6 })
+ }
+}
+
+impl Module for LayerNormNoWeights {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let x_dtype = xs.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let hidden_size = xs.dim(D::Minus1)?;
+ let xs = xs.to_dtype(internal_dtype)?;
+ let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let xs = xs.broadcast_sub(&mean_x)?;
+ let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)
+ }
+}
+
+#[derive(Debug)]
+pub struct TimestepBlock {
+ mapper: candle_nn::Linear,
+}
+
+impl TimestepBlock {
+ pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
+ let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
+ Ok(Self { mapper })
+ }
+
+ pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
+ let ab = self
+ .mapper
+ .forward(t)?
+ .unsqueeze(2)?
+ .unsqueeze(3)?
+ .chunk(2, 1)?;
+ xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
+ }
+}
+
+#[derive(Debug)]
+pub struct GlobalResponseNorm {
+ gamma: Tensor,
+ beta: Tensor,
+}
+
+impl GlobalResponseNorm {
+ pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let gamma = vb.get((1, 1, 1, dim), "gamma")?;
+ let beta = vb.get((1, 1, 1, dim), "beta")?;
+ Ok(Self { gamma, beta })
+ }
+}
+
+impl Module for GlobalResponseNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ let stand_div_norm =
+ agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
+ xs.broadcast_mul(&stand_div_norm)?
+ .broadcast_mul(&self.gamma)?
+ .broadcast_add(&self.beta)?
+ + xs
+ }
+}
+
+#[derive(Debug)]
+pub struct ResBlock {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlock {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ padding: ksize / 2,
+ groups: c,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c)?;
+ let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .apply(&self.depthwise)?
+ .apply(&self.norm)?
+ .permute((0, 2, 3, 1))?;
+ let xs = xs
+ .apply(&self.channelwise_lin1)?
+ .gelu_erf()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+use super::attention_processor::Attention;
+#[derive(Debug)]
+pub struct AttnBlock {
+ self_attn: bool,
+ norm: WLayerNorm,
+ attention: Attention,
+ kv_mapper_lin: candle_nn::Linear,
+}
+
+impl AttnBlock {
+ pub fn new(
+ c: usize,
+ c_cond: usize,
+ nhead: usize,
+ self_attn: bool,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm = WLayerNorm::new(c)?;
+ let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
+ let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
+ Ok(Self {
+ self_attn,
+ norm,
+ attention,
+ kv_mapper_lin,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
+ let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
+ let norm_xs = self.norm.forward(xs)?;
+ let kv = if self.self_attn {
+ let (b_size, channel, _, _) = xs.dims4()?;
+ let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
+ Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
+ } else {
+ kv
+ };
+ xs + self.attention.forward(&norm_xs, &kv)
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs
new file mode 100644
index 00000000..9e69b868
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/ddpm.rs
@@ -0,0 +1,103 @@
+use candle::{Result, Tensor};
+
+#[derive(Debug, Clone)]
+pub struct DDPMWSchedulerConfig {
+ scaler: f64,
+ s: f64,
+}
+
+impl Default for DDPMWSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ scaler: 1f64,
+ s: 0.008f64,
+ }
+ }
+}
+
+pub struct DDPMWScheduler {
+ init_alpha_cumprod: f64,
+ init_noise_sigma: f64,
+ timesteps: Vec<f64>,
+ pub config: DDPMWSchedulerConfig,
+}
+
+impl DDPMWScheduler {
+ pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
+ let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
+ .cos()
+ .powi(2);
+ let timesteps = (0..=inference_steps)
+ .map(|i| 1. - i as f64 / inference_steps as f64)
+ .collect::<Vec<_>>();
+ Ok(Self {
+ init_alpha_cumprod,
+ init_noise_sigma: 1.0,
+ timesteps,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[f64] {
+ &self.timesteps
+ }
+
+ fn alpha_cumprod(&self, t: f64) -> f64 {
+ let scaler = self.config.scaler;
+ let s = self.config.s;
+ let t = if scaler > 1. {
+ 1. - (1. - t).powf(scaler)
+ } else if scaler < 1. {
+ t.powf(scaler)
+ } else {
+ t
+ };
+ let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
+ .cos()
+ .powi(2)
+ / self.init_alpha_cumprod;
+ alpha_cumprod.clamp(0.0001, 0.9999)
+ }
+
+ fn previous_timestep(&self, ts: f64) -> f64 {
+ let index = self
+ .timesteps
+ .iter()
+ .enumerate()
+ .map(|(idx, v)| (idx, (v - ts).abs()))
+ .min_by(|x, y| x.1.total_cmp(&y.1))
+ .unwrap()
+ .0;
+ self.timesteps[index + 1]
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
+ sample
+ }
+
+ pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
+ let prev_t = self.previous_timestep(ts);
+
+ let alpha_cumprod = self.alpha_cumprod(ts);
+ let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
+ let alpha = alpha_cumprod / alpha_cumprod_prev;
+
+ let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
+ let mu = (mu * (1. / alpha).sqrt())?;
+
+ let std_noise = mu.randn_like(0., 1.)?;
+ let std =
+ std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
+ if prev_t == 0. {
+ Ok(mu)
+ } else {
+ mu + std
+ }
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
new file mode 100644
index 00000000..64a48c8a
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -0,0 +1,396 @@
+use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct ResBlockStageB {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlockStageB {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ groups: c,
+ padding: ksize / 2,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c)?;
+ let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .permute((0, 2, 3, 1))?
+ .contiguous()?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+
+#[derive(Debug)]
+struct SubBlock {
+ res_block: ResBlockStageB,
+ ts_block: TimestepBlock,
+ attn_block: Option<AttnBlock>,
+}
+
+#[derive(Debug)]
+struct DownBlock {
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::Conv2d>,
+ sub_blocks: Vec<SubBlock>,
+}
+
+#[derive(Debug)]
+struct UpBlock {
+ sub_blocks: Vec<SubBlock>,
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::ConvTranspose2d>,
+}
+
+#[derive(Debug)]
+pub struct WDiffNeXt {
+ clip_mapper: candle_nn::Linear,
+ effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
+ seq_norm: LayerNormNoWeights,
+ embedding_conv: candle_nn::Conv2d,
+ embedding_ln: WLayerNorm,
+ down_blocks: Vec<DownBlock>,
+ up_blocks: Vec<UpBlock>,
+ clf_ln: WLayerNorm,
+ clf_conv: candle_nn::Conv2d,
+ c_r: usize,
+ patch_size: usize,
+}
+
+impl WDiffNeXt {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ c_in: usize,
+ c_out: usize,
+ c_r: usize,
+ c_cond: usize,
+ clip_embd: usize,
+ patch_size: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
+ const BLOCKS: [usize; 4] = [4, 4, 14, 4];
+ const NHEAD: [usize; 4] = [1, 10, 20, 20];
+ const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
+ const EFFNET_EMBD: usize = 16;
+
+ let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
+ let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
+ let vb_e = vb.pp("effnet_mappers");
+ for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i + INJECT_EFFNET.len()),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ let seq_norm = LayerNormNoWeights::new(c_cond)?;
+ let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let embedding_conv = candle_nn::conv2d(
+ c_in * patch_size * patch_size,
+ C_HIDDEN[0],
+ 1,
+ Default::default(),
+ vb.pp("embedding.1"),
+ )?;
+
+ let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
+ let vb = vb.pp("down_blocks").pp(i);
+ let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
+ (Some(layer_norm), Some(conv), 1)
+ } else {
+ (None, None, 0)
+ };
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = start_layer_i;
+ for _j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let down_block = DownBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ down_blocks.push(down_block)
+ }
+
+ let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
+ let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = 0;
+ for j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
+ c_hidden + c_skip
+ } else {
+ c_skip
+ };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let (layer_norm, conv) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv_transpose2d(
+ c_hidden,
+ C_HIDDEN[i - 1],
+ 2,
+ cfg,
+ vb.pp(layer_i).pp(1),
+ )?;
+ (Some(layer_norm), Some(conv))
+ } else {
+ (None, None)
+ };
+ let up_block = UpBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ up_blocks.push(up_block)
+ }
+
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let clf_conv = candle_nn::conv2d(
+ C_HIDDEN[0],
+ 2 * c_out * patch_size * patch_size,
+ 1,
+ Default::default(),
+ vb.pp("clf.1"),
+ )?;
+ Ok(Self {
+ clip_mapper,
+ effnet_mappers,
+ seq_norm,
+ embedding_conv,
+ embedding_ln,
+ down_blocks,
+ up_blocks,
+ clf_ln,
+ clf_conv,
+ c_r,
+ patch_size,
+ })
+ }
+
+ fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
+ clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ r: &Tensor,
+ effnet: &Tensor,
+ clip: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ const EPS: f64 = 1e-3;
+
+ let r_embed = self.gen_r_embedding(r)?;
+ let clip = match clip {
+ None => None,
+ Some(clip) => Some(self.gen_c_embeddings(clip)?),
+ };
+ let x_in = xs;
+
+ let mut xs = xs
+ .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
+ .apply(&self.embedding_conv)?
+ .apply(&self.embedding_ln)?;
+
+ let mut level_outputs = Vec::new();
+ for (i, down_block) in self.down_blocks.iter().enumerate() {
+ if let Some(ln) = &down_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &down_block.conv {
+ xs = xs.apply(conv)?
+ }
+ let skip = match &self.effnet_mappers[i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for block in down_block.sub_blocks.iter() {
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ level_outputs.push(xs.clone())
+ }
+ level_outputs.reverse();
+ let mut xs = level_outputs[0].clone();
+
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for (j, block) in up_block.sub_blocks.iter().enumerate() {
+ let skip = if j == 0 && i > 0 {
+ Some(&level_outputs[i])
+ } else {
+ None
+ };
+ let skip = match (skip, effnet_c.as_ref()) {
+ (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
+ (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
+ (None, None) => None,
+ };
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ if let Some(ln) = &up_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &up_block.conv {
+ xs = xs.apply(conv)?
+ }
+ }
+
+ let ab = xs
+ .apply(&self.clf_ln)?
+ .apply(&self.clf_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
+ .chunk(2, 1)?;
+ let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
+ (x_in - &ab[0])? / b
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs
new file mode 100644
index 00000000..7b076f06
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/mod.rs
@@ -0,0 +1,6 @@
+pub mod attention_processor;
+pub mod common;
+pub mod ddpm;
+pub mod diffnext;
+pub mod paella_vq;
+pub mod prior;
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
new file mode 100644
index 00000000..4a69cca0
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -0,0 +1,211 @@
+use super::common::LayerNormNoWeights;
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct MixingResidualBlock {
+ norm1: LayerNormNoWeights,
+ depthwise_conv: candle_nn::Conv2d,
+ norm2: LayerNormNoWeights,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_lin2: candle_nn::Linear,
+ gammas: Vec<f32>,
+}
+
+impl MixingResidualBlock {
+ pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let norm1 = LayerNormNoWeights::new(inp)?;
+ let norm2 = LayerNormNoWeights::new(inp)?;
+ let cfg = candle_nn::Conv2dConfig {
+ groups: inp,
+ ..Default::default()
+ };
+ let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
+ let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
+ let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
+ let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
+ Ok(Self {
+ norm1,
+ depthwise_conv,
+ norm2,
+ channelwise_lin1,
+ channelwise_lin2,
+ gammas,
+ })
+ }
+}
+
+impl Module for MixingResidualBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mods = &self.gammas;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm1)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[0] as f64, mods[1] as f64)?;
+ let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
+ let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm2)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[3] as f64, mods[4] as f64)?;
+ let x_temp = x_temp
+ .permute((0, 2, 3, 1))?
+ .contiguous()?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_temp * mods[5] as f64
+ }
+}
+
+#[derive(Debug)]
+pub struct PaellaVQ {
+ in_block_conv: candle_nn::Conv2d,
+ out_block_conv: candle_nn::Conv2d,
+ down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
+ down_blocks_conv: candle_nn::Conv2d,
+ down_blocks_bn: candle_nn::BatchNorm,
+ up_blocks_conv: candle_nn::Conv2d,
+ up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>,
+}
+
+impl PaellaVQ {
+ pub fn new(vb: VarBuilder) -> Result<Self> {
+ const IN_CHANNELS: usize = 3;
+ const OUT_CHANNELS: usize = 3;
+ const LATENT_CHANNELS: usize = 4;
+ const EMBED_DIM: usize = 384;
+ const BOTTLENECK_BLOCKS: usize = 12;
+ const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM];
+
+ let in_block_conv = candle_nn::conv2d(
+ IN_CHANNELS * 4,
+ C_LEVELS[0],
+ 1,
+ Default::default(),
+ vb.pp("in_block.1"),
+ )?;
+ let out_block_conv = candle_nn::conv2d(
+ C_LEVELS[0],
+ OUT_CHANNELS * 4,
+ 1,
+ Default::default(),
+ vb.pp("out_block.0"),
+ )?;
+
+ let mut down_blocks = Vec::new();
+ let vb_d = vb.pp("down_blocks");
+ let mut d_idx = 0;
+ for (i, &c_level) in C_LEVELS.iter().enumerate() {
+ let conv_block = if i > 0 {
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ stride: 2,
+ ..Default::default()
+ };
+ let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
+ d_idx += 1;
+ Some(block)
+ } else {
+ None
+ };
+ let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?;
+ d_idx += 1;
+ down_blocks.push((conv_block, res_block))
+ }
+ let vb_d = vb_d.pp(d_idx);
+ let down_blocks_conv = candle_nn::conv2d_no_bias(
+ C_LEVELS[1],
+ LATENT_CHANNELS,
+ 1,
+ Default::default(),
+ vb_d.pp(0),
+ )?;
+ let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
+
+ let mut up_blocks = Vec::new();
+ let vb_u = vb.pp("up_blocks");
+ let mut u_idx = 0;
+ let up_blocks_conv = candle_nn::conv2d(
+ LATENT_CHANNELS,
+ C_LEVELS[1],
+ 1,
+ Default::default(),
+ vb_u.pp(u_idx).pp(0),
+ )?;
+ u_idx += 1;
+ for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
+ let mut res_blocks = Vec::new();
+ let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 };
+ for _j in 0..n_bottleneck_blocks {
+ let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?;
+ u_idx += 1;
+ res_blocks.push(res_block)
+ }
+ let conv_block = if i < C_LEVELS.len() - 1 {
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ padding: 1,
+ stride: 2,
+ ..Default::default()
+ };
+ let block = candle_nn::conv_transpose2d(
+ c_level,
+ C_LEVELS[C_LEVELS.len() - i - 2],
+ 4,
+ cfg,
+ vb_u.pp(u_idx),
+ )?;
+ u_idx += 1;
+ Some(block)
+ } else {
+ None
+ };
+ up_blocks.push((res_blocks, conv_block))
+ }
+ Ok(Self {
+ in_block_conv,
+ down_blocks,
+ down_blocks_conv,
+ down_blocks_bn,
+ up_blocks,
+ up_blocks_conv,
+ out_block_conv,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
+ for down_block in self.down_blocks.iter() {
+ if let Some(conv) = &down_block.0 {
+ xs = xs.apply(conv)?
+ }
+ xs = xs.apply(&down_block.1)?
+ }
+ xs.apply(&self.down_blocks_conv)?
+ .apply(&self.down_blocks_bn)
+ }
+
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: quantizer if we want to support `force_not_quantize=False`.
+ let mut xs = xs.apply(&self.up_blocks_conv)?;
+ for up_block in self.up_blocks.iter() {
+ for b in up_block.0.iter() {
+ xs = xs.apply(b)?;
+ }
+ if let Some(conv) = &up_block.1 {
+ xs = xs.apply(conv)?
+ }
+ }
+ xs.apply(&self.out_block_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
+ }
+}
+
+impl Module for PaellaVQ {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.decode(&self.encode(xs)?)
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
new file mode 100644
index 00000000..97ccf0e2
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -0,0 +1,103 @@
+use super::common::{AttnBlock, ResBlock, TimestepBlock};
+use candle::{DType, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct Block {
+ res_block: ResBlock,
+ ts_block: TimestepBlock,
+ attn_block: AttnBlock,
+}
+
+#[derive(Debug)]
+pub struct WPrior {
+ projection: candle_nn::Conv2d,
+ cond_mapper_lin1: candle_nn::Linear,
+ cond_mapper_lin2: candle_nn::Linear,
+ blocks: Vec<Block>,
+ out_ln: super::common::WLayerNorm,
+ out_conv: candle_nn::Conv2d,
+ c_r: usize,
+}
+
+impl WPrior {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ c_in: usize,
+ c: usize,
+ c_cond: usize,
+ c_r: usize,
+ depth: usize,
+ nhead: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
+ let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
+ let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
+ let out_ln = super::common::WLayerNorm::new(c)?;
+ let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
+ let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
+ let attn_block = AttnBlock::new(
+ c,
+ c,
+ nhead,
+ true,
+ use_flash_attn,
+ vb.pp(format!("blocks.{}", 3 * index + 2)),
+ )?;
+ blocks.push(Block {
+ res_block,
+ ts_block,
+ attn_block,
+ })
+ }
+ Ok(Self {
+ projection,
+ cond_mapper_lin1,
+ cond_mapper_lin2,
+ blocks,
+ out_ln,
+ out_conv,
+ c_r,
+ })
+ }
+
+ pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let x_in = xs;
+ let mut xs = xs.apply(&self.projection)?;
+ let c_embed = c
+ .apply(&self.cond_mapper_lin1)?
+ .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
+ .apply(&self.cond_mapper_lin2)?;
+ let r_embed = self.gen_r_embedding(r)?;
+ for block in self.blocks.iter() {
+ xs = block.res_block.forward(&xs, None)?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ xs = block.attn_block.forward(&xs, &c_embed)?;
+ }
+ let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
+ (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
+ }
+}
diff --git a/candle-examples/src/object_detection.rs b/candle-transformers/src/object_detection.rs
index c7c60136..ce579316 100644
--- a/candle-examples/src/object_detection.rs
+++ b/candle-transformers/src/object_detection.rs
@@ -1,12 +1,12 @@
/// A bounding box around an object.
#[derive(Debug, Clone)]
-pub struct Bbox {
+pub struct Bbox<D> {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
- pub keypoints: Vec<KeyPoint>,
+ pub data: D,
}
#[derive(Debug, Clone, Copy, PartialEq)]
@@ -17,7 +17,7 @@ pub struct KeyPoint {
}
/// Intersection over union of two bounding boxes.
-pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
+pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
@@ -28,7 +28,7 @@ pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
i_area / (b1_area + b2_area - i_area)
}
-pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
+pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs
new file mode 100644
index 00000000..76f994d0
--- /dev/null
+++ b/candle-transformers/tests/generation_tests.rs
@@ -0,0 +1,29 @@
+use candle::{Device, Result, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+
+#[test]
+fn sample_with_zero_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(1337, None, None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 3);
+ Ok(())
+}
+
+#[test]
+fn sample_with_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(0.9), None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 0);
+ Ok(())
+}
+
+#[test]
+fn sample_with_top_p() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5));
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 2);
+ Ok(())
+}
diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml
new file mode 100644
index 00000000..81a043de
--- /dev/null
+++ b/candle-wasm-examples/bert/Cargo.toml
@@ -0,0 +1,33 @@
+[package]
+name = "candle-wasm-example-bert"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.2" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
+num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+safetensors = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/bert/README.md b/candle-wasm-examples/bert/README.md
new file mode 100644
index 00000000..c34d33cc
--- /dev/null
+++ b/candle-wasm-examples/bert/README.md
@@ -0,0 +1,26 @@
+## Running BERT with Candle and WASM
+
+Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/bert/bertWorker.js b/candle-wasm-examples/bert/bertWorker.js
new file mode 100644
index 00000000..fd796c2b
--- /dev/null
+++ b/candle-wasm-examples/bert/bertWorker.js
@@ -0,0 +1,77 @@
+//load Candle Bert Module wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "bert-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Bert {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize = true,
+ } = event.data;
+ try {
+ self.postMessage({ status: "ready", message: "Starting Bert Model" });
+ const model = await Bert.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "embedding",
+ message: "Calculating Embeddings",
+ });
+ const output = model.get_embeddings({
+ sentences: sentences,
+ normalize_embeddings: normalize,
+ });
+
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output.data,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/bert/build-lib.sh b/candle-wasm-examples/bert/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/bert/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/bert/lib-example.html b/candle-wasm-examples/bert/lib-example.html
new file mode 100644
index 00000000..d10ea1db
--- /dev/null
+++ b/candle-wasm-examples/bert/lib-example.html
@@ -0,0 +1,368 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Bert</title>
+ </head>
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module" src="./code.js"></script>
+ <script type="module">
+ import { hcl } from "https://cdn.skypack.dev/d3-color@3";
+ import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3";
+ import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4";
+ import {
+ getModelInfo,
+ getEmbeddings,
+ getWikiText,
+ cosineSimilarity,
+ } from "./utils.js";
+
+ const bertWorker = new Worker("./bertWorker.js", {
+ type: "module",
+ });
+
+ const inputContainerEL = document.querySelector("#input-container");
+ const textAreaEl = document.querySelector("#input-area");
+ const outputAreaEl = document.querySelector("#output-area");
+ const formEl = document.querySelector("#form");
+ const searchInputEl = document.querySelector("#search-input");
+ const formWikiEl = document.querySelector("#form-wiki");
+ const searchWikiEl = document.querySelector("#search-wiki");
+ const outputStatusEl = document.querySelector("#output-status");
+ const modelSelectEl = document.querySelector("#model");
+
+ const sentencesRegex =
+ /(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm;
+
+ let sentenceEmbeddings = [];
+ let currInputText = "";
+ let isCalculating = false;
+
+ function toggleTextArea(state) {
+ if (state) {
+ textAreaEl.hidden = false;
+ textAreaEl.focus();
+ } else {
+ textAreaEl.hidden = true;
+ }
+ }
+ inputContainerEL.addEventListener("focus", (e) => {
+ toggleTextArea(true);
+ });
+ textAreaEl.addEventListener("blur", (e) => {
+ toggleTextArea(false);
+ });
+ textAreaEl.addEventListener("focusout", (e) => {
+ toggleTextArea(false);
+ if (currInputText === textAreaEl.value || isCalculating) return;
+ populateOutputArea(textAreaEl.value);
+ calculateEmbeddings(textAreaEl.value);
+ });
+
+ modelSelectEl.addEventListener("change", (e) => {
+ if (currInputText === "" || isCalculating) return;
+ populateOutputArea(textAreaEl.value);
+ calculateEmbeddings(textAreaEl.value);
+ });
+
+ function populateOutputArea(text) {
+ currInputText = text;
+ const sentences = text.split(sentencesRegex);
+
+ outputAreaEl.innerHTML = "";
+ for (const [id, sentence] of sentences.entries()) {
+ const sentenceEl = document.createElement("span");
+ sentenceEl.id = `sentence-${id}`;
+ sentenceEl.innerText = sentence + " ";
+ outputAreaEl.appendChild(sentenceEl);
+ }
+ }
+ formEl.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if (isCalculating || currInputText === "") return;
+ toggleInputs(true);
+ const modelID = modelSelectEl.value;
+ const { modelURL, tokenizerURL, configURL, search_prefix } =
+ getModelInfo(modelID);
+
+ const text = searchInputEl.value;
+ const query = search_prefix + searchInputEl.value;
+ outputStatusEl.classList.remove("invisible");
+ outputStatusEl.innerText = "Calculating embeddings for query...";
+ isCalculating = true;
+ const out = await getEmbeddings(
+ bertWorker,
+ modelURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ [query]
+ );
+ outputStatusEl.classList.add("invisible");
+ const queryEmbeddings = out.output[0];
+ // calculate cosine similarity with all sentences given the query
+ const distances = sentenceEmbeddings
+ .map((embedding, id) => ({
+ id,
+ similarity: cosineSimilarity(queryEmbeddings, embedding),
+ }))
+ .sort((a, b) => b.similarity - a.similarity)
+ // getting top 10 most similar sentences
+ .slice(0, 10);
+
+ const colorScale = scaleLinear()
+ .domain([
+ distances[distances.length - 1].similarity,
+ distances[0].similarity,
+ ])
+ .range([0, 1])
+ .interpolate(() => interpolateReds);
+ outputAreaEl.querySelectorAll("span").forEach((el) => {
+ el.style.color = "unset";
+ el.style.backgroundColor = "unset";
+ });
+ distances.forEach((d) => {
+ const el = outputAreaEl.querySelector(`#sentence-${d.id}`);
+ const color = colorScale(d.similarity);
+ const fontColor = hcl(color).l < 70 ? "white" : "black";
+ el.style.color = fontColor;
+ el.style.backgroundColor = color;
+ });
+
+ outputAreaEl
+ .querySelector(`#sentence-${distances[0].id}`)
+ .scrollIntoView({
+ behavior: "smooth",
+ block: "center",
+ inline: "nearest",
+ });
+
+ isCalculating = false;
+ toggleInputs(false);
+ });
+ async function calculateEmbeddings(text) {
+ isCalculating = true;
+ toggleInputs(true);
+ const modelID = modelSelectEl.value;
+ const { modelURL, tokenizerURL, configURL, document_prefix } =
+ getModelInfo(modelID);
+
+ const sentences = text.split(sentencesRegex);
+ const allEmbeddings = [];
+ outputStatusEl.classList.remove("invisible");
+ for (const [id, sentence] of sentences.entries()) {
+ const query = document_prefix + sentence;
+ outputStatusEl.innerText = `Calculating embeddings: sentence ${
+ id + 1
+ } of ${sentences.length}`;
+ const embeddings = await getEmbeddings(
+ bertWorker,
+ modelURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ [query],
+ updateStatus
+ );
+ allEmbeddings.push(embeddings);
+ }
+ outputStatusEl.classList.add("invisible");
+ sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]);
+ isCalculating = false;
+ toggleInputs(false);
+ }
+
+ function updateStatus(data) {
+ if ("status" in data) {
+ if (data.status === "loading") {
+ outputStatusEl.innerText = data.message;
+ outputStatusEl.classList.remove("invisible");
+ }
+ }
+ }
+ function toggleInputs(state) {
+ const interactive = document.querySelectorAll(".interactive");
+ interactive.forEach((el) => {
+ if (state) {
+ el.disabled = true;
+ } else {
+ el.disabled = false;
+ }
+ });
+ }
+
+ searchWikiEl.addEventListener("input", () => {
+ searchWikiEl.setCustomValidity("");
+ });
+
+ formWikiEl.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if ("example" in e.submitter.dataset) {
+ searchWikiEl.value = e.submitter.innerText;
+ }
+ const text = searchWikiEl.value;
+
+ if (isCalculating || text === "") return;
+ try {
+ const wikiText = await getWikiText(text);
+ searchWikiEl.setCustomValidity("");
+ textAreaEl.innerHTML = wikiText;
+ populateOutputArea(wikiText);
+ calculateEmbeddings(wikiText);
+ searchWikiEl.value = "";
+ } catch {
+ searchWikiEl.setCustomValidity("Invalid Wikipedia article name");
+ searchWikiEl.reportValidity();
+ }
+ });
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-5 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle BERT</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ Running sentence embeddings and similarity search in the browser using
+ the Bert Model written with
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ and compiled to Wasm. Embeddings models from are from
+ <a
+ href="https://huggingface.co/sentence-transformers/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >
+ Sentence Transformers
+ </a>
+ and
+ <a
+ href="https://huggingface.co/intfloat/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >
+ Liang Wang - e5 Models
+ </a>
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium block">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
+ >
+ <option value="intfloat_e5_small_v2" selected>
+ intfloat/e5-small-v2 (133 MB)
+ </option>
+ <option value="intfloat_e5_base_v2">
+ intfloat/e5-base-v2 (438 MB)
+ </option>
+ <option value="intfloat_multilingual_e5_small">
+ intfloat/multilingual-e5-small (471 MB)
+ </option>
+ <option value="sentence_transformers_all_MiniLM_L6_v2">
+ sentence-transformers/all-MiniLM-L6-v2 (90.9 MB)
+ </option>
+ <option value="sentence_transformers_all_MiniLM_L12_v2">
+ sentence-transformers/all-MiniLM-L12-v2 (133 MB)
+ </option>
+ </select>
+ </div>
+ <div>
+ <h3 class="font-medium">Examples:</h3>
+ <form
+ id="form-wiki"
+ class="flex text-xs rounded-md justify-between w-min gap-3"
+ >
+ <input type="submit" hidden />
+
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Pizza
+ </button>
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Paris
+ </button>
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Physics
+ </button>
+ <input
+ type="text"
+ id="search-wiki"
+ title="Search Wikipedia article by title"
+ class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive"
+ placeholder="Load Wikipedia article..."
+ />
+ <button
+ title="Search Wikipedia article and load into input"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
+ >
+ Load
+ </button>
+ </form>
+ </div>
+ <form
+ id="form"
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
+ >
+ <input type="submit" hidden />
+ <input
+ type="text"
+ id="search-input"
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed"
+ placeholder="Search query here..."
+ />
+ <button
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
+ >
+ Search
+ </button>
+ </form>
+ <div>
+ <h3 class="font-medium">Input text:</h3>
+ <div class="flex justify-between items-center">
+ <div class="rounded-md inline text-xs">
+ <span id="output-status" class="m-auto font-light invisible"
+ >C</span
+ >
+ </div>
+ </div>
+ <div
+ id="input-container"
+ tabindex="0"
+ class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative"
+ >
+ <textarea
+ id="input-area"
+ hidden
+ value=""
+ placeholder="Input text to perform semantic similarity search..."
+ class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible"
+ ></textarea>
+ <p id="output-area" class="grid-rows-2">
+ Input text to perform semantic similarity search...
+ </p>
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs
new file mode 100644
index 00000000..f5521abd
--- /dev/null
+++ b/candle-wasm-examples/bert/src/bin/m.rs
@@ -0,0 +1,92 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::bert::{BertModel, Config};
+use candle_wasm_example_bert::console_log;
+use tokenizers::{PaddingParams, Tokenizer};
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct Model {
+ bert: BertModel,
+ tokenizer: Tokenizer,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let device = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
+ let config: Config = serde_json::from_slice(&config)?;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let bert = BertModel::load(vb, &config)?;
+
+ Ok(Self { bert, tokenizer })
+ }
+
+ pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let input: Params =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+ let sentences = input.sentences;
+ let normalize_embeddings = input.normalize_embeddings;
+
+ let device = &Device::Cpu;
+ if let Some(pp) = self.tokenizer.get_padding_mut() {
+ pp.strategy = tokenizers::PaddingStrategy::BatchLongest
+ } else {
+ let pp = PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ ..Default::default()
+ };
+ self.tokenizer.with_padding(Some(pp));
+ }
+ let tokens = self
+ .tokenizer
+ .encode_batch(sentences.to_vec(), true)
+ .map_err(|m| JsError::new(&m.to_string()))?;
+
+ let token_ids: Vec<Tensor> = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let token_ids = Tensor::stack(&token_ids, 0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+ console_log!("running inference on batch {:?}", token_ids.shape());
+ let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
+ console_log!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if normalize_embeddings {
+ embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
+ } else {
+ embeddings
+ };
+ let embeddings_data = embeddings.to_vec2()?;
+ Ok(serde_wasm_bindgen::to_value(&Embeddings {
+ data: embeddings_data,
+ })?)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Embeddings {
+ data: Vec<Vec<f64>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct Params {
+ sentences: Vec<String>,
+ normalize_embeddings: bool,
+}
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/bert/src/lib.rs b/candle-wasm-examples/bert/src/lib.rs
new file mode 100644
index 00000000..1e3657be
--- /dev/null
+++ b/candle-wasm-examples/bert/src/lib.rs
@@ -0,0 +1,20 @@
+use candle_transformers::models::bert;
+use wasm_bindgen::prelude::*;
+
+pub use bert::{BertModel, Config, DTYPE};
+pub use tokenizers::{PaddingParams, Tokenizer};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}
diff --git a/candle-wasm-examples/bert/utils.js b/candle-wasm-examples/bert/utils.js
new file mode 100644
index 00000000..9d8bd7bd
--- /dev/null
+++ b/candle-wasm-examples/bert/utils.js
@@ -0,0 +1,99 @@
+export async function getEmbeddings(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ updateStatus = null
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+
+const MODELS = {
+ intfloat_e5_small_v2: {
+ base_url: "https://huggingface.co/intfloat/e5-small-v2/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage: ",
+ },
+ intfloat_e5_base_v2: {
+ base_url: "https://huggingface.co/intfloat/e5-base-v2/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage:",
+ },
+ intfloat_multilingual_e5_small: {
+ base_url:
+ "https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage: ",
+ },
+ sentence_transformers_all_MiniLM_L6_v2: {
+ base_url:
+ "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/refs%2Fpr%2F21/",
+ search_prefix: "",
+ document_prefix: "",
+ },
+ sentence_transformers_all_MiniLM_L12_v2: {
+ base_url:
+ "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/refs%2Fpr%2F4/",
+ search_prefix: "",
+ document_prefix: "",
+ },
+};
+export function getModelInfo(id) {
+ return {
+ modelURL: MODELS[id].base_url + "model.safetensors",
+ configURL: MODELS[id].base_url + "config.json",
+ tokenizerURL: MODELS[id].base_url + "tokenizer.json",
+ search_prefix: MODELS[id].search_prefix,
+ document_prefix: MODELS[id].document_prefix,
+ };
+}
+
+export function cosineSimilarity(vec1, vec2) {
+ const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0);
+ const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0));
+ const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0));
+ return dot / (a * b);
+}
+export async function getWikiText(article) {
+ // thanks to wikipedia for the API
+ const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`;
+ return fetch(URL, {
+ method: "GET",
+ headers: {
+ Accept: "application/json",
+ },
+ })
+ .then((r) => r.json())
+ .then((data) => {
+ const pages = data.query.pages;
+ const pageId = Object.keys(pages)[0];
+ const extract = pages[pageId].extract;
+ if (extract === undefined || extract === "") {
+ throw new Error("No article found");
+ }
+ return extract;
+ })
+ .catch((error) => console.error("Error:", error));
+}
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
index 51eac694..601f5e34 100644
--- a/candle-wasm-examples/llama2-c/Cargo.toml
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.2.1" }
-candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
+candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.3" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
diff --git a/candle-wasm-examples/llama2-c/README.md b/candle-wasm-examples/llama2-c/README.md
new file mode 100644
index 00000000..0b41e064
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/README.md
@@ -0,0 +1,47 @@
+## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples
+
+Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes.
+
+### Pure Rust UI
+
+To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
+From the `candle-wasm-examples/llama2-c` directory run:
+
+Download assets:
+
+```bash
+# Model and tokenizer
+
+wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
+wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
+
+```
+
+Run hot reload server:
+
+```bash
+trunk serve --release --public-url / --port 8080
+```
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/llama2-c/build-lib.sh b/candle-wasm-examples/llama2-c/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
new file mode 100644
index 00000000..86fe9811
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -0,0 +1,359 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Llama.c Rust/WASM</title>
+ </head>
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ code,
+ output,
+ select,
+ pre {
+ font-family: "Source Code Pro", monospace;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ // base url for audio examples
+ const MODELS_BASE_URL =
+ "https://huggingface.co/karpathy/tinyllamas/resolve/main";
+
+ // models base url
+ const MODELS = {
+ stories15M: {
+ url: "stories15M.bin",
+ seq_len: 256,
+ },
+ stories42M: {
+ url: "stories42M.bin",
+ seq_len: 1024,
+ },
+ stories110M: {
+ url: "stories110M.bin",
+ seq_len: 1024,
+ },
+ };
+
+ const llamaWorker = new Worker("./llama2cWorker.js", {
+ type: "module",
+ });
+ async function generateSequence(controller) {
+ const getValue = (id) => document.querySelector(`#${id}`).value;
+ const modelID = getValue("model");
+ const model = MODELS[modelID];
+ const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
+ const prompt = getValue("prompt");
+ const temperature = getValue("temperature");
+ const topP = getValue("top-p");
+ const repeatPenalty = getValue("repeat_penalty");
+ const seed = getValue("seed");
+ const maxSeqLen = getValue("max-seq");
+
+ function updateStatus(data) {
+ const outStatus = document.querySelector("#output-status");
+ const outGen = document.querySelector("#output-generation");
+ const outCounter = document.querySelector("#output-counter");
+
+ switch (data.status) {
+ case "loading":
+ outStatus.hidden = false;
+ outStatus.textContent = data.message;
+ outGen.hidden = true;
+ outCounter.hidden = true;
+ break;
+ case "generating":
+ const { message, prompt, sentence, tokensSec, totalTime } = data;
+ outStatus.hidden = true;
+ outCounter.hidden = false;
+ outGen.hidden = false;
+ outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace(
+ /\<s\>|\<\/s\>/g,
+ ""
+ )}`;
+ outCounter.innerHTML = `${(totalTime / 1000).toFixed(
+ 2
+ )}s (${tokensSec.toFixed(2)} tok/s)`;
+ break;
+ case "complete":
+ outStatus.hidden = true;
+ outGen.hidden = false;
+ break;
+ }
+ }
+
+ return new Promise((resolve, reject) => {
+ llamaWorker.postMessage({
+ weightsURL,
+ modelID,
+ tokenizerURL: "tokenizer.json",
+ prompt,
+ temp: temperature,
+ top_p: topP,
+ repeatPenalty,
+ seed: BigInt(seed),
+ maxSeqLen,
+ command: "start",
+ });
+
+ const handleAbort = () => {
+ llamaWorker.postMessage({ command: "abort" });
+ };
+ const handleMessage = (event) => {
+ const { status, error, message, prompt, sentence } = event.data;
+ if (status) updateStatus(event.data);
+ if (error) {
+ llamaWorker.removeEventListener("message", handleMessage);
+ reject(new Error(error));
+ }
+ if (status === "complete") {
+ llamaWorker.removeEventListener("message", handleMessage);
+ resolve(event.data);
+ }
+ };
+
+ controller.signal.addEventListener("abort", handleAbort);
+ llamaWorker.addEventListener("message", handleMessage);
+ });
+ }
+
+ const form = document.querySelector("#form");
+ const prompt = document.querySelector("#prompt");
+ const clearBtn = document.querySelector("#clear-btn");
+ const runBtn = document.querySelector("#run");
+ const modelSelect = document.querySelector("#model");
+ let runController = new AbortController();
+ let isRunning = false;
+
+ modelSelect.addEventListener("change", (e) => {
+ const model = MODELS[e.target.value];
+ document.querySelector("#max-seq").max = model.seq_len;
+ document.querySelector("#max-seq").nextElementSibling.value =
+ model.seq_len;
+ });
+
+ form.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if (isRunning) {
+ stopRunning();
+ } else {
+ startRunning();
+ await generateSequence(runController);
+ stopRunning();
+ }
+ });
+
+ function startRunning() {
+ isRunning = true;
+ runBtn.textContent = "Stop";
+ }
+
+ function stopRunning() {
+ runController.abort();
+ runController = new AbortController();
+ runBtn.textContent = "Run";
+ isRunning = false;
+ }
+ clearBtn.addEventListener("click", (e) => {
+ e.preventDefault();
+ prompt.value = "";
+ clearBtn.classList.add("invisible");
+ runBtn.disabled = true;
+ stopRunning();
+ });
+ prompt.addEventListener("input", (e) => {
+ runBtn.disabled = false;
+ if (e.target.value.length > 0) {
+ clearBtn.classList.remove("invisible");
+ } else {
+ clearBtn.classList.add("invisible");
+ }
+ });
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4 text-gray-800">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle Llama2.c</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ <a
+ href="https://github.com/karpathy/llama2.c"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >Llama2.c</a
+ >
+ is Andrey Karpathy's C implementation of the Llama 2 LLM model in C.
+ This demo uses
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ to run Llama2.c in the browser using rust/wasm.
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"
+ >
+ <option value="stories15M" selected>stories 15M (60.8 MB)</option>
+ <option value="stories42M">stories 42M (167 MB)</option>
+ <option value="stories110M">stories 110M (438 MB)</option>
+ </select>
+ </div>
+ <form
+ id="form"
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
+ >
+ <input type="submit" hidden />
+ <input
+ type="text"
+ id="prompt"
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
+ placeholder="Add your prompt here..."
+ value="Once upon a time"
+ />
+ <button id="clear-btn">
+ <svg
+ fill="none"
+ xmlns="http://www.w3.org/2000/svg"
+ width="40"
+ viewBox="0 0 70 40"
+ >
+ <path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
+ <path
+ d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
+ opacity=".5"
+ stroke="#1F2937"
+ stroke-width="2"
+ />
+ </svg>
+ </button>
+ <button
+ id="run"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
+ >
+ Run
+ </button>
+ </form>
+ <div class="grid grid-cols-3 max-w-md items-center gap-3">
+ <label class="text-sm font-medium" for="max-seq">Maximum length </label>
+ <input
+ type="range"
+ id="max-seq"
+ name="max-seq"
+ min="1"
+ max="256"
+ step="1"
+ value="200"
+ oninput="this.nextElementSibling.value = Number(this.value)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 200</output
+ >
+ <label class="text-sm font-medium" for="temperature">Temperature</label>
+ <input
+ type="range"
+ id="temperature"
+ name="temperature"
+ min="0"
+ max="2"
+ step="0.01"
+ value="0.50"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 0.50</output
+ >
+ <label class="text-sm font-medium" for="top-p">Top-p</label>
+ <input
+ type="range"
+ id="top-p"
+ name="top-p"
+ min="0"
+ max="1"
+ step="0.01"
+ value="1.00"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 1.00</output
+ >
+
+ <label class="text-sm font-medium" for="repeat_penalty"
+ >Repeat Penalty</label
+ >
+
+ <input
+ type="range"
+ id="repeat_penalty"
+ name="repeat_penalty"
+ min="-2"
+ max="2"
+ step="0.01"
+ value="1.10"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >1.10</output
+ >
+ <label class="text-sm font-medium" for="seed">Seed</label>
+ <input
+ type="number"
+ id="seed"
+ name="seed"
+ value="299792458"
+ class="font-light border border-gray-700 text-right rounded-md p-2"
+ />
+ <button
+ id="run"
+ onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"
+ >
+ Rand
+ </button>
+ </div>
+ <div>
+ <h3 class="font-medium">Generation:</h3>
+ <div
+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
+ >
+ <div
+ id="output-counter"
+ hidden
+ class="ml-auto font-semibold grid-rows-1 text-sm"
+ ></div>
+ <p hidden id="output-generation" class="grid-rows-2"></p>
+ <span id="output-status" class="m-auto font-light"
+ >No output yet</span
+ >
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js
new file mode 100644
index 00000000..abaf3401
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/llama2cWorker.js
@@ -0,0 +1,106 @@
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "llama2c-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Llama2C {
+ static instance = {};
+
+ static async getInstance(weightsURL, modelID, tokenizerURL) {
+ // load individual modelID only once
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+
+ const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ ]);
+
+ this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8);
+ }
+ return this.instance[modelID];
+ }
+}
+
+let controller = null;
+self.addEventListener("message", (event) => {
+ if (event.data.command === "start") {
+ controller = new AbortController();
+ generate(event.data);
+ } else if (event.data.command === "abort") {
+ controller.abort();
+ }
+});
+
+async function generate(data) {
+ const {
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ prompt,
+ temp,
+ repeatPenalty,
+ seed,
+ maxSeqLen,
+ } = data;
+ try {
+ self.postMessage({ status: "loading", message: "Starting llama2.c" });
+ const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
+
+ self.postMessage({ status: "loading", message: "Initializing model" });
+ model.init_with_prompt(prompt, temp, repeatPenalty, seed);
+
+ const seq_len = model.get_seq_len();
+
+ let sentence = "";
+ let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
+ let startTime = performance.now();
+ let tokensCount = 0;
+ while (tokensCount < maxTokens) {
+ await new Promise(async (resolve) => {
+ if (controller && controller.signal.aborted) {
+ self.postMessage({
+ status: "aborted",
+ message: "Aborted",
+ output: prompt + sentence,
+ });
+ return;
+ }
+ const token = await model.next_token();
+ const tokensSec =
+ ((tokensCount + 1) / (performance.now() - startTime)) * 1000;
+
+ sentence += token;
+ self.postMessage({
+ status: "generating",
+ message: "Generating token",
+ token: token,
+ sentence: sentence,
+ totalTime: performance.now() - startTime,
+ tokensSec,
+ prompt: prompt,
+ });
+ setTimeout(resolve, 0);
+ });
+ tokensCount++;
+ }
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: prompt + sentence,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index 782026a4..ea04a810 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -46,6 +46,7 @@ pub struct App {
status: String,
loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
+ top_p: std::rc::Rc<std::cell::RefCell<f64>>,
prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String,
n_tokens: usize,
@@ -81,6 +82,7 @@ impl Component for App {
status,
n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
+ top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(),
current_decode: None,
@@ -122,10 +124,11 @@ impl Component for App {
self.n_tokens = 0;
self.generated.clear();
let temp = *self.temperature.borrow();
+ let top_p = *self.top_p.borrow();
let prompt = self.prompt.borrow().clone();
- console_log!("temp: {}, prompt: {}", temp, prompt);
+ console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
- .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
+ .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
@@ -177,13 +180,21 @@ impl Component for App {
fn view(&self, ctx: &Context<Self>) -> Html {
use yew::TargetCast;
let temperature = self.temperature.clone();
- let oninput = ctx.link().callback(move |e: yew::InputEvent| {
+ let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp
}
Msg::Refresh
});
+ let top_p = self.top_p.clone();
+ let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
+ let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+ if let Ok(top_p_input) = f64::from_str(&input.value()) {
+ *top_p.borrow_mut() = top_p_input
+ }
+ Msg::Refresh
+ });
let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@@ -201,9 +212,13 @@ impl Component for App {
</p>
</div>
{"temperature \u{00a0} "}
- <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
+ <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ >
+ {"top_p \u{00a0} "}
+ <input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
+ {format!(" \u{00a0} {}", self.top_p.borrow())}
+ <br/ >
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
<br/ >
{
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index d014e38a..61de9d7f 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -47,7 +47,7 @@ impl Model {
tokenizer,
model: weights,
});
- let logits_processor = LogitsProcessor::new(299792458, None);
+ let logits_processor = LogitsProcessor::new(299792458, None, None);
match model {
Ok(inner) => Ok(Self {
inner,
@@ -60,11 +60,18 @@ impl Model {
}
#[wasm_bindgen]
+ pub fn get_seq_len(&mut self) -> usize {
+ self.inner.config.seq_len
+ }
+
+ #[wasm_bindgen]
pub fn init_with_prompt(
&mut self,
prompt: String,
temp: f64,
+ top_p: f64,
repeat_penalty: f32,
+ seed: u64,
) -> Result<String, JsError> {
// First reset the cache.
{
@@ -74,13 +81,18 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
- self.logits_processor = LogitsProcessor::new(299792458, temp);
+ let top_p = if top_p <= 0. || top_p >= 1. {
+ None
+ } else {
+ Some(top_p)
+ };
+ self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self
.inner
.tokenizer
- .encode(prompt.to_string(), true)
+ .encode(prompt, true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 3d187fcc..79dd2f32 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
pub struct Model {
pub cache: Cache,
- config: Config,
+ pub config: Config,
pub llama: Llama,
pub tokenizer: Tokenizer,
}
@@ -62,12 +62,18 @@ impl Model {
link: &WorkerLink<Worker>,
id: HandlerId,
temp: f64,
+ top_p: f64,
prompt: String,
) -> Result<()> {
let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) };
- console_log!("{temp:?} {prompt}");
- let mut logits_processor = LogitsProcessor::new(299792458, temp);
+ let top_p = if top_p <= 0. || top_p >= 1.0 {
+ None
+ } else {
+ Some(top_p)
+ };
+ console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
+ let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
let mut index_pos = 0;
let mut tokens = self
.tokenizer
@@ -268,7 +274,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run(f64, String),
+ Run(f64, f64, String),
}
#[derive(Serialize, Deserialize)]
@@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run(temp, prompt) => match &mut self.model {
+ WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
{
@@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
}
}
let result = model
- .run(&self.link, id, temp, prompt)
+ .run(&self.link, id, temp, top_p, prompt)
.map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}
diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml
new file mode 100644
index 00000000..46b85615
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/Cargo.toml
@@ -0,0 +1,30 @@
+[package]
+name = "candle-wasm-example-sam"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.3" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.3" }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+getrandom = { version = "0.2", features = ["js"] }
+image = { workspace = true }
+log = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+wasm-bindgen = "0.2.87"
+serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md
new file mode 100644
index 00000000..04ff2033
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/README.md
@@ -0,0 +1,26 @@
+## Running Segment Anything Example
+
+Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html
new file mode 100644
index 00000000..5060f073
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/lib-example.html
@@ -0,0 +1,407 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Segment Anything Model (SAM) Rust/WASM</title>
+ </head>
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ // base url for image examples
+ const MODEL_BASEURL =
+ "https://huggingface.co/lmz/candle-sam/resolve/main/";
+
+ // models base url
+ const MODELS = {
+ sam_mobile_tiny: {
+ url: "mobile_sam-tiny-vitt.safetensors",
+ },
+ sam_base: {
+ url: "sam_vit_b_01ec64.safetensors",
+ },
+ };
+ const samWorker = new Worker("./samWorker.js", { type: "module" });
+
+ async function segmentPoints(
+ modelURL, // URL to the weights file
+ modelID, // model ID
+ imageURL, // URL to the image file
+ points // {x, y} points to prompt image
+ ) {
+ return new Promise((resolve, reject) => {
+ function messageHandler(event) {
+ console.log(event.data);
+ if ("status" in event.data) {
+ updateStatus(event.data);
+ }
+ if ("error" in event.data) {
+ samWorker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete-embedding") {
+ samWorker.removeEventListener("message", messageHandler);
+ resolve();
+ }
+ if (event.data.status === "complete") {
+ samWorker.removeEventListener("message", messageHandler);
+ resolve(event.data.output);
+ }
+ }
+ samWorker.addEventListener("message", messageHandler);
+ samWorker.postMessage({
+ modelURL,
+ modelID,
+ imageURL,
+ points,
+ });
+ });
+ }
+ function updateStatus(statusMessage) {
+ statusOutput.innerText = event.data.message;
+ }
+
+ const clearBtn = document.querySelector("#clear-btn");
+ const canvas = document.querySelector("#canvas");
+ const mask = document.querySelector("#mask");
+ const ctxCanvas = canvas.getContext("2d");
+ const ctxMask = mask.getContext("2d");
+ const fileUpload = document.querySelector("#file-upload");
+ const dropArea = document.querySelector("#drop-area");
+ const dropButtons = document.querySelector("#drop-buttons");
+ const imagesExamples = document.querySelector("#image-select");
+ const modelSelection = document.querySelector("#model");
+ const statusOutput = document.querySelector("#output-status");
+
+ //add event listener to file input
+ fileUpload.addEventListener("change", (e) => {
+ const target = e.target;
+ if (target.files.length > 0) {
+ const href = URL.createObjectURL(target.files[0]);
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ }
+ });
+ // add event listener to drop-area
+ dropArea.addEventListener("dragenter", (e) => {
+ e.preventDefault();
+ dropArea.classList.add("border-blue-700");
+ });
+ dropArea.addEventListener("dragleave", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ });
+ dropArea.addEventListener("dragover", (e) => {
+ e.preventDefault();
+ });
+ dropArea.addEventListener("drop", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ const url = e.dataTransfer.getData("text/uri-list");
+ const files = e.dataTransfer.files;
+
+ if (files.length > 0) {
+ const href = URL.createObjectURL(files[0]);
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ } else if (url) {
+ cleanImageCanvas();
+ drawImageCanvas(url);
+ setImageEmbeddings(url);
+ }
+ });
+
+ let hasImage = false;
+ let isSegmenting = false;
+ let isEmbedding = false;
+ let currentImageURL = "";
+ //add event listener to image examples
+ imagesExamples.addEventListener("click", (e) => {
+ if (isEmbedding || isSegmenting) {
+ return;
+ }
+ const target = e.target;
+ if (target.nodeName === "IMG") {
+ const href = target.src;
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ }
+ });
+ //add event listener to clear button
+ clearBtn.addEventListener("click", () => {
+ cleanImageCanvas();
+ });
+ //add click event to canvas
+ canvas.addEventListener("click", async (event) => {
+ if (!hasImage || isEmbedding || isSegmenting) {
+ return;
+ }
+ const targetBox = event.target.getBoundingClientRect();
+ const x = (event.clientX - targetBox.left) / targetBox.width;
+ const y = (event.clientY - targetBox.top) / targetBox.height;
+ isSegmenting = true;
+ const { maskURL } = await getSegmentationMask({ x, y });
+ isSegmenting = false;
+ drawMask(maskURL);
+ });
+
+ async function getSegmentationMask(points) {
+ const modelID = modelSelection.value;
+ const modelURL = MODEL_BASEURL + MODELS[modelID].url;
+ const imageURL = currentImageURL;
+ const { maskURL } = await segmentPoints(
+ modelURL,
+ modelID,
+ imageURL,
+ points
+ );
+ return { maskURL };
+ }
+ async function setImageEmbeddings(imageURL) {
+ if (isEmbedding) {
+ return;
+ }
+ canvas.classList.remove("cursor-pointer");
+ canvas.classList.add("cursor-wait");
+ clearBtn.disabled = true;
+ const modelID = modelSelection.value;
+ const modelURL = MODEL_BASEURL + MODELS[modelID].url;
+ isEmbedding = true;
+ await segmentPoints(modelURL, modelID, imageURL);
+ canvas.classList.remove("cursor-wait");
+ canvas.classList.add("cursor-pointer");
+ clearBtn.disabled = false;
+ isEmbedding = false;
+ currentImageURL = imageURL;
+ }
+
+ function cleanImageCanvas() {
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+ ctxMask.clearRect(0, 0, canvas.width, canvas.height);
+ hasImage = false;
+ isEmbedding = false;
+ isSegmenting = false;
+ currentImageURL = "";
+ clearBtn.classList.add("invisible");
+ canvas.parentElement.style.height = "auto";
+ dropButtons.classList.remove("invisible");
+ }
+ function drawMask(maskURL) {
+ if (!maskURL) {
+ throw new Error("No mask URL provided");
+ }
+
+ const img = new Image();
+ img.crossOrigin = "anonymous";
+
+ img.onload = () => {
+ mask.width = canvas.width;
+ mask.height = canvas.height;
+ ctxMask.drawImage(canvas, 0, 0);
+ ctxMask.globalCompositeOperation = "source-atop";
+ ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
+ ctxMask.fillRect(0, 0, canvas.width, canvas.height);
+ ctxMask.globalCompositeOperation = "destination-in";
+ ctxMask.drawImage(img, 0, 0);
+ };
+ img.src = maskURL;
+ }
+ function drawImageCanvas(imgURL) {
+ if (!imgURL) {
+ throw new Error("No image URL provided");
+ }
+
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+
+ const img = new Image();
+ img.crossOrigin = "anonymous";
+
+ img.onload = () => {
+ canvas.width = img.width;
+ canvas.height = img.height;
+ ctxCanvas.drawImage(img, 0, 0);
+ canvas.parentElement.style.height = canvas.offsetHeight + "px";
+ hasImage = true;
+ clearBtn.classList.remove("invisible");
+ dropButtons.classList.add("invisible");
+ };
+ img.src = imgURL;
+ }
+
+ const observer = new ResizeObserver((entries) => {
+ for (let entry of entries) {
+ if (entry.target === canvas) {
+ canvas.parentElement.style.height = canvas.offsetHeight + "px";
+ }
+ }
+ });
+ observer.observe(canvas);
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]">🕯️</span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle Segment Anything</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ Zero-shot image segmentation with
+ <a
+ href="https://segment-anything.com"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >Segment Anything Model (SAM)</a
+ >
+ and
+ <a
+ href="https://github.com/ChaoningZhang/MobileSAM"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >MobileSAM </a
+ >. It runs in the browser with a WASM runtime built with
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ </p>
+ </div>
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"
+ >
+ <option value="sam_mobile_tiny" selected>
+ Mobile SAM Tiny (40.6 MB)
+ </option>
+ <option value="sam_base">SAM Base (375 MB)</option>
+ </select>
+ </div>
+ <div>
+ <p class="text-xs italic max-w-lg">
+ <b>Note:</b>
+ The model's first run may take a few seconds as it loads and caches
+ the model in the browser, and then creates the image embeddings. Any
+ subsequent clicks on points will be significantly faster.
+ </p>
+ </div>
+ <div class="relative max-w-lg">
+ <div class="flex justify-between items-center">
+ <div class="px-2 rounded-md inline text-xs">
+ <span id="output-status" class="m-auto font-light"></span>
+ </div>
+ <button
+ id="clear-btn"
+ class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible"
+ >
+ <svg
+ class=""
+ xmlns="http://www.w3.org/2000/svg"
+ viewBox="0 0 13 12"
+ height="1em"
+ >
+ <path
+ d="M1.6.7 12 11.1M12 .7 1.6 11.1"
+ stroke="#2E3036"
+ stroke-width="2"
+ />
+ </svg>
+ Clear image
+ </button>
+ </div>
+ <div
+ id="drop-area"
+ class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden"
+ >
+ <div
+ id="drop-buttons"
+ class="flex flex-col items-center justify-center space-y-1 text-center relative z-10"
+ >
+ <svg
+ width="25"
+ height="25"
+ viewBox="0 0 25 25"
+ fill="none"
+ xmlns="http://www.w3.org/2000/svg"
+ >
+ <path
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
+ fill="#000"
+ />
+ </svg>
+ <div class="flex text-sm text-gray-600">
+ <label
+ for="file-upload"
+ class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
+ >
+ <span>Drag and drop your image here</span>
+ <span class="block text-xs">or</span>
+ <span class="block text-xs">Click to upload</span>
+ </label>
+ </div>
+ <input
+ id="file-upload"
+ name="file-upload"
+ type="file"
+ class="sr-only"
+ />
+ </div>
+ <canvas id="canvas" class="absolute w-full"></canvas>
+ <canvas
+ id="mask"
+ class="pointer-events-none absolute w-full"
+ ></canvas>
+ </div>
+ <div class="text-right py-2">
+ <button
+ id="share-btn"
+ class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible"
+ >
+ <img
+ src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg"
+ />
+ </button>
+ </div>
+ </div>
+ <div>
+ <div
+ class="flex gap-3 items-center overflow-x-scroll"
+ id="image-select"
+ >
+ <h3 class="font-medium">Examples:</h3>
+
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js
new file mode 100644
index 00000000..c1a152ef
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/samWorker.js
@@ -0,0 +1,155 @@
+//load the candle SAM Model wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url, cacheModel = true) {
+ if (!cacheModel)
+ return new Uint8Array(await (await fetch(url)).arrayBuffer());
+ const cacheName = "sam-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class SAMModel {
+ static instance = {};
+ // keep current image embeddings state
+ static imageArrayHash = {};
+ // Add a new property to hold the current modelID
+ static currentModelID = null;
+
+ static async getInstance(modelURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({
+ status: "loading",
+ message: `Loading Model ${modelID}`,
+ });
+ const weightsArrayU8 = await fetchArrayBuffer(modelURL);
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ /tiny|mobile/.test(modelID)
+ );
+ } else {
+ self.postMessage({ status: "loading", message: "Model Already Loaded" });
+ }
+ // Set the current modelID to the modelID that was passed in
+ this.currentModelID = modelID;
+ return this.instance[modelID];
+ }
+
+ // Remove the modelID parameter from setImageEmbeddings
+ static setImageEmbeddings(imageArrayU8) {
+ // check if image embeddings are already set for this image and model
+ const imageArrayHash = this.getSimpleHash(imageArrayU8);
+ if (
+ this.imageArrayHash[this.currentModelID] === imageArrayHash &&
+ this.instance[this.currentModelID]
+ ) {
+ self.postMessage({
+ status: "embedding",
+ message: "Embeddings Already Set",
+ });
+ return;
+ }
+ this.imageArrayHash[this.currentModelID] = imageArrayHash;
+ this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
+ self.postMessage({ status: "embedding", message: "Embeddings Set" });
+ }
+
+ static getSimpleHash(imageArrayU8) {
+ // get simple hash of imageArrayU8
+ let imageArrayHash = 0;
+ for (let i = 0; i < imageArrayU8.length; i += 100) {
+ imageArrayHash ^= imageArrayU8[i];
+ }
+ return imageArrayHash.toString(16);
+ }
+}
+
+async function createImageCanvas(
+ { mask_shape, mask_data }, // mask
+ { original_width, original_height, width, height } // original image
+) {
+ const [_, __, shape_width, shape_height] = mask_shape;
+ const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
+ const maskCtx = maskCanvas.getContext("2d");
+ const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
+ const ctx = canvas.getContext("2d");
+
+ const imageData = maskCtx.createImageData(
+ maskCanvas.width,
+ maskCanvas.height
+ );
+ const data = imageData.data;
+
+ for (let p = 0; p < data.length; p += 4) {
+ data[p] = 0;
+ data[p + 1] = 0;
+ data[p + 2] = 0;
+ data[p + 3] = mask_data[p / 4] * 255;
+ }
+ maskCtx.putImageData(imageData, 0, 0);
+
+ let sx, sy;
+ if (original_height < original_width) {
+ sy = original_height / original_width;
+ sx = 1;
+ } else {
+ sy = 1;
+ sx = original_width / original_height;
+ }
+ ctx.drawImage(
+ maskCanvas,
+ 0,
+ 0,
+ maskCanvas.width * sx,
+ maskCanvas.height * sy,
+ 0,
+ 0,
+ original_width,
+ original_height
+ );
+
+ const blob = await canvas.convertToBlob();
+ return URL.createObjectURL(blob);
+}
+
+self.addEventListener("message", async (event) => {
+ const { modelURL, modelID, imageURL, points } = event.data;
+ try {
+ self.postMessage({ status: "loading", message: "Starting SAM" });
+ const sam = await SAMModel.getInstance(modelURL, modelID);
+
+ self.postMessage({ status: "loading", message: "Loading Image" });
+ const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
+
+ self.postMessage({ status: "embedding", message: "Creating Embeddings" });
+ SAMModel.setImageEmbeddings(imageArrayU8);
+ if (!points) {
+ // no points only do the embeddings
+ self.postMessage({
+ status: "complete-embedding",
+ message: "Embeddings Complete",
+ });
+ return;
+ }
+
+ self.postMessage({ status: "segmenting", message: "Segmenting" });
+ const { mask, image } = sam.mask_for_point(points.x, points.y);
+ const maskDataURL = await createImageCanvas(mask, image);
+ // Send the segment back to the main thread as JSON
+ self.postMessage({
+ status: "complete",
+ message: "Segmentation Complete",
+ output: { maskURL: maskDataURL },
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
new file mode 100644
index 00000000..5140b979
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -0,0 +1,140 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_wasm_example_sam as sam;
+use wasm_bindgen::prelude::*;
+
+#[allow(unused)]
+struct Embeddings {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+ data: Tensor,
+}
+
+#[wasm_bindgen]
+pub struct Model {
+ sam: sam::Sam,
+ embeddings: Option<Embeddings>,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
+ console_error_panic_hook::set_once();
+ let dev = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
+ let sam = if use_tiny {
+ sam::Sam::new_tiny(vb)? // tiny vit_t
+ } else {
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ };
+ Ok(Self {
+ sam,
+ embeddings: None,
+ })
+ }
+
+ pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
+ sam::console_log!("image data: {}", image_data.len());
+ let image_data = std::io::Cursor::new(image_data);
+ let image = image::io::Reader::new(image_data)
+ .with_guessed_format()?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let (original_height, original_width) = (image.height(), image.width());
+ let (height, width) = (original_height, original_width);
+ let resize_longest = sam::IMAGE_SIZE as u32;
+ let (height, width) = if height < width {
+ let h = (resize_longest * height) / width;
+ (h, resize_longest)
+ } else {
+ let w = (resize_longest * width) / height;
+ (resize_longest, w)
+ };
+ let image_t = {
+ let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
+ let data = img.to_rgb8().into_raw();
+ Tensor::from_vec(
+ data,
+ (img.height() as usize, img.width() as usize, 3),
+ &Device::Cpu,
+ )?
+ .permute((2, 0, 1))?
+ };
+ let data = self.sam.embeddings(&image_t)?;
+ self.embeddings = Some(Embeddings {
+ original_width,
+ original_height,
+ width,
+ height,
+ data,
+ });
+ Ok(())
+ }
+
+ // x and y have to be between 0 and 1
+ pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> {
+ if !(0. ..=1.).contains(&x) {
+ Err(JsError::new(&format!(
+ "x has to be between 0 and 1, got {x}"
+ )))?
+ }
+ if !(0. ..=1.).contains(&y) {
+ Err(JsError::new(&format!(
+ "y has to be between 0 and 1, got {y}"
+ )))?
+ }
+ let embeddings = match &self.embeddings {
+ None => Err(JsError::new("image embeddings have not been set"))?,
+ Some(embeddings) => embeddings,
+ };
+ let (mask, iou_predictions) = self.sam.forward_for_embeddings(
+ &embeddings.data,
+ embeddings.height as usize,
+ embeddings.width as usize,
+ Some((x, y)),
+ false,
+ )?;
+ let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
+ let mask_shape = mask.dims().to_vec();
+ let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
+ let mask = Mask {
+ iou,
+ mask_shape,
+ mask_data,
+ };
+ let image = Image {
+ original_width: embeddings.original_width,
+ original_height: embeddings.original_height,
+ width: embeddings.width,
+ height: embeddings.height,
+ };
+ Ok(serde_wasm_bindgen::to_value(&MaskImage { mask, image })?)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Mask {
+ iou: f32,
+ mask_shape: Vec<usize>,
+ mask_data: Vec<u8>,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Image {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct MaskImage {
+ mask: Mask,
+ image: Image,
+}
+
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs
new file mode 100644
index 00000000..0f4f96fd
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/lib.rs
@@ -0,0 +1,19 @@
+use candle_transformers::models::segment_anything::sam;
+use wasm_bindgen::prelude::*;
+
+pub use sam::{Sam, IMAGE_SIZE};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}
diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml
index 47e7e094..8f1df531 100644
--- a/candle-wasm-examples/whisper/Cargo.toml
+++ b/candle-wasm-examples/whisper/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.2.1" }
+candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.3" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
diff --git a/candle-wasm-examples/whisper/lib-example.html b/candle-wasm-examples/whisper/lib-example.html
index a8c49785..3cfd87a7 100644
--- a/candle-wasm-examples/whisper/lib-example.html
+++ b/candle-wasm-examples/whisper/lib-example.html
@@ -6,7 +6,7 @@
<body></body>
</html>
-<!doctype html>
+<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
@@ -51,18 +51,21 @@
mel_filtersURL,
audioURL,
});
- whisperWorker.addEventListener("message", (event) => {
+ function messageHandler(event) {
console.log(event.data);
if ("status" in event.data) {
updateStatus(event.data);
}
if ("error" in event.data) {
+ whisperWorker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
+ whisperWorker.removeEventListener("message", messageHandler);
resolve(event.data);
}
- });
+ }
+ whisperWorker.addEventListener("message", messageHandler);
});
}
@@ -141,7 +144,9 @@
const { output } = result;
const text = output.map((segment) => segment.dr.text).join(" ");
console.log(text);
- document.getElementById("output").textContent = text;
+ document.querySelector("#output-status").hidden = true;
+ document.querySelector("#output-generation").hidden = false;
+ document.querySelector("#output-generation").textContent = text;
})
.catch((error) => {
console.error(error);
@@ -295,18 +300,21 @@
<button
id="detect"
disabled
- class="bg-orange-900 hover:bg-orange-800 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:cursor-not-allowed"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
>
Transcribe Audio
</button>
</div>
<div>
<h3 class="font-medium">Transcription:</h3>
-
<div
- id="output"
- class="min-h-[100px] bg-slate-500 text-white p-4 rounded-md"
- ></div>
+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
+ >
+ <p hidden id="output-generation" class="grid-rows-2"></p>
+ <span id="output-status" class="m-auto font-light"
+ >No transcription results yet</span
+ >
+ </div>
</div>
</main>
</body>
diff --git a/candle-wasm-examples/whisper/whisperWorker.js b/candle-wasm-examples/whisper/whisperWorker.js
index 2598adde..d2ad8e0b 100644
--- a/candle-wasm-examples/whisper/whisperWorker.js
+++ b/candle-wasm-examples/whisper/whisperWorker.js
@@ -2,16 +2,17 @@
import init, { Decoder } from "./build/m.js";
async function fetchArrayBuffer(url) {
- const res = await fetch(url, {
- cache: "force-cache",
- headers: {
- "Cache-Control": "public, max-age=31536000",
- },
- });
- const data = await res.arrayBuffer();
- return new Uint8Array(data);
+ const cacheName = "whisper-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
}
-
class Whisper {
static instance = {};
// Retrieve the Whisper model. When called for the first time,
diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml
index b4daf6e6..71ef8049 100644
--- a/candle-wasm-examples/yolo/Cargo.toml
+++ b/candle-wasm-examples/yolo/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.2.1" }
+candle = { path = "../../candle-core", version = "0.2.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.3" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-wasm-examples/yolo/lib-example.html b/candle-wasm-examples/yolo/lib-example.html
index bab2ec13..d9f18975 100644
--- a/candle-wasm-examples/yolo/lib-example.html
+++ b/candle-wasm-examples/yolo/lib-example.html
@@ -6,7 +6,7 @@
<body></body>
</html>
-<!doctype html>
+<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
@@ -145,6 +145,10 @@
}
});
+ document.querySelector("#clear-btn").addEventListener("click", () => {
+ drawImageCanvas();
+ });
+
function drawImageCanvas(imgURL) {
const canvas = document.querySelector("#canvas");
const canvasResult = document.querySelector("#canvas-result");
@@ -153,21 +157,28 @@
.clearRect(0, 0, canvas.width, canvas.height);
const ctx = canvas.getContext("2d");
ctx.clearRect(0, 0, canvas.width, canvas.height);
- document.querySelector("#share-btn").hidden = true;
+ document.querySelector("#share-btn").classList.add("invisible");
+ document.querySelector("#clear-btn").classList.add("invisible");
+ document.querySelector("#detect").disabled = true;
+ hasImage = false;
+ canvas.parentElement.style.height = "auto";
- const img = new Image();
- img.crossOrigin = "anonymous";
+ if (imgURL && imgURL !== "") {
+ const img = new Image();
+ img.crossOrigin = "anonymous";
- img.onload = () => {
- canvas.width = img.width;
- canvas.height = img.height;
- ctx.drawImage(img, 0, 0);
+ img.onload = () => {
+ canvas.width = img.width;
+ canvas.height = img.height;
+ ctx.drawImage(img, 0, 0);
- canvas.parentElement.style.height = canvas.offsetHeight + "px";
- hasImage = true;
- document.querySelector("#detect").disabled = false;
- };
- img.src = imgURL;
+ canvas.parentElement.style.height = canvas.offsetHeight + "px";
+ hasImage = true;
+ document.querySelector("#detect").disabled = false;
+ document.querySelector("#clear-btn").classList.remove("invisible");
+ };
+ img.src = imgURL;
+ }
}
async function classifyImage(
@@ -188,17 +199,21 @@
confidence,
iou_threshold,
});
- yoloWorker.addEventListener("message", (event) => {
+ function handleMessage(event) {
+ console.log("message", event.data);
if ("status" in event.data) {
updateStatus(event.data.status);
}
if ("error" in event.data) {
+ yoloWorker.removeEventListener("message", handleMessage);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
+ yoloWorker.removeEventListener("message", handleMessage);
resolve(event.data);
}
- });
+ }
+ yoloWorker.addEventListener("message", handleMessage);
});
}
// add event listener to detect button
@@ -310,7 +325,7 @@
button.classList.add("bg-blue-950");
button.classList.remove("bg-blue-700");
button.textContent = "Predict";
- document.querySelector("#share-btn").hidden = false;
+ document.querySelector("#share-btn").classList.remove("invisible");
}
}
document.querySelector("#share-btn").addEventListener("click", () => {
@@ -372,8 +387,37 @@
<option value="yolov8x_pose">yolov8x_pose (139 MB)</option>
</select>
</div>
+ <div>
+ <button
+ id="detect"
+ disabled
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
+ >
+ Predict
+ </button>
+ </div>
<!-- drag and drop area -->
- <div class="relative">
+ <div class="relative max-w-lg">
+ <div class="py-1">
+ <button
+ id="clear-btn"
+ class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center ml-auto invisible"
+ >
+ <svg
+ class=""
+ xmlns="http://www.w3.org/2000/svg"
+ viewBox="0 0 13 12"
+ height="1em"
+ >
+ <path
+ d="M1.6.7 12 11.1M12 .7 1.6 11.1"
+ stroke="#2E3036"
+ stroke-width="2"
+ />
+ </svg>
+ Clear image
+ </button>
+ </div>
<div
id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative aspect-video w-full overflow-hidden"
@@ -422,8 +466,7 @@
<div class="text-right py-2">
<button
id="share-btn"
- hidden
- class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50"
+ class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible"
>
<img
src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg"
@@ -432,7 +475,10 @@
</div>
</div>
<div>
- <div class="flex gap-3 items-center" id="image-select">
+ <div
+ class="flex gap-3 items-center overflow-x-scroll"
+ id="image-select"
+ >
<h3 class="font-medium">Examples:</h3>
<img
@@ -489,15 +535,6 @@
>
</div>
</div>
- <div>
- <button
- id="detect"
- disabled
- class="bg-blue-950 hover:bg-blue-700 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:hover:bg-blue-950"
- >
- Predict
- </button>
- </div>
</main>
</body>
</html>
diff --git a/candle-wasm-examples/yolo/yoloWorker.js b/candle-wasm-examples/yolo/yoloWorker.js
index 93097372..8b5ef8b9 100644
--- a/candle-wasm-examples/yolo/yoloWorker.js
+++ b/candle-wasm-examples/yolo/yoloWorker.js
@@ -1,6 +1,19 @@
//load the candle yolo wasm module
import init, { Model, ModelPose } from "./build/m.js";
+async function fetchArrayBuffer(url) {
+ const cacheName = "yolo-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+
class Yolo {
static instance = {};
// Retrieve the YOLO model. When called for the first time,
@@ -11,9 +24,7 @@ class Yolo {
await init();
self.postMessage({ status: `loading model ${modelID}:${modelSize}` });
- const modelRes = await fetch(modelURL);
- const yoloArrayBuffer = await modelRes.arrayBuffer();
- const weightsArrayU8 = new Uint8Array(yoloArrayBuffer);
+ const weightsArrayU8 = await fetchArrayBuffer(modelURL);
if (/pose/.test(modelID)) {
// if pose model, use ModelPose
this.instance[modelID] = new ModelPose(weightsArrayU8, modelSize);