summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--Cargo.toml5
-rw-r--r--LICENSE-APACHE201
-rw-r--r--LICENSE-MIT23
-rw-r--r--Makefile4
-rw-r--r--README.md7
-rw-r--r--candle-book/src/guide/hello_world.md6
-rw-r--r--candle-book/src/guide/installation.md4
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/examples/basics.rs29
-rw-r--r--candle-core/src/accelerate.rs111
-rw-r--r--candle-core/src/backend.rs11
-rw-r--r--candle-core/src/backprop.rs17
-rw-r--r--candle-core/src/conv.rs29
-rw-r--r--candle-core/src/cpu_backend.rs288
-rw-r--r--candle-core/src/cuda_backend.rs18
-rw-r--r--candle-core/src/device.rs7
-rw-r--r--candle-core/src/dtype.rs6
-rw-r--r--candle-core/src/dummy_cuda_backend.rs18
-rw-r--r--candle-core/src/ggml.rs582
-rw-r--r--candle-core/src/lib.rs3
-rw-r--r--candle-core/src/op.rs18
-rw-r--r--candle-core/src/storage.rs58
-rw-r--r--candle-core/src/tensor.rs123
-rw-r--r--candle-core/src/utils.rs14
-rw-r--r--candle-core/tests/conv_tests.rs184
-rw-r--r--candle-core/tests/pool_tests.rs17
-rw-r--r--candle-datasets/Cargo.toml20
-rw-r--r--candle-datasets/src/batcher.rs (renamed from candle-nn/src/dataset.rs)0
-rw-r--r--candle-datasets/src/lib.rs6
-rw-r--r--candle-datasets/src/nlp/mod.rs1
-rw-r--r--candle-datasets/src/nlp/tinystories.rs122
-rw-r--r--candle-datasets/src/vision/cifar.rs (renamed from candle-nn/src/vision/cifar.rs)0
-rw-r--r--candle-datasets/src/vision/mnist.rs (renamed from candle-nn/src/vision/mnist.rs)0
-rw-r--r--candle-datasets/src/vision/mod.rs (renamed from candle-nn/src/vision/mod.rs)0
-rw-r--r--candle-examples/Cargo.toml8
-rw-r--r--candle-examples/examples/llama/main.rs17
-rw-r--r--candle-examples/examples/llama/model.rs43
-rw-r--r--candle-examples/examples/llama2-c/main.rs28
-rw-r--r--candle-examples/examples/llama2-c/training.rs124
-rw-r--r--candle-examples/examples/llama2-c/weights.rs25
-rw-r--r--candle-examples/examples/mnist-training/main.rs4
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs445
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs305
-rw-r--r--candle-examples/examples/stable-diffusion/ddim.rs181
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs65
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs273
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs129
-rw-r--r--candle-examples/examples/stable-diffusion/schedulers.rs45
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs212
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs383
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs808
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs31
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs378
-rw-r--r--candle-flash-attn/Cargo.toml2
-rw-r--r--candle-kernels/Cargo.toml2
-rw-r--r--candle-kernels/src/unary.cu4
-rw-r--r--candle-nn/Cargo.toml2
-rw-r--r--candle-nn/src/conv.rs89
-rw-r--r--candle-nn/src/group_norm.rs83
-rw-r--r--candle-nn/src/lib.rs6
-rw-r--r--candle-nn/src/ops.rs10
-rw-r--r--candle-nn/tests/group_norm.rs103
-rw-r--r--candle-pyo3/Cargo.toml5
-rw-r--r--candle-pyo3/README.md10
-rw-r--r--candle-pyo3/build.rs3
-rw-r--r--candle-pyo3/src/lib.rs123
-rw-r--r--candle-pyo3/test.py20
-rw-r--r--candle-transformers/Cargo.toml2
69 files changed, 5684 insertions, 219 deletions
diff --git a/.gitignore b/.gitignore
index a7006d50..1585ed11 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,7 @@ Cargo.lock
perf.data
flamegraph.svg
+*.dylib
*.so
*.swp
trace-*.json
diff --git a/Cargo.toml b/Cargo.toml
index 301451a0..0a231df8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,7 @@
[workspace]
members = [
"candle-core",
+ "candle-datasets",
"candle-examples",
"candle-nn",
"candle-pyo3",
@@ -20,9 +21,10 @@ description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
-license = "MIT/Apache-2.0"
+license = "MIT OR Apache-2.0"
[workspace.dependencies]
+accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
@@ -31,6 +33,7 @@ cudarc = { version = "0.9.13", features = ["f16"] }
gemm = { version = "0.15.5", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
+image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
diff --git a/LICENSE-APACHE b/LICENSE-APACHE
new file mode 100644
index 00000000..261eeb9e
--- /dev/null
+++ b/LICENSE-APACHE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/LICENSE-MIT b/LICENSE-MIT
new file mode 100644
index 00000000..31aa7938
--- /dev/null
+++ b/LICENSE-MIT
@@ -0,0 +1,23 @@
+Permission is hereby granted, free of charge, to any
+person obtaining a copy of this software and associated
+documentation files (the "Software"), to deal in the
+Software without restriction, including without
+limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software
+is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice
+shall be included in all copies or substantial portions
+of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
diff --git a/Makefile b/Makefile
index cb472d80..eba92821 100644
--- a/Makefile
+++ b/Makefile
@@ -9,4 +9,8 @@ clean:
test:
cargo test
+pyo3-test:
+ cargo build --profile=release-with-debug --package candle-pyo3
+ python3 candle-pyo3/test.py
+
all: test
diff --git a/README.md b/README.md
index 3232fe7a..c8622b88 100644
--- a/README.md
+++ b/README.md
@@ -67,7 +67,7 @@ And then browse to
- Distributed computing using NCCL.
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
- Embed user-defined ops/kernels, such as [flash-attention
- v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
+ v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
<!--- ANCHOR_END: features --->
@@ -98,8 +98,9 @@ Cheatsheet:
- [candle-nn](./candle-nn/): Facilities to build real models
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
- [candle-kernels](./candle-kernels/): CUDA custom kernels
-
-
+- [candle-datasets](./candle-datasets/): Datasets and data loaders.
+- [candle-transformers](./candle-transformers): Transformer related utilities.
+- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
## FAQ
diff --git a/candle-book/src/guide/hello_world.md b/candle-book/src/guide/hello_world.md
index a9767d5f..5b32181d 100644
--- a/candle-book/src/guide/hello_world.md
+++ b/candle-book/src/guide/hello_world.md
@@ -128,17 +128,17 @@ fn main() -> Result<()> {
```
Now it works, it is a great way to create your own layers.
-But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
+But most of the classical layers are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
## Using `candle_nn`.
-For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
+For instance [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs) is already there.
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
So instead we can simplify our example:
```bash
-cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
+cargo add --git https://github.com/huggingface/candle.git candle-nn
```
And rewrite our examples using it
diff --git a/candle-book/src/guide/installation.md b/candle-book/src/guide/installation.md
index c909a5df..d2086e0c 100644
--- a/candle-book/src/guide/installation.md
+++ b/candle-book/src/guide/installation.md
@@ -5,13 +5,13 @@ Start by creating a new app:
```bash
cargo new myapp
cd myapp
-cargo add --git https://github.com/LaurentMazare/candle.git candle
+cargo add --git https://github.com/huggingface/candle.git candle-core
```
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
```bash
-cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
+cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
```
You can check everything works properly:
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 5a59aedc..af77a0e0 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -10,6 +10,7 @@ license.workspace = true
readme = "README.md"
[dependencies]
+accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
cudarc = { workspace = true, optional = true }
@@ -32,3 +33,4 @@ anyhow = { workspace = true }
default = []
cuda = ["dep:cudarc", "dep:candle-kernels"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
+accelerate = ["dep:libc", "dep:accelerate-src"]
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index d028db66..efce913a 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -1,29 +1,18 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
- let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
- let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
- let c = a.matmul(&b)?;
- println!("{a} {b} {c}");
-
- let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
- let t1 = Tensor::new(data, &Device::Cpu)?;
- let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
- let t2 = Tensor::new(data2, &Device::Cpu)?;
- assert_eq!(
- Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
- .t()?
- .to_vec2::<f32>()?,
- [
- [3.0, 1.0, 4.0, 1.0, 5.0],
- [2.0, 7.0, 1.0, 8.0, 2.0],
- [5.0, 5.0, 5.0, 5.0, 5.0],
- [2.0, 7.0, 1.0, 8.0, 2.0]
- ]
- );
+ let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
+ let start = std::time::Instant::now();
+ let res = inp.conv2d(&w, 0, 1);
+ println!("{:?}", start.elapsed());
+ println!("{res:?}");
Ok(())
}
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs
new file mode 100644
index 00000000..8b0df5c1
--- /dev/null
+++ b/candle-core/src/accelerate.rs
@@ -0,0 +1,111 @@
+#![allow(dead_code)]
+use libc::{c_char, c_double, c_float, c_int};
+
+mod ffi {
+ use super::*;
+ extern "C" {
+ // It would be nice to be able to switch to the NEWLAPACK version of the function but this
+ // seems to trigger some link error. Available function names can be seen here:
+ // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
+ #[link_name = "sgemm_"]
+ pub fn sgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_float,
+ a: *const c_float,
+ lda: *const c_int,
+ b: *const c_float,
+ ldb: *const c_int,
+ beta: *const c_float,
+ c: *mut c_float,
+ ldc: *const c_int,
+ );
+ #[link_name = "dgemm_"]
+ pub fn dgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_double,
+ a: *const c_double,
+ lda: *const c_int,
+ b: *const c_double,
+ ldb: *const c_int,
+ beta: *const c_double,
+ c: *mut c_double,
+ ldc: *const c_int,
+ );
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn sgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f32,
+ a: &[f32],
+ lda: i32,
+ b: &[f32],
+ ldb: i32,
+ beta: f32,
+ c: &mut [f32],
+ ldc: i32,
+) {
+ ffi::sgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn dgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f64,
+ a: &[f64],
+ lda: i32,
+ b: &[f64],
+ ldb: i32,
+ beta: f64,
+ c: &mut [f64],
+ ldc: i32,
+) {
+ ffi::dgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 345db0e5..a8e5ac52 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,6 +37,17 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self>;
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
+
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f5cc8191..0eab508e 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
+ | Op::Conv2D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@@ -81,6 +86,8 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest2D(node)
+ | Op::AvgPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -163,6 +170,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
+ Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest2d",
+ })?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@@ -291,6 +303,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
+ Op::Unary(arg, UnaryOp::Recip) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let grad = (grad / arg.sqr()?)?;
+ *sum_grad = sum_grad.sub(&grad)?
+ }
&Op::Narrow(ref arg, dim, start_idx, len) => {
let arg_dims = arg.dims();
let left_pad = if start_idx == 0 {
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 4cf9d0ad..30799459 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -25,3 +25,32 @@ impl ParamsConv1D {
}
}
}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConv2D {
+ pub(crate) b_size: usize,
+ pub(crate) i_h: usize,
+ pub(crate) i_w: usize,
+ pub(crate) k_h: usize,
+ pub(crate) k_w: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv2D {
+ pub(crate) fn out_h(&self) -> usize {
+ let dilation = 1;
+ (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_w(&self) -> usize {
+ let dilation = 1;
+ (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec<usize> {
+ vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
+ }
+}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 8563721c..10c6cc4a 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -633,6 +633,84 @@ impl Map1 for Affine {
}
}
+struct AvgPool2D((usize, usize), (usize, usize));
+
+impl Map1 for AvgPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ let scale = 1f64 / (k_h * k_w) as f64;
+ let scale = T::from_f64(scale);
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut sum = T::zero();
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = k_h * h_idx + m;
+ let n = k_w * w_idx + n;
+ sum += src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ dst[h_idx * w_out + w_idx] = sum * scale;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct UpsampleNearest2D(usize, usize);
+
+impl Map1 for UpsampleNearest2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // TODO: Specialized implementation for the case 2*h, 2*w?
+ let (dst_h, dst_w) = (self.0, self.1);
+ let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let src_index = layout.start_offset();
+ let scale_h = src_h as f64 / dst_h as f64;
+ let scale_w = src_w as f64 / dst_w as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
+ let src_h_idxs = (0..src_h)
+ .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
+ .collect::<Vec<_>>();
+ let src_w_idxs = (0..src_w)
+ .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
+ .collect::<Vec<_>>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_h * dst_w..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_h * dst_w..];
+ let src_index = src_index + c_idx * stride[1];
+ for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
+ for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
+ let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
+ dst[h_idx * dst_w + w_idx] = src[src_index]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
@@ -921,7 +999,6 @@ impl<'a> Map2 for Conv1D<'a> {
(0, inp_stride) // This value never gets used anyway
};
let k_stride = k_l.stride();
- let k_over_2 = p.k_size / 2;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems];
@@ -935,18 +1012,16 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + dst_l;
let mut d = T::zero();
for offset in 0..p.k_size {
- let src_l_plus = p.stride * dst_l + offset;
- // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
- if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
- let src_l = src_l_plus - k_over_2;
- for src_c_idx in 0..p.c_in {
- let inp_idx =
- inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset * k_stride[2];
- d += inp[inp_idx] * k[k_idx]
- }
+ let src_l = (p.stride * dst_l + offset)
+ .saturating_sub(p.padding)
+ .min(p.l_in - 1);
+ for src_c_idx in 0..p.c_in {
+ let inp_idx =
+ inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset * k_stride[2];
+ d += inp[inp_idx] * k[k_idx]
}
}
dst[dst_idx] = d
@@ -957,6 +1032,65 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
+
+impl<'a> Map2 for Conv2D<'a> {
+ const OP: &'static str = "conv2d";
+ fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
+ &self,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+ ) -> Result<Vec<T>> {
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let inp_stride = inp_l.stride();
+ let k = &k[k_l.start_offset()..];
+ let k_stride = k_l.stride();
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+ for b_idx in 0..p.b_size {
+ let inp_idx = b_idx * inp_stride[0];
+ let dst_idx = b_idx * p.c_out * out_h * out_w;
+ for dst_c_idx in 0..p.c_out {
+ let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
+ for dst_h in 0..out_h {
+ let dst_idx = dst_idx + dst_h * out_w;
+ for dst_w in 0..out_w {
+ let dst_idx = dst_idx + dst_w;
+ let mut d = T::zero();
+ for offset_h in 0..p.k_h {
+ let src_h = (p.stride * dst_h + offset_h)
+ .saturating_sub(p.padding)
+ .min(p.i_h - 1);
+ for offset_w in 0..p.k_w {
+ let src_w = (p.stride * dst_w + offset_w)
+ .saturating_sub(p.padding)
+ .min(p.i_w - 1);
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = inp_idx
+ + src_c_idx * inp_stride[1]
+ + src_h * inp_stride[2]
+ + src_w * inp_stride[3];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset_h * k_stride[2]
+ + offset_w * k_stride[3];
+ d += inp[inp_idx] * k[k_idx]
+ }
+ }
+ }
+ dst[dst_idx] = d
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct MatMul((usize, usize, usize, usize));
impl MatMul {
@@ -974,7 +1108,7 @@ impl MatMul {
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
- #[cfg(not(feature = "mkl"))]
+ #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
@@ -1053,6 +1187,109 @@ impl Map2 for MatMul {
Ok(dst)
}
+ #[cfg(feature = "accelerate")]
+ fn f<T: 'static + WithDType + num_traits::Num + Copy>(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result<Vec<T>> {
+ let (b, m, n, k) = self.0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+
+ let a_skip: usize = match lhs_stride[..rank - 2] {
+ [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
+ [stride] => stride,
+ [] => m * k,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
+ };
+ let b_skip: usize = match rhs_stride[..rank - 2] {
+ [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
+ [stride] => stride,
+ [] => n * k,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
+ };
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ let mut dst = vec![T::zero(); b * m * n];
+ match T::DTYPE {
+ DType::F16 => {
+ crate::bail!("the accelerate backend does not support f16 matmul")
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::sgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::dgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(dst)
+ }
+
#[cfg(feature = "mkl")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
@@ -1426,6 +1663,19 @@ impl BackendStorage for CpuStorage {
Affine(mul, add).map(self, layout)
}
+ fn avg_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ AvgPool2D(kernel_size, stride).map(self, layout)
+ }
+
+ fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
+ UpsampleNearest2D(h, w).map(self, layout)
+ }
+
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
@@ -1612,6 +1862,16 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
+ fn conv2d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ Conv2D(params).map(self, l, kernel, kernel_l)
+ }
+
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 7b4b358d..727ea073 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1381,6 +1381,24 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
+ todo!()
+ }
+
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 563d892b..65232839 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -101,6 +101,13 @@ impl Device {
}
}
+ pub fn is_cpu(&self) -> bool {
+ match self {
+ Self::Cpu => true,
+ Self::Cuda(_) => false,
+ }
+ }
+
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 0e906119..92929748 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -43,7 +43,7 @@ impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
- Self::U8 => 4,
+ Self::U8 => 1,
Self::U32 => 4,
Self::BF16 => 2,
Self::F16 => 2,
@@ -53,7 +53,9 @@ impl DType {
}
}
-pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
+pub trait WithDType:
+ Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static
+{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 17d4a22e..ae4dd09f 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn conv2d(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@@ -119,6 +129,14 @@ impl crate::backend::BackendStorage for CudaStorage {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
}
impl crate::backend::BackendDevice for CudaDevice {
diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs
new file mode 100644
index 00000000..4a5d4fa0
--- /dev/null
+++ b/candle-core/src/ggml.rs
@@ -0,0 +1,582 @@
+//! Support for the GGML file format.
+
+use crate::{DType, Device, Result, Tensor};
+use byteorder::{LittleEndian, ReadBytesExt};
+use half::f16;
+
+// Default to QK_K 256 rather than 64.
+pub const QK_K: usize = 256;
+pub const K_SCALE_SIZE: usize = 12;
+
+pub const QK4_0: usize = 32;
+pub const QK4_1: usize = 32;
+pub const QK5_0: usize = 32;
+pub const QK5_1: usize = 32;
+pub const QK8_0: usize = 32;
+pub const QK8_1: usize = 32;
+
+#[repr(C)]
+struct BlockQ4_0 {
+ d: f16,
+ qs: [u8; QK4_0 / 2],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
+
+#[repr(C)]
+struct BlockQ4_1 {
+ d: f16,
+ m: f16,
+ qs: [u8; QK4_1 / 2],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
+
+#[repr(C)]
+struct BlockQ5_0 {
+ d: f16,
+ qh: [u8; 4],
+ qs: [u8; QK5_0 / 2],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
+
+#[repr(C)]
+struct BlockQ5_1 {
+ d: f16,
+ m: f16,
+ qh: [u8; 4],
+ qs: [u8; QK5_1 / 2],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
+
+#[repr(C)]
+struct BlockQ8_0 {
+ d: f16,
+ qs: [u8; QK8_0],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
+
+#[repr(C)]
+struct BlockQ8_1 {
+ d: f16,
+ s: f16,
+ qs: [u8; QK8_1],
+}
+const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
+
+#[repr(C)]
+struct BlockQ2K {
+ scales: [u8; QK_K / 16],
+ qs: [u8; QK_K / 4],
+ d: f16,
+ dmin: f16,
+}
+const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
+
+#[repr(C)]
+struct BlockQ3K {
+ hmask: [u8; QK_K / 8],
+ qs: [u8; QK_K / 4],
+ scales: [u8; 12],
+ d: f16,
+}
+const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
+
+// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
+#[repr(C)]
+struct BlockQ4K {
+ d: f16,
+ dmin: f16,
+ scales: [u8; K_SCALE_SIZE],
+ qs: [u8; QK_K / 2],
+}
+const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
+
+#[repr(C)]
+struct BlockQ5K {
+ d: f16,
+ dmin: f16,
+ scales: [u8; K_SCALE_SIZE],
+ qh: [u8; QK_K / 8],
+ qs: [u8; QK_K / 2],
+}
+const _: () =
+ assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
+
+#[repr(C)]
+struct BlockQ6K {
+ ql: [u8; QK_K / 2],
+ qh: [u8; QK_K / 4],
+ scales: [i8; QK_K / 16],
+ d: f16,
+}
+const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
+
+// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
+fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
+ let k = ys.len();
+ if k % QK_K != 0 {
+ crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
+ }
+ let mut ys_index = 0;
+ for x in xs {
+ let d = x.d.to_f32();
+ let min = x.dmin.to_f32();
+ let q = &x.qs;
+
+ let mut is = 0;
+ for n in (0..QK_K).step_by(128) {
+ // Step by 32 over q.
+ let q = &q[n / 4..];
+ let mut shift = 0;
+ for _j in 0..4 {
+ let sc = x.scales[is];
+ is += 1;
+ let dl = d * (sc & 0xF) as f32;
+ let ml = min * (sc >> 4) as f32;
+ for q in &q[..16] {
+ let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+
+ let sc = x.scales[is];
+ is += 1;
+ let dl = d * (sc & 0xF) as f32;
+ let ml = min * (sc >> 4) as f32;
+ for q in &q[16..32] {
+ let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+
+ shift += 2;
+ }
+ }
+ }
+ Ok(())
+}
+
+fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
+ if j < 4 {
+ let d = q[j] & 63;
+ let m = q[j + 4] & 63;
+ (d, m)
+ } else {
+ let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+ let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
+ (d, m)
+ }
+}
+// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
+fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
+ let k = ys.len();
+ if k % QK_K != 0 {
+ crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
+ }
+ let mut ys_index = 0;
+ for x in xs.iter() {
+ let d = x.d.to_f32();
+ let min = x.dmin.to_f32();
+ let q = &x.qs;
+ let mut is = 0;
+ for j in (0..QK_K).step_by(64) {
+ let q = &q[j / 2..j / 2 + 32];
+ let (sc, m) = get_scale_min_k4(is, &x.scales);
+ let d1 = d * sc as f32;
+ let m1 = min * m as f32;
+ let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
+ let d2 = d * sc as f32;
+ let m2 = min * m as f32;
+ for q in q {
+ let y = d1 * (q & 0xF) as f32 - m1;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+ for q in q {
+ let y = d2 * (q >> 4) as f32 - m2;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+ is += 2;
+ }
+ }
+ Ok(())
+}
+
+// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
+fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> {
+ todo!()
+}
+
+// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
+fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
+ let k = ys.len();
+ if k % QK_K != 0 {
+ crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
+ }
+ let mut ys_index = 0;
+ for x in xs.iter() {
+ let d = x.d.to_f32();
+ let min = x.dmin.to_f32();
+ let ql = &x.qs;
+ let qh = &x.qh;
+ let mut is = 0;
+ let mut u1 = 1;
+ let mut u2 = 2;
+ for j in (0..QK_K).step_by(64) {
+ let ql = &ql[j / 2..j / 2 + 32];
+ let (sc, m) = get_scale_min_k4(is, &x.scales);
+ let d1 = d * sc as f32;
+ let m1 = min * m as f32;
+ let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
+ let d2 = d * sc as f32;
+ let m2 = min * m as f32;
+ for (ql, qh) in ql.iter().zip(qh) {
+ let to_add = if qh & u1 != 0 { 16 } else { 1 };
+ let y = d1 * ((ql & 0xF) + to_add) as f32 - m1;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+ for (ql, qh) in ql.iter().zip(qh) {
+ let to_add = if qh & u2 != 0 { 16 } else { 1 };
+ let y = d2 * ((ql >> 4) + to_add) as f32 - m2;
+ ys[ys_index] = y;
+ ys_index += 1;
+ }
+ is += 2;
+ u1 <<= 2;
+ u2 <<= 2;
+ }
+ }
+ Ok(())
+}
+
+// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
+fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
+ let k = ys.len();
+ if k % QK_K != 0 {
+ crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
+ }
+ for x in xs.iter() {
+ let d = x.d.to_f32();
+ let ql = &x.ql;
+ let qh = &x.qh;
+ let sc = &x.scales;
+ for n in (0..QK_K).step_by(128) {
+ let idx = n / 128;
+ let ys = &mut ys[n..];
+ let sc = &sc[8 * idx..];
+ let ql = &ql[64 * idx..];
+ let qh = &qh[32 * idx..];
+ for l in 0..32 {
+ let is = l / 16;
+ let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
+ let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
+ let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
+ let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
+ ys[l] = d * sc[is] as f32 * q1 as f32;
+ ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
+ ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
+ ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
+ }
+ }
+ }
+ Ok(())
+}
+
+// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum Magic {
+ Ggjt,
+ Ggla,
+ Ggmf,
+ Ggml,
+ Ggsn,
+}
+
+impl TryFrom<u32> for Magic {
+ type Error = crate::Error;
+ fn try_from(value: u32) -> Result<Self> {
+ let magic = match value {
+ 0x67676a74 => Self::Ggjt,
+ 0x67676c61 => Self::Ggla,
+ 0x67676d66 => Self::Ggmf,
+ 0x67676d6c => Self::Ggml,
+ 0x6767736e => Self::Ggsn,
+ _ => crate::bail!("unknown magic {value:08x}"),
+ };
+ Ok(magic)
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum VersionedMagic {
+ GgmlUnversioned,
+ GgmfV1,
+ GgjtV1,
+ GgjtV2,
+ GgjtV3,
+}
+
+impl VersionedMagic {
+ fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
+ let magic = reader.read_u32::<LittleEndian>()?;
+ let magic = Magic::try_from(magic)?;
+ if magic == Magic::Ggml {
+ return Ok(Self::GgmlUnversioned);
+ }
+ let version = reader.read_u32::<LittleEndian>()?;
+ let versioned_magic = match (magic, version) {
+ (Magic::Ggmf, 1) => Self::GgmfV1,
+ (Magic::Ggjt, 1) => Self::GgjtV1,
+ (Magic::Ggjt, 2) => Self::GgjtV2,
+ (Magic::Ggjt, 3) => Self::GgjtV3,
+ _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
+ };
+ Ok(versioned_magic)
+ }
+
+ fn align32(&self) -> bool {
+ match self {
+ Self::GgmlUnversioned | Self::GgmfV1 => false,
+ Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct HParams {
+ pub n_vocab: u32,
+ pub n_embd: u32,
+ pub n_mult: u32,
+ pub n_head: u32,
+ pub n_layer: u32,
+ pub n_rot: u32,
+ pub ftype: u32,
+}
+
+impl HParams {
+ fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
+ let n_vocab = reader.read_u32::<LittleEndian>()?;
+ let n_embd = reader.read_u32::<LittleEndian>()?;
+ let n_mult = reader.read_u32::<LittleEndian>()?;
+ let n_head = reader.read_u32::<LittleEndian>()?;
+ let n_layer = reader.read_u32::<LittleEndian>()?;
+ let n_rot = reader.read_u32::<LittleEndian>()?;
+ let ftype = reader.read_u32::<LittleEndian>()?;
+ Ok(Self {
+ n_vocab,
+ n_embd,
+ n_mult,
+ n_head,
+ n_layer,
+ n_rot,
+ ftype,
+ })
+ }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct Vocab {
+ pub token_score_pairs: Vec<(Vec<u8>, f32)>,
+}
+
+impl Vocab {
+ fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
+ // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
+ let mut token_score_pairs = Vec::with_capacity(n_vocab);
+ for _index in 0..n_vocab {
+ let len = reader.read_u32::<LittleEndian>()? as usize;
+ let mut word = vec![0u8; len];
+ reader.read_exact(&mut word)?;
+ let score = reader.read_f32::<LittleEndian>()?;
+ token_score_pairs.push((word, score))
+ }
+ Ok(Self { token_score_pairs })
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum GgmlDType {
+ F32,
+ F16,
+ Q4_0,
+ Q4_1,
+ Q5_0,
+ Q5_1,
+ Q8_0,
+ Q8_1,
+ Q2K,
+ Q3K,
+ Q4K,
+ Q5K,
+ Q6K,
+}
+
+impl GgmlDType {
+ fn from_u32(u: u32) -> Result<Self> {
+ let dtype = match u {
+ 0 => Self::F32,
+ 1 => Self::F16,
+ 2 => Self::Q4_0,
+ 3 => Self::Q4_1,
+ 6 => Self::Q5_0,
+ 7 => Self::Q5_1,
+ 8 => Self::Q8_0,
+ 9 => Self::Q8_1,
+ 10 => Self::Q2K,
+ 11 => Self::Q3K,
+ 12 => Self::Q4K,
+ 13 => Self::Q5K,
+ 14 => Self::Q6K,
+ _ => crate::bail!("unknown dtype for tensor {u}"),
+ };
+ Ok(dtype)
+ }
+
+ fn type_size(&self) -> usize {
+ match self {
+ Self::F32 => 4,
+ Self::F16 => 2,
+ Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
+ Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
+ Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
+ Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
+ // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
+ Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
+ Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
+ Self::Q2K => std::mem::size_of::<BlockQ2K>(),
+ Self::Q3K => std::mem::size_of::<BlockQ3K>(),
+ Self::Q4K => std::mem::size_of::<BlockQ4K>(),
+ Self::Q5K => std::mem::size_of::<BlockQ5K>(),
+ Self::Q6K => std::mem::size_of::<BlockQ6K>(),
+ }
+ }
+
+ fn blck_size(&self) -> usize {
+ match self {
+ Self::F32 => 1,
+ Self::F16 => 1,
+ Self::Q4_0 => QK4_0,
+ Self::Q4_1 => QK4_1,
+ Self::Q5_0 => QK5_0,
+ Self::Q5_1 => QK5_1,
+ Self::Q8_0 => QK8_0,
+ Self::Q8_1 => QK8_1,
+ Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Content {
+ pub magic: VersionedMagic,
+ pub hparams: HParams,
+ pub vocab: Vocab,
+ pub tensors: Vec<(String, Tensor)>,
+}
+
+fn read_one_tensor<R: std::io::Seek + std::io::Read>(
+ reader: &mut R,
+ magic: VersionedMagic,
+ device: &Device,
+) -> Result<(String, Tensor)> {
+ let n_dims = reader.read_u32::<LittleEndian>()?;
+ let name_len = reader.read_u32::<LittleEndian>()?;
+ let dtype = reader.read_u32::<LittleEndian>()?;
+ let dtype = GgmlDType::from_u32(dtype)?;
+ let mut dims = vec![0u32; n_dims as usize];
+ reader.read_u32_into::<LittleEndian>(&mut dims)?;
+ let mut name = vec![0u8; name_len as usize];
+ reader.read_exact(&mut name)?;
+ let name = String::from_utf8_lossy(&name).into_owned();
+
+ if magic.align32() {
+ let pos = reader.stream_position()?;
+ reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
+ }
+ let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
+ let tensor_elems = dims.iter().product::<usize>();
+ let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
+ println!("{name} {dtype:?} {dims:?}");
+ // TODO: Mmap version to avoid copying the data around?
+ let mut raw_data = vec![0u8; size_in_bytes];
+ reader.read_exact(&mut raw_data)?;
+ let tensor = match dtype {
+ GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
+ GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
+ GgmlDType::Q2K => {
+ let mut f32_data = vec![0f32; tensor_elems];
+ let raw_data_ptr = raw_data.as_ptr();
+ let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
+ let raw_data =
+ unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
+ dequantize_row_q2k(raw_data, &mut f32_data)?;
+ // Maybe we should use bf16 instead?
+ Tensor::from_vec(f32_data, dims, device)?
+ }
+ GgmlDType::Q3K => {
+ let mut f32_data = vec![0f32; tensor_elems];
+ let raw_data_ptr = raw_data.as_ptr();
+ let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ3K>();
+ let raw_data =
+ unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) };
+ dequantize_row_q3k(raw_data, &mut f32_data)?;
+ Tensor::from_vec(f32_data, dims, device)?
+ }
+ GgmlDType::Q4K => {
+ let mut f32_data = vec![0f32; tensor_elems];
+ let raw_data_ptr = raw_data.as_ptr();
+ let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
+ let raw_data =
+ unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
+ dequantize_row_q4k(raw_data, &mut f32_data)?;
+ Tensor::from_vec(f32_data, dims, device)?
+ }
+ GgmlDType::Q5K => {
+ let mut f32_data = vec![0f32; tensor_elems];
+ let raw_data_ptr = raw_data.as_ptr();
+ let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ5K>();
+ let raw_data =
+ unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) };
+ dequantize_row_q5k(raw_data, &mut f32_data)?;
+ Tensor::from_vec(f32_data, dims, device)?
+ }
+ GgmlDType::Q6K => {
+ let mut f32_data = vec![0f32; tensor_elems];
+ let raw_data_ptr = raw_data.as_ptr();
+ let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
+ let raw_data =
+ unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
+ dequantize_row_q6k(raw_data, &mut f32_data)?;
+ Tensor::from_vec(f32_data, dims, device)?
+ }
+ _ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
+ };
+ Ok((name, tensor))
+}
+
+impl Content {
+ pub fn read<R: std::io::Seek + std::io::Read>(
+ reader: &mut R,
+ device: &Device,
+ ) -> Result<Content> {
+ // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
+ let last_position = reader.seek(std::io::SeekFrom::End(0))?;
+ reader.seek(std::io::SeekFrom::Start(0))?;
+ let magic = VersionedMagic::read(reader)?;
+ let hparams = HParams::read(reader)?;
+ let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
+ let mut tensors = vec![];
+
+ while reader.stream_position()? != last_position {
+ let (name, tensor) = read_one_tensor(reader, magic, device)?;
+ tensors.push((name, tensor))
+ }
+ Ok(Self {
+ magic,
+ hparams,
+ vocab,
+ tensors,
+ })
+ }
+}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index c374d245..016d3806 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -33,6 +33,8 @@
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
+#[cfg(feature = "accelerate")]
+mod accelerate;
pub mod backend;
pub mod backprop;
mod conv;
@@ -45,6 +47,7 @@ pub mod display;
mod dtype;
mod dummy_cuda_backend;
pub mod error;
+pub mod ggml;
mod indexer;
pub mod layout;
#[cfg(feature = "mkl")]
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index ba8d2fb4..aea8b733 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -51,6 +51,7 @@ pub enum UnaryOp {
Cos,
Abs,
Neg,
+ Recip,
Sqr,
Sqrt,
Gelu,
@@ -79,6 +80,21 @@ pub enum Op {
stride: usize,
},
+ #[allow(dead_code)]
+ Conv2D {
+ arg: Tensor,
+ kernel: Tensor,
+ padding: usize,
+ stride: usize,
+ },
+
+ AvgPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+ UpsampleNearest2D(Tensor),
+
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
@@ -264,6 +280,7 @@ pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
+pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
@@ -410,6 +427,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
+unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 1e1ef305..3ed38e6a 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -266,6 +266,64 @@ impl Storage {
}
}
+ pub(crate) fn conv2d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ self.same_device(kernel, "conv2d")?;
+ self.same_dtype(kernel, "conv2d")?;
+ match (self, &kernel) {
+ (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
+ let s = inp.conv2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cpu(s))
+ }
+ (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
+ let s = inp.conv2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cuda(s))
+ }
+ (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
+ op: "conv2d",
+ }
+ .bt()),
+ }
+ }
+
+ pub(crate) fn avg_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
+ pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.upsample_nearest2d(layout, h, w)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.upsample_nearest2d(layout, h, w)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn where_cond(
&self,
layout: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index b958e06d..adba7376 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -269,6 +269,10 @@ impl Tensor {
Self::rand_impl(lo, up, s, device, false)
}
+ pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
+ Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
+ }
+
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
@@ -296,6 +300,17 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}
+ pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
+ Tensor::randn_f64_impl(
+ mean,
+ stdev,
+ self.shape(),
+ self.dtype(),
+ self.device(),
+ false,
+ )
+ }
+
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
@@ -474,6 +489,7 @@ impl Tensor {
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
+ unary_op!(recip, Recip);
unary_op!(neg, Neg);
unary_op!(exp, Exp);
unary_op!(log, Log);
@@ -548,6 +564,32 @@ impl Tensor {
}
}
+ /// Split a tensor into the specified number of chunks, this may return less chunks than
+ /// specificed.
+ pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
+ let dim = dim.to_index(self.shape(), "chunk")?;
+ let size = self.dim(dim)?;
+ if size < chunks {
+ (0..size).map(|i| self.narrow(dim, i, 1)).collect()
+ } else {
+ let chunk_size = size / chunks;
+ let cnt_additional = size % chunks;
+ let mut tensors = vec![];
+ let mut sum_chunk_size = 0;
+ for i in 0..chunks {
+ let chunk_size = if i < cnt_additional {
+ chunk_size + 1
+ } else {
+ chunk_size
+ };
+ let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
+ tensors.push(tensor);
+ sum_chunk_size += chunk_size
+ }
+ Ok(tensors)
+ }
+ }
+
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
@@ -775,6 +817,61 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
+ pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ let params = crate::conv::ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
+ }
+
+ pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ let (n, c, _h, _w) = self.dims4()?;
+ let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
+ let storage = self
+ .storage()
+ .upsample_nearest2d(self.layout(), target_h, target_w)?;
+ Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
+ }
+
+ pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let (n, c, h, w) = self.dims4()?;
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
+ let h_out = (h - kernel_size.0) / stride.0 + 1;
+ let w_out = (w - kernel_size.1) / stride.1 + 1;
+ let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
+ arg,
+ kernel_size,
+ stride,
+ });
+ let storage = self
+ .storage()
+ .avg_pool2d(self.layout(), kernel_size, stride)?;
+ Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
+ }
+
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
@@ -1717,6 +1814,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
+ if left == 0 && right == 0 {
+ Ok(self.clone())
+ } else if left == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[self, &right], dim)
+ } else if right == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self], dim)
+ } else {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self, &right], dim)
+ }
+ }
+
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs
index 895c97e1..d3f5b50e 100644
--- a/candle-core/src/utils.rs
+++ b/candle-core/src/utils.rs
@@ -11,16 +11,14 @@ pub fn get_num_threads() -> usize {
}
}
+pub fn has_accelerate() -> bool {
+ cfg!(feature = "accelerate")
+}
+
pub fn has_mkl() -> bool {
- #[cfg(feature = "mkl")]
- return true;
- #[cfg(not(feature = "mkl"))]
- return false;
+ cfg!(feature = "mkl")
}
pub fn cuda_is_available() -> bool {
- #[cfg(feature = "cuda")]
- return true;
- #[cfg(not(feature = "cuda"))]
- return false;
+ cfg!(feature = "cuda")
}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
new file mode 100644
index 00000000..7ec83592
--- /dev/null
+++ b/candle-core/tests/conv_tests.rs
@@ -0,0 +1,184 @@
+mod test_utils;
+use anyhow::Result;
+use candle_core::{Device, Tensor};
+
+/* This test is based on the following script.
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 4, 5))
+w = torch.randn((2, 4, 3))
+print(t.flatten())
+print(w.flatten())
+res = torch.nn.functional.conv1d(t, w)
+print(res.flatten())
+*/
+#[test]
+fn conv1d() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
+ 1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
+ ],
+ dev,
+ )?
+ .reshape((1, 4, 5))?;
+ let w = Tensor::new(
+ &[
+ -0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
+ -1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
+ -0.6451, -0.0840, -1.4247, 0.5512,
+ ],
+ dev,
+ )?
+ .reshape((2, 4, 3))?;
+ let res = t.conv1d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 2, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
+ );
+ let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ assert_eq!(res.dims(), [1, 2, 5]);
+ /* Note that the default for padding is different from PyTorch at the moment: instead of
+ padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to:
+ t = torch.nn.functional.pad(t, (1, 1), mode='replicate')
+ res = torch.nn.functional.conv1d(t, w, padding=0)
+ */
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562]
+ );
+ Ok(())
+}
+
+#[test]
+fn conv1d_small() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
+ let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
+ let res = t.conv1d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 1, 2]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [0.4056, -0.8689]
+ );
+ let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ assert_eq!(res.dims(), [1, 1, 4]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [0.4056, 0.4056, -0.8689, -0.0773],
+ );
+ Ok(())
+}
+
+/* This test is based on the following script.
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 4, 5, 5))
+w = torch.randn((2, 4, 3, 3))
+print(t.flatten())
+print(w.flatten())
+res = torch.nn.functional.conv2d(t, w)
+print(res.flatten())
+*/
+#[test]
+fn conv2d() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
+ 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
+ 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
+ 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
+ 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
+ 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
+ 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
+ 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
+ -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
+ -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
+ ],
+ dev,
+ )?;
+ let w = Tensor::new(
+ &[
+ -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
+ -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
+ -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
+ 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
+ 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
+ -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
+ 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
+ 0.5583, 0.4623, 0.6026,
+ ],
+ dev,
+ )?;
+ let t = t.reshape((1, 4, 5, 5))?;
+ let w = w.reshape((2, 4, 3, 3))?;
+ let res = t.conv2d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 2, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [
+ -4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
+ 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
+ ]
+ );
+ Ok(())
+}
+
+/* This test is based on the following script.
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 2, 3, 3))
+w = torch.randn((1, 2, 1, 1))
+print(t.flatten())
+print(w.flatten())
+res = torch.nn.functional.conv2d(t, w)
+print(res.flatten())
+*/
+#[test]
+fn conv2d_small() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
+ -0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
+ ],
+ dev,
+ )?;
+ let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
+ let t = t.reshape((1, 2, 3, 3))?;
+ let w = w.reshape((1, 2, 1, 1))?;
+ let res = t.conv2d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 1, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
+ );
+ Ok(())
+}
+
+#[test]
+fn conv2d_smaller() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
+ ],
+ dev,
+ )?;
+ let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
+ let t = t.reshape((1, 1, 3, 3))?;
+ let w = w.reshape((1, 1, 3, 3))?;
+ let res = t.conv2d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 1, 1, 1]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [-0.6197]
+ );
+ Ok(())
+}
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
new file mode 100644
index 00000000..574182ca
--- /dev/null
+++ b/candle-core/tests/pool_tests.rs
@@ -0,0 +1,17 @@
+mod test_utils;
+use candle_core::{Device, Tensor};
+
+// https://github.com/huggingface/candle/issues/364
+#[test]
+fn avg_pool2d() -> anyhow::Result<()> {
+ let device = Device::Cpu;
+
+ let data: Vec<f32> = vec![
+ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
+
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
+ Ok(())
+}
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
new file mode 100644
index 00000000..12169daf
--- /dev/null
+++ b/candle-datasets/Cargo.toml
@@ -0,0 +1,20 @@
+[package]
+name = "candle-datasets"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+readme = "README.md"
+
+[dependencies]
+byteorder = { workspace = true }
+candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.1.0" }
+hf-hub = { workspace = true}
+intel-mkl-src = { workspace = true, optional = true }
+memmap2 = { workspace = true }
+tokenizers = { workspace = true, features = ["onig"] }
+rand = { workspace = true }
diff --git a/candle-nn/src/dataset.rs b/candle-datasets/src/batcher.rs
index b74f1417..b74f1417 100644
--- a/candle-nn/src/dataset.rs
+++ b/candle-datasets/src/batcher.rs
diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs
new file mode 100644
index 00000000..42ad5d62
--- /dev/null
+++ b/candle-datasets/src/lib.rs
@@ -0,0 +1,6 @@
+//! Datasets & Dataloaders for Candle
+pub mod batcher;
+pub mod nlp;
+pub mod vision;
+
+pub use batcher::Batcher;
diff --git a/candle-datasets/src/nlp/mod.rs b/candle-datasets/src/nlp/mod.rs
new file mode 100644
index 00000000..42e9d288
--- /dev/null
+++ b/candle-datasets/src/nlp/mod.rs
@@ -0,0 +1 @@
+pub mod tinystories;
diff --git a/candle-datasets/src/nlp/tinystories.rs b/candle-datasets/src/nlp/tinystories.rs
new file mode 100644
index 00000000..c657c9eb
--- /dev/null
+++ b/candle-datasets/src/nlp/tinystories.rs
@@ -0,0 +1,122 @@
+//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
+//! by the tools from https://github.com/karpathy/llama2.c
+use candle::{Device, Result, Tensor};
+
+pub struct Dataset {
+ valid_tokens: Vec<memmap2::Mmap>,
+ train_tokens: Vec<memmap2::Mmap>,
+}
+
+fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
+ let file = std::fs::File::open(p)?;
+ let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
+ Ok(mmap)
+}
+
+impl Dataset {
+ pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
+ let dir = dir.as_ref();
+ let mut bin_files = vec![];
+ for file in std::fs::read_dir(dir)?.flatten() {
+ let file = file.path();
+ if let Some(extension) = file.extension() {
+ if extension == "bin" {
+ bin_files.push(file)
+ }
+ }
+ }
+ if bin_files.len() < 2 {
+ candle::bail!("found less than two bin files in {:?}", dir)
+ }
+ bin_files.sort();
+ let valid_tokens = mmap_file(&bin_files[0])?;
+ let train_tokens = bin_files[1..]
+ .iter()
+ .map(mmap_file)
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ valid_tokens: vec![valid_tokens],
+ train_tokens,
+ })
+ }
+
+ pub fn train_tokens(&self) -> usize {
+ self.train_tokens.len()
+ }
+
+ pub fn valid_tokens(&self) -> usize {
+ self.valid_tokens.len()
+ }
+}
+
+pub struct DatasetRandomIter<'a> {
+ all_tokens: &'a [memmap2::Mmap],
+ tokens: Vec<&'a memmap2::Mmap>,
+ current_tokens: &'a memmap2::Mmap,
+ indexes_in_bytes: Vec<usize>,
+ seq_len: usize,
+ device: Device,
+}
+
+impl<'a> DatasetRandomIter<'a> {
+ pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
+ use rand::seq::SliceRandom;
+ use rand::thread_rng;
+
+ let all_tokens = if valid {
+ &ds.valid_tokens
+ } else {
+ &ds.train_tokens
+ };
+ let mut tokens = all_tokens.iter().collect::<Vec<_>>();
+ tokens.shuffle(&mut thread_rng());
+ let current_tokens = tokens.pop().unwrap();
+ let seq_len_in_bytes = seq_len * 2;
+ let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
+ .step_by(seq_len_in_bytes)
+ .collect::<Vec<_>>();
+ indexes_in_bytes.shuffle(&mut thread_rng());
+ Self {
+ all_tokens,
+ tokens,
+ current_tokens,
+ indexes_in_bytes,
+ seq_len,
+ device,
+ }
+ }
+}
+
+impl<'a> Iterator for DatasetRandomIter<'a> {
+ type Item = Result<(Tensor, Tensor)>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ use byteorder::{LittleEndian, ReadBytesExt};
+ use rand::seq::SliceRandom;
+ use rand::thread_rng;
+
+ let seq_len = self.seq_len;
+ if self.indexes_in_bytes.is_empty() {
+ if self.tokens.is_empty() {
+ self.tokens = self.all_tokens.iter().collect();
+ self.tokens.shuffle(&mut thread_rng());
+ }
+ self.current_tokens = self.tokens.pop().unwrap();
+ let seq_len_in_bytes = self.seq_len * 2;
+ self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
+ .step_by(seq_len_in_bytes)
+ .collect::<Vec<_>>();
+ self.indexes_in_bytes.shuffle(&mut thread_rng());
+ }
+ let start_idx = self.indexes_in_bytes.pop().unwrap();
+ let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
+ let mut tokens = vec![0u16; bytes.len() / 2];
+ if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
+ return Some(Err(err.into()));
+ }
+ let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
+ let inputs = Tensor::new(&tokens[..seq_len], &self.device);
+ let targets = Tensor::new(&tokens[1..], &self.device);
+ Some(candle::error::zip(inputs, targets))
+ }
+}
diff --git a/candle-nn/src/vision/cifar.rs b/candle-datasets/src/vision/cifar.rs
index 0683c4d2..0683c4d2 100644
--- a/candle-nn/src/vision/cifar.rs
+++ b/candle-datasets/src/vision/cifar.rs
diff --git a/candle-nn/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 2267f9a0..2267f9a0 100644
--- a/candle-nn/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
diff --git a/candle-nn/src/vision/mod.rs b/candle-datasets/src/vision/mod.rs
index 6ce743eb..6ce743eb 100644
--- a/candle-nn/src/vision/mod.rs
+++ b/candle-datasets/src/vision/mod.rs
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index c4e34656..54eb0be6 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -10,7 +10,9 @@ license.workspace = true
readme = "README.md"
[dependencies]
+accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
candle-nn = { path = "../candle-nn", version = "0.1.0" }
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
@@ -21,6 +23,7 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
+image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -42,6 +45,7 @@ anyhow = { workspace = true }
[features]
default = []
+accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
@@ -50,3 +54,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
+
+[[example]]
+name = "stable-diffusion"
+required-features = ["image"]
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index f3cf17bc..b2c4e55a 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -9,6 +9,9 @@
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -111,6 +114,10 @@ struct Args {
#[arg(long)]
use_f32: bool,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
#[arg(long)]
model_id: Option<String>,
@@ -123,8 +130,18 @@ struct Args {
fn main() -> Result<()> {
use tokenizers::Tokenizer;
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
let config = if args.v1 {
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index ae27afc1..f5ac587e 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -47,6 +47,21 @@ impl Config {
}
}
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
@@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
struct RmsNorm {
scale: Tensor,
eps: f64,
+ span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps })
+ Ok(Self { scale, eps, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
@@ -155,6 +174,8 @@ struct CausalSelfAttention {
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
+ span: tracing::Span,
+ span_rot: tracing::Span,
}
#[cfg(feature = "flash-attn")]
@@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
@@ -188,6 +210,7 @@ impl CausalSelfAttention {
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@@ -269,6 +292,8 @@ impl CausalSelfAttention {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
@@ -286,6 +311,8 @@ impl CausalSelfAttention {
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
+ span,
+ span_rot,
})
}
}
@@ -301,15 +328,18 @@ struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
+ span: tracing::Span,
}
impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
@@ -319,6 +349,7 @@ impl Mlp {
c_fc1,
c_fc2,
c_proj,
+ span,
})
}
}
@@ -328,10 +359,12 @@ struct Block {
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
+ span: tracing::Span,
}
impl Block {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
@@ -341,6 +374,7 @@ impl Block {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@@ -354,6 +388,7 @@ impl Block {
attn,
rms_2,
mlp,
+ span,
})
}
}
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 8b64fdd2..418218b6 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -1,5 +1,8 @@
// https://github.com/karpathy/llama2.c
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -27,7 +30,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
- /// Config file in binary format.
+ /// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
@@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
- let mut file = std::fs::File::open(config_path)?;
- let config = Config::from_reader(&mut file)?;
- println!("{config:?}");
- let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
- let vb = weights.var_builder(&config, &device)?;
+ let is_safetensors = config_path
+ .extension()
+ .map_or(false, |v| v == "safetensors");
+ let (vb, config) = if is_safetensors {
+ let config = Config::tiny();
+ let tensors = candle::safetensors::load(config_path, &device)?;
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
+ (vb, config)
+ } else {
+ let mut file = std::fs::File::open(config_path)?;
+ let config = Config::from_reader(&mut file)?;
+ println!("{config:?}");
+ let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
+ let vb = weights.var_builder(&config, &device)?;
+ (vb, config)
+ };
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index e55c686c..3e93c786 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -1,118 +1,6 @@
-#![allow(dead_code)]
-#![allow(unused)]
use crate::model::{Cache, Config, Llama};
-use candle::{DType, Device, Result, Tensor};
-
-pub struct Dataset {
- valid_tokens: Vec<memmap2::Mmap>,
- train_tokens: Vec<memmap2::Mmap>,
-}
-
-fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
- let file = std::fs::File::open(p)?;
- let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
- Ok(mmap)
-}
-
-impl Dataset {
- pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
- let dir = dir.as_ref();
- let mut bin_files = vec![];
- for file in std::fs::read_dir(dir)?.flatten() {
- let file = file.path();
- if let Some(extension) = file.extension() {
- if extension == "bin" {
- bin_files.push(file)
- }
- }
- }
- if bin_files.len() < 2 {
- candle::bail!("found less than two bin files in {:?}", dir)
- }
- bin_files.sort();
- let valid_tokens = mmap_file(&bin_files[0])?;
- let train_tokens = bin_files[1..]
- .iter()
- .map(mmap_file)
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- valid_tokens: vec![valid_tokens],
- train_tokens,
- })
- }
-}
-
-struct DatasetRandomIter<'a> {
- all_tokens: &'a [memmap2::Mmap],
- tokens: Vec<&'a memmap2::Mmap>,
- current_tokens: &'a memmap2::Mmap,
- indexes_in_bytes: Vec<usize>,
- seq_len: usize,
- device: Device,
-}
-
-impl<'a> DatasetRandomIter<'a> {
- pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let all_tokens = if valid {
- &ds.valid_tokens
- } else {
- &ds.train_tokens
- };
- let mut tokens = all_tokens.iter().collect::<Vec<_>>();
- tokens.shuffle(&mut thread_rng());
- let current_tokens = tokens.pop().unwrap();
- let seq_len_in_bytes = seq_len * 2;
- let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- indexes_in_bytes.shuffle(&mut thread_rng());
- Self {
- all_tokens,
- tokens,
- current_tokens,
- indexes_in_bytes,
- seq_len,
- device,
- }
- }
-}
-
-impl<'a> Iterator for DatasetRandomIter<'a> {
- type Item = Result<(Tensor, Tensor)>;
-
- fn next(&mut self) -> Option<Self::Item> {
- use byteorder::{LittleEndian, ReadBytesExt};
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let seq_len = self.seq_len;
- if self.indexes_in_bytes.is_empty() {
- if self.tokens.is_empty() {
- self.tokens = self.all_tokens.iter().collect();
- self.tokens.shuffle(&mut thread_rng());
- }
- self.current_tokens = self.tokens.pop().unwrap();
- let seq_len_in_bytes = self.seq_len * 2;
- self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- self.indexes_in_bytes.shuffle(&mut thread_rng());
- }
- let start_idx = self.indexes_in_bytes.pop().unwrap();
- let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
- let mut tokens = vec![0u16; bytes.len() / 2];
- if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
- return Some(Err(err.into()));
- }
- let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
- let inputs = Tensor::new(&tokens[..seq_len], &self.device);
- let targets = Tensor::new(&tokens[1..], &self.device);
- Some(candle::error::zip(inputs, targets))
- }
-}
+use candle::{DType, Device, Result};
+use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
fn valid_loss(
dataset: &Dataset,
@@ -121,7 +9,7 @@ fn valid_loss(
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
@@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
- dataset.train_tokens.len(),
- dataset.valid_tokens.len()
+ dataset.train_tokens(),
+ dataset.valid_tokens()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs
index ae1fd6d9..b78418ce 100644
--- a/candle-examples/examples/llama2-c/weights.rs
+++ b/candle-examples/examples/llama2-c/weights.rs
@@ -104,7 +104,14 @@ impl TransformerWeights {
})
}
- pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
+ // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
+ // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
+ // second matrix back. We detect this case here and as a temporary hack make the weight
+ // matrix column major rather than row major. This ends up speeding up text generation from
+ // 120 token/s to 220 token/s on a Ryzen 2600X.
+ let tr = device.is_cpu() && !candle::utils::has_mkl();
+ let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
@@ -115,36 +122,36 @@ impl TransformerWeights {
"model.embed_tokens.weight",
self.token_embedding_table.clone(),
);
- insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
insert("model.norm.weight", self.rms_final_weight.clone());
for layer in 0..cfg.n_layers {
ws.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
- self.wq.i(layer)?,
+ tr(self.wq.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
- self.wk.i(layer)?,
+ tr(self.wk.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
- self.wv.i(layer)?,
+ tr(self.wv.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
- self.wo.i(layer)?,
+ tr(self.wo.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
- self.w1.i(layer)?,
+ tr(self.w1.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
- self.w2.i(layer)?,
+ tr(self.w2.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
- self.w3.i(layer)?,
+ tr(self.w3.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.input_layernorm.weight"),
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index e251f6e9..d9e596ce 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -63,7 +63,7 @@ struct TrainingArgs {
}
fn training_loop<M: Model>(
- m: candle_nn::vision::Dataset,
+ m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
@@ -140,7 +140,7 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_nn::vision::mnist::load_dir("data")?;
+ let m = candle_datasets::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
new file mode 100644
index 00000000..83e7ef34
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -0,0 +1,445 @@
+#![allow(dead_code)]
+//! Attention Based Building Blocks
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct GeGlu {
+ proj: nn::Linear,
+}
+
+impl GeGlu {
+ fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
+ let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
+ Ok(Self { proj })
+ }
+}
+
+impl GeGlu {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
+ &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
+ }
+}
+
+/// A feed-forward layer.
+#[derive(Debug)]
+struct FeedForward {
+ project_in: GeGlu,
+ linear: nn::Linear,
+}
+
+impl FeedForward {
+ // The glu parameter in the python code is unused?
+ // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
+ /// Creates a new feed-forward layer based on some given input dimension, some
+ /// output dimension, and a multiplier to be used for the intermediary layer.
+ fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
+ let inner_dim = dim * mult;
+ let dim_out = dim_out.unwrap_or(dim);
+ let vs = vs.pp("net");
+ let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
+ let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
+ Ok(Self { project_in, linear })
+ }
+}
+
+impl FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.project_in.forward(xs)?;
+ self.linear.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct CrossAttention {
+ to_q: nn::Linear,
+ to_k: nn::Linear,
+ to_v: nn::Linear,
+ to_out: nn::Linear,
+ heads: usize,
+ scale: f64,
+ slice_size: Option<usize>,
+}
+
+impl CrossAttention {
+ // Defaults should be heads = 8, dim_head = 64, context_dim = None
+ fn new(
+ vs: nn::VarBuilder,
+ query_dim: usize,
+ context_dim: Option<usize>,
+ heads: usize,
+ dim_head: usize,
+ slice_size: Option<usize>,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let context_dim = context_dim.unwrap_or(query_dim);
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
+ let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
+ let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
+ let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ heads,
+ scale,
+ slice_size,
+ })
+ }
+
+ fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
+ .transpose(1, 2)?
+ .reshape((batch_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
+ .transpose(1, 2)?
+ .reshape((batch_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn sliced_attention(
+ &self,
+ query: &Tensor,
+ key: &Tensor,
+ value: &Tensor,
+ slice_size: usize,
+ ) -> Result<Tensor> {
+ let batch_size_attention = query.dim(0)?;
+ let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
+
+ for i in 0..batch_size_attention / slice_size {
+ let start_idx = i * slice_size;
+ let end_idx = (i + 1) * slice_size;
+
+ let xs = query
+ .i(start_idx..end_idx)?
+ .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
+ hidden_states.push(xs)
+ }
+ let hidden_states = Tensor::stack(&hidden_states, 0)?;
+ self.reshape_batch_dim_to_heads(&hidden_states)
+ }
+
+ fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
+ let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
+ self.reshape_batch_dim_to_heads(&xs)
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let query = self.to_q.forward(xs)?;
+ let context = context.unwrap_or(xs);
+ let key = self.to_k.forward(context)?;
+ let value = self.to_v.forward(context)?;
+ let query = self.reshape_heads_to_batch_dim(&query)?;
+ let key = self.reshape_heads_to_batch_dim(&key)?;
+ let value = self.reshape_heads_to_batch_dim(&value)?;
+ let xs = match self.slice_size {
+ None => self.attention(&query, &key, &value)?,
+ Some(slice_size) => {
+ if query.dim(0)? / slice_size <= 1 {
+ self.attention(&query, &key, &value)?
+ } else {
+ self.sliced_attention(&query, &key, &value, slice_size)?
+ }
+ }
+ };
+ self.to_out.forward(&xs)
+ }
+}
+
+/// A basic Transformer block.
+#[derive(Debug)]
+struct BasicTransformerBlock {
+ attn1: CrossAttention,
+ ff: FeedForward,
+ attn2: CrossAttention,
+ norm1: nn::LayerNorm,
+ norm2: nn::LayerNorm,
+ norm3: nn::LayerNorm,
+}
+
+impl BasicTransformerBlock {
+ fn new(
+ vs: nn::VarBuilder,
+ dim: usize,
+ n_heads: usize,
+ d_head: usize,
+ context_dim: Option<usize>,
+ sliced_attention_size: Option<usize>,
+ ) -> Result<Self> {
+ let attn1 = CrossAttention::new(
+ vs.pp("attn1"),
+ dim,
+ None,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
+ let attn2 = CrossAttention::new(
+ vs.pp("attn2"),
+ dim,
+ context_dim,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
+ let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
+ let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
+ Ok(Self {
+ attn1,
+ ff,
+ attn2,
+ norm1,
+ norm2,
+ norm3,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
+ let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
+ self.ff.forward(&self.norm3.forward(&xs)?)? + xs
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct SpatialTransformerConfig {
+ pub depth: usize,
+ pub num_groups: usize,
+ pub context_dim: Option<usize>,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for SpatialTransformerConfig {
+ fn default() -> Self {
+ Self {
+ depth: 1,
+ num_groups: 32,
+ context_dim: None,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+enum Proj {
+ Conv2d(nn::Conv2d),
+ Linear(nn::Linear),
+}
+
+// Aka Transformer2DModel
+#[derive(Debug)]
+pub struct SpatialTransformer {
+ norm: nn::GroupNorm,
+ proj_in: Proj,
+ transformer_blocks: Vec<BasicTransformerBlock>,
+ proj_out: Proj,
+ pub config: SpatialTransformerConfig,
+}
+
+impl SpatialTransformer {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ n_heads: usize,
+ d_head: usize,
+ config: SpatialTransformerConfig,
+ ) -> Result<Self> {
+ let inner_dim = n_heads * d_head;
+ let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
+ let proj_in = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ in_channels,
+ inner_dim,
+ 1,
+ Default::default(),
+ vs.pp("proj_in"),
+ )?)
+ };
+ let mut transformer_blocks = vec![];
+ let vs_tb = vs.pp("transformer_blocks");
+ for index in 0..config.depth {
+ let tb = BasicTransformerBlock::new(
+ vs_tb.pp(&index.to_string()),
+ inner_dim,
+ n_heads,
+ d_head,
+ config.context_dim,
+ config.sliced_attention_size,
+ )?;
+ transformer_blocks.push(tb)
+ }
+ let proj_out = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ inner_dim,
+ in_channels,
+ 1,
+ Default::default(),
+ vs.pp("proj_out"),
+ )?)
+ };
+ Ok(Self {
+ norm,
+ proj_in,
+ transformer_blocks,
+ proj_out,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let (batch, _channel, height, weight) = xs.dims4()?;
+ let residual = xs;
+ let xs = self.norm.forward(xs)?;
+ let (inner_dim, xs) = match &self.proj_in {
+ Proj::Conv2d(p) => {
+ let xs = p.forward(&xs)?;
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, xs)
+ }
+ Proj::Linear(p) => {
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, p.forward(&xs)?)
+ }
+ };
+ let mut xs = xs;
+ for block in self.transformer_blocks.iter() {
+ xs = block.forward(&xs, context)?
+ }
+ let xs = match &self.proj_out {
+ Proj::Conv2d(p) => p.forward(
+ &xs.reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ )?,
+ Proj::Linear(p) => p
+ .forward(&xs)?
+ .reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ };
+ xs + residual
+ }
+}
+
+/// Configuration for an attention block.
+#[derive(Debug, Clone, Copy)]
+pub struct AttentionBlockConfig {
+ pub num_head_channels: Option<usize>,
+ pub num_groups: usize,
+ pub rescale_output_factor: f64,
+ pub eps: f64,
+}
+
+impl Default for AttentionBlockConfig {
+ fn default() -> Self {
+ Self {
+ num_head_channels: None,
+ num_groups: 32,
+ rescale_output_factor: 1.,
+ eps: 1e-5,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct AttentionBlock {
+ group_norm: nn::GroupNorm,
+ query: nn::Linear,
+ key: nn::Linear,
+ value: nn::Linear,
+ proj_attn: nn::Linear,
+ channels: usize,
+ num_heads: usize,
+ config: AttentionBlockConfig,
+}
+
+impl AttentionBlock {
+ pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
+ let num_head_channels = config.num_head_channels.unwrap_or(channels);
+ let num_heads = channels / num_head_channels;
+ let group_norm =
+ nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
+ let query = nn::linear(channels, channels, vs.pp("query"))?;
+ let key = nn::linear(channels, channels, vs.pp("key"))?;
+ let value = nn::linear(channels, channels, vs.pp("value"))?;
+ let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
+ Ok(Self {
+ group_norm,
+ query,
+ key,
+ value,
+ proj_attn,
+ channels,
+ num_heads,
+ config,
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
+ let (batch, t, h_times_d) = xs.dims3()?;
+ xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
+ .transpose(1, 2)
+ }
+}
+
+impl AttentionBlock {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let (batch, channel, height, width) = xs.dims4()?;
+ let xs = self
+ .group_norm
+ .forward(xs)?
+ .reshape((batch, channel, height * width))?
+ .transpose(1, 2)?;
+
+ let query_proj = self.query.forward(&xs)?;
+ let key_proj = self.key.forward(&xs)?;
+ let value_proj = self.value.forward(&xs)?;
+
+ let query_states = self.transpose_for_scores(query_proj)?;
+ let key_states = self.transpose_for_scores(key_proj)?;
+ let value_states = self.transpose_for_scores(value_proj)?;
+
+ let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
+ let attention_scores =
+ // TODO: Check that this needs two multiplication by `scale`.
+ (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
+ let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
+
+ let xs = attention_probs.matmul(&value_states)?;
+ let xs = xs.transpose(1, 2)?.contiguous()?;
+ let xs = xs.flatten_from(D::Minus2)?;
+ let xs = self
+ .proj_attn
+ .forward(&xs)?
+ .t()?
+ .reshape((batch, channel, height, width))?;
+ (xs + residual)? / self.config.rescale_output_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
new file mode 100644
index 00000000..ca00b417
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -0,0 +1,305 @@
+#![allow(dead_code)]
+//! Contrastive Language-Image Pre-Training
+//!
+//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+//! pairs of images with related texts.
+//!
+//! https://github.com/openai/CLIP
+use candle::{Device, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug, Clone, Copy)]
+pub enum Activation {
+ QuickGelu,
+ Gelu,
+}
+
+impl Activation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::Gelu => xs.gelu(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ vocab_size: usize,
+ embed_dim: usize, // aka config.hidden_size
+ activation: Activation, // aka config.hidden_act
+ intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ // The character to use for padding, use EOS when not set.
+ pub pad_with: Option<String>,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ #[allow(dead_code)]
+ projection_dim: usize,
+}
+
+impl Config {
+ // The config details can be found in the "text_config" section of this json file:
+ // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
+ pub fn v1_5() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
+ pub fn v2_1() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 23,
+ num_attention_heads: 16,
+ projection_dim: 512,
+ activation: Activation::Gelu,
+ }
+ }
+}
+
+// CLIP Text Model
+// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
+#[derive(Debug)]
+struct ClipTextEmbeddings {
+ token_embedding: candle_nn::Embedding,
+ position_embedding: candle_nn::Embedding,
+ position_ids: Tensor,
+}
+
+impl ClipTextEmbeddings {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let token_embedding =
+ candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
+ let position_embedding = candle_nn::embedding(
+ c.max_position_embeddings,
+ c.embed_dim,
+ vs.pp("position_embedding"),
+ )?;
+ let position_ids =
+ Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
+ Ok(ClipTextEmbeddings {
+ token_embedding,
+ position_embedding,
+ position_ids,
+ })
+ }
+}
+
+impl ClipTextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let token_embedding = self.token_embedding.forward(xs)?;
+ let position_embedding = self.position_embedding.forward(&self.position_ids)?;
+ token_embedding.broadcast_add(&position_embedding)
+ }
+}
+
+#[derive(Debug)]
+struct ClipAttention {
+ k_proj: candle_nn::Linear,
+ v_proj: candle_nn::Linear,
+ q_proj: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+ scale: f64,
+ num_attention_heads: usize,
+}
+
+impl ClipAttention {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let embed_dim = c.embed_dim;
+ let num_attention_heads = c.num_attention_heads;
+ let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
+ let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
+ let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
+ let head_dim = embed_dim / num_attention_heads;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(ClipAttention {
+ k_proj,
+ v_proj,
+ q_proj,
+ out_proj,
+ head_dim,
+ scale,
+ num_attention_heads,
+ })
+ }
+
+ fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
+ xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len, embed_dim) = xs.dims3()?;
+ let query_states = (self.q_proj.forward(xs)? * self.scale)?;
+ let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
+ let query_states = self
+ .shape(&query_states, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let key_states = self
+ .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let value_states = self
+ .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
+
+ let src_len = key_states.dim(1)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
+ let attn_weights =
+ attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+
+ let attn_output = attn_weights.matmul(&value_states)?;
+ let attn_output = attn_output
+ .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
+ .transpose(1, 2)?
+ .reshape((bsz, seq_len, embed_dim))?;
+ self.out_proj.forward(&attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct ClipMlp {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ activation: Activation,
+}
+
+impl ClipMlp {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
+ Ok(ClipMlp {
+ fc1,
+ fc2,
+ activation: c.activation,
+ })
+ }
+}
+
+impl ClipMlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?;
+ self.fc2.forward(&self.activation.forward(&xs)?)
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoderLayer {
+ self_attn: ClipAttention,
+ layer_norm1: candle_nn::LayerNorm,
+ mlp: ClipMlp,
+ layer_norm2: candle_nn::LayerNorm,
+}
+
+impl ClipEncoderLayer {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
+ let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
+ let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
+ let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
+ Ok(ClipEncoderLayer {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.layer_norm1.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = self.layer_norm2.forward(&xs)?;
+ let xs = self.mlp.forward(&xs)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoder {
+ layers: Vec<ClipEncoderLayer>,
+}
+
+impl ClipEncoder {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("layers");
+ let mut layers: Vec<ClipEncoderLayer> = Vec::new();
+ for index in 0..c.num_hidden_layers {
+ let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ layers.push(layer)
+ }
+ Ok(ClipEncoder { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?;
+ }
+ Ok(xs)
+ }
+}
+
+/// A CLIP transformer based model.
+#[derive(Debug)]
+pub struct ClipTextTransformer {
+ embeddings: ClipTextEmbeddings,
+ encoder: ClipEncoder,
+ final_layer_norm: candle_nn::LayerNorm,
+}
+
+impl ClipTextTransformer {
+ pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("text_model");
+ let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
+ let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
+ let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
+ Ok(ClipTextTransformer {
+ embeddings,
+ encoder,
+ final_layer_norm,
+ })
+ }
+
+ // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
+ fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
+ mask.broadcast_as((bsz, seq_len, seq_len))
+ }
+}
+
+impl ClipTextTransformer {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
+ let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
+ self.final_layer_norm.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-examples/examples/stable-diffusion/ddim.rs
new file mode 100644
index 00000000..6eb6df44
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/ddim.rs
@@ -0,0 +1,181 @@
+#![allow(dead_code)]
+//! # Denoising Diffusion Implicit Models
+//!
+//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
+//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
+//! generative process is the reverse of a Markovian process, DDIM generalizes
+//! this to non-Markovian guidance.
+//!
+//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
+//! https://arxiv.org/abs/2010.02502
+use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use candle::{Result, Tensor};
+
+/// The configuration for the DDIM scheduler.
+#[derive(Debug, Clone, Copy)]
+pub struct DDIMSchedulerConfig {
+ /// The value of beta at the beginning of training.
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// The amount of noise to be added at each step.
+ pub eta: f64,
+ /// Adjust the indexes of the inference schedule by this value.
+ pub steps_offset: usize,
+ /// prediction type of the scheduler function, one of `epsilon` (predicting
+ /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
+ /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model
+ pub train_timesteps: usize,
+}
+
+impl Default for DDIMSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085f64,
+ beta_end: 0.012f64,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ eta: 0.,
+ steps_offset: 1,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ }
+ }
+}
+
+/// The DDIM scheduler.
+#[derive(Debug, Clone)]
+pub struct DDIMScheduler {
+ timesteps: Vec<usize>,
+ alphas_cumprod: Vec<f64>,
+ step_ratio: usize,
+ init_noise_sigma: f64,
+ pub config: DDIMSchedulerConfig,
+}
+
+// clip_sample: False, set_alpha_to_one: False
+impl DDIMScheduler {
+ /// Creates a new DDIM scheduler given the number of steps to be
+ /// used for inference as well as the number of steps that was used
+ /// during training.
+ pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect();
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => crate::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+ Ok(Self {
+ alphas_cumprod,
+ timesteps,
+ step_ratio,
+ init_noise_sigma: 1.,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
+ Ok(sample)
+ }
+
+ /// Performs a backward step during inference.
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
+ let prev_timestep = if timestep > self.step_ratio {
+ timestep - self.step_ratio
+ } else {
+ 0
+ };
+
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
+ let beta_prod_t = 1. - alpha_prod_t;
+ let beta_prod_t_prev = 1. - alpha_prod_t_prev;
+
+ let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
+ PredictionType::Epsilon => {
+ let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
+ * (1. / alpha_prod_t.sqrt()))?;
+ (pred_original_sample, model_output.clone())
+ }
+ PredictionType::VPrediction => {
+ let pred_original_sample =
+ ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
+ let pred_epsilon =
+ ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
+ (pred_original_sample, pred_epsilon)
+ }
+ PredictionType::Sample => {
+ let pred_original_sample = model_output.clone();
+ let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
+ * (1. / beta_prod_t.sqrt()))?;
+ (pred_original_sample, pred_epsilon)
+ }
+ };
+
+ let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
+ let std_dev_t = self.config.eta * variance.sqrt();
+
+ let pred_sample_direction =
+ (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
+ let prev_sample =
+ ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
+ if self.config.eta > 0. {
+ &prev_sample
+ + Tensor::randn(
+ 0f32,
+ std_dev_t as f32,
+ prev_sample.shape(),
+ prev_sample.device(),
+ )?
+ } else {
+ Ok(prev_sample)
+ }
+ }
+
+ pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
+ let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
+ (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
new file mode 100644
index 00000000..e3a339f5
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+pub struct TimestepEmbedding {
+ linear_1: nn::Linear,
+ linear_2: nn::Linear,
+}
+
+impl TimestepEmbedding {
+ // act_fn: "silu"
+ pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
+ let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
+ let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
+ Ok(Self { linear_1, linear_2 })
+ }
+}
+
+impl TimestepEmbedding {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
+ self.linear_2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Timesteps {
+ num_channels: usize,
+ flip_sin_to_cos: bool,
+ downscale_freq_shift: f64,
+}
+
+impl Timesteps {
+ pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
+ Self {
+ num_channels,
+ flip_sin_to_cos,
+ downscale_freq_shift,
+ }
+ }
+}
+
+impl Timesteps {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let half_dim = (self.num_channels / 2) as u32;
+ let exponent =
+ (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
+ let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
+ let emb = exponent.exp()?;
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let (cos, sin) = (emb.cos()?, emb.sin()?);
+ let emb = if self.flip_sin_to_cos {
+ Tensor::cat(&[&cos, &sin], D::Minus1)?
+ } else {
+ Tensor::cat(&[&sin, &cos], D::Minus1)?
+ };
+ if self.num_channels % 2 == 1 {
+ emb.pad_with_zeros(D::Minus2, 0, 1)
+ } else {
+ Ok(emb)
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
new file mode 100644
index 00000000..8ce0c234
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -0,0 +1,273 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+mod attention;
+mod clip;
+mod ddim;
+mod embeddings;
+mod resnet;
+mod schedulers;
+mod stable_diffusion;
+mod unet_2d;
+mod unet_2d_blocks;
+mod utils;
+mod vae;
+
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, Tensor};
+use clap::Parser;
+use tokenizers::Tokenizer;
+
+const GUIDANCE_SCALE: f64 = 7.5;
+
+#[derive(Parser)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The prompt to be used for image generation.
+ #[arg(
+ long,
+ default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
+ )]
+ prompt: String,
+
+ #[arg(long, default_value = "")]
+ uncond_prompt: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// The height in pixels of the generated image.
+ #[arg(long)]
+ height: Option<usize>,
+
+ /// The width in pixels of the generated image.
+ #[arg(long)]
+ width: Option<usize>,
+
+ /// The UNet weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ unet_weights: Option<String>,
+
+ /// The CLIP weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ clip_weights: Option<String>,
+
+ /// The VAE weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ vae_weights: Option<String>,
+
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for tokenization.
+ tokenizer: String,
+
+ /// The size of the sliced attention or 0 for automatic slicing (disabled by default)
+ #[arg(long)]
+ sliced_attention_size: Option<usize>,
+
+ /// The number of steps to run the diffusion for.
+ #[arg(long, default_value_t = 30)]
+ n_steps: usize,
+
+ /// The number of samples to generate.
+ #[arg(long, default_value_t = 1)]
+ num_samples: i64,
+
+ /// The name of the final image to generate.
+ #[arg(long, value_name = "FILE", default_value = "sd_final.png")]
+ final_image: String,
+
+ #[arg(long, value_enum, default_value = "v2-1")]
+ sd_version: StableDiffusionVersion,
+
+ /// Generate intermediary images at each step.
+ #[arg(long, action)]
+ intermediary_images: bool,
+}
+
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+enum StableDiffusionVersion {
+ V1_5,
+ V2_1,
+}
+
+impl Args {
+ fn clip_weights(&self) -> String {
+ match &self.clip_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+
+ fn vae_weights(&self) -> String {
+ match &self.vae_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+
+ fn unet_weights(&self) -> String {
+ match &self.unet_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+}
+
+fn output_filename(
+ basename: &str,
+ sample_idx: i64,
+ num_samples: i64,
+ timestep_idx: Option<usize>,
+) -> String {
+ let filename = if num_samples > 1 {
+ match basename.rsplit_once('.') {
+ None => format!("{basename}.{sample_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}.{sample_idx}.{extension}")
+ }
+ }
+ } else {
+ basename.to_string()
+ };
+ match timestep_idx {
+ None => filename,
+ Some(timestep_idx) => match filename.rsplit_once('.') {
+ None => format!("{filename}-{timestep_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}-{timestep_idx}.{extension}")
+ }
+ },
+ }
+}
+
+fn run(args: Args) -> Result<()> {
+ let clip_weights = args.clip_weights();
+ let vae_weights = args.vae_weights();
+ let unet_weights = args.unet_weights();
+ let Args {
+ prompt,
+ uncond_prompt,
+ cpu,
+ height,
+ width,
+ n_steps,
+ tokenizer,
+ final_image,
+ sliced_attention_size,
+ num_samples,
+ sd_version,
+ ..
+ } = args;
+ let sd_config = match sd_version {
+ StableDiffusionVersion::V1_5 => {
+ stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
+ }
+ StableDiffusionVersion::V2_1 => {
+ stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
+ }
+ };
+
+ let scheduler = sd_config.build_scheduler(n_steps)?;
+ let device = candle_examples::device(cpu)?;
+
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
+
+ println!("Building the autoencoder.");
+ let vae = sd_config.build_vae(&vae_weights, &device)?;
+ println!("Building the unet.");
+ let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
+
+ let bsize = 1;
+ for idx in 0..num_samples {
+ let mut latents = Tensor::randn(
+ 0f32,
+ 1f32,
+ (bsize, 4, sd_config.height / 8, sd_config.width / 8),
+ &device,
+ )?;
+
+ // scale the initial noise by the standard deviation required by the scheduler
+ latents = (latents * scheduler.init_noise_sigma())?;
+
+ for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
+ println!("Timestep {timestep_index}/{n_steps}");
+ let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+
+ let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
+ let noise_pred =
+ unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
+ let noise_pred =
+ (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
+ latents = scheduler.step(&noise_pred, timestep, &latents)?;
+
+ if args.intermediary_images {
+ let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename =
+ output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
+ crate::utils::save_image(&image, image_filename)?
+ }
+ }
+
+ println!(
+ "Generating the final image for sample {}/{}.",
+ idx + 1,
+ num_samples
+ );
+ let image = vae.decode(&(&latents / 0.18215)?)?;
+ // TODO: Add the clamping between 0 and 1.
+ let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
+ crate::utils::save_image(&image, image_filename)?
+ }
+ Ok(())
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ run(args)
+}
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
new file mode 100644
index 00000000..7790dcf9
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -0,0 +1,129 @@
+#![allow(dead_code)]
+//! ResNet Building Blocks
+//!
+//! Some Residual Network blocks used in UNet models.
+//!
+//! Denoising Diffusion Implicit Models, K. He and al, 2015.
+//! https://arxiv.org/abs/1512.03385
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+/// Configuration for a ResNet block.
+#[derive(Debug, Clone, Copy)]
+pub struct ResnetBlock2DConfig {
+ /// The number of output channels, defaults to the number of input channels.
+ pub out_channels: Option<usize>,
+ pub temb_channels: Option<usize>,
+ /// The number of groups to use in group normalization.
+ pub groups: usize,
+ pub groups_out: Option<usize>,
+ /// The epsilon to be used in the group normalization operations.
+ pub eps: f64,
+ /// Whether to use a 2D convolution in the skip connection. When using None,
+ /// such a convolution is used if the number of input channels is different from
+ /// the number of output channels.
+ pub use_in_shortcut: Option<bool>,
+ // non_linearity: silu
+ /// The final output is scaled by dividing by this value.
+ pub output_scale_factor: f64,
+}
+
+impl Default for ResnetBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ out_channels: None,
+ temb_channels: Some(512),
+ groups: 32,
+ groups_out: None,
+ eps: 1e-6,
+ use_in_shortcut: None,
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ResnetBlock2D {
+ norm1: nn::GroupNorm,
+ conv1: nn::Conv2d,
+ norm2: nn::GroupNorm,
+ conv2: nn::Conv2d,
+ time_emb_proj: Option<nn::Linear>,
+ conv_shortcut: Option<nn::Conv2d>,
+ config: ResnetBlock2DConfig,
+}
+
+impl ResnetBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ config: ResnetBlock2DConfig,
+ ) -> Result<Self> {
+ let out_channels = config.out_channels.unwrap_or(in_channels);
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
+ let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
+ let groups_out = config.groups_out.unwrap_or(config.groups);
+ let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
+ let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
+ let use_in_shortcut = config
+ .use_in_shortcut
+ .unwrap_or(in_channels != out_channels);
+ let conv_shortcut = if use_in_shortcut {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 0,
+ };
+ Some(nn::conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ conv_cfg,
+ vs.pp("conv_shortcut"),
+ )?)
+ } else {
+ None
+ };
+ let time_emb_proj = match config.temb_channels {
+ None => None,
+ Some(temb_channels) => Some(nn::linear(
+ temb_channels,
+ out_channels,
+ vs.pp("time_emb_proj"),
+ )?),
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ time_emb_proj,
+ config,
+ conv_shortcut,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let shortcut_xs = match &self.conv_shortcut {
+ Some(conv_shortcut) => conv_shortcut.forward(xs)?,
+ None => xs.clone(),
+ };
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
+ let xs = match (temb, &self.time_emb_proj) {
+ (Some(temb), Some(time_emb_proj)) => time_emb_proj
+ .forward(&nn::ops::silu(temb)?)?
+ .unsqueeze(D::Minus1)?
+ .unsqueeze(D::Minus1)?
+ .broadcast_add(&xs)?,
+ _ => xs,
+ };
+ let xs = self
+ .conv2
+ .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
+ (shortcut_xs + xs)? / self.config.output_scale_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-examples/examples/stable-diffusion/schedulers.rs
new file mode 100644
index 00000000..3f6a1d72
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/schedulers.rs
@@ -0,0 +1,45 @@
+#![allow(dead_code)]
+//! # Diffusion pipelines and models
+//!
+//! Noise schedulers can be used to set the trade-off between
+//! inference speed and quality.
+
+use candle::{Result, Tensor};
+
+/// This represents how beta ranges from its minimum value to the maximum
+/// during training.
+#[derive(Debug, Clone, Copy)]
+pub enum BetaSchedule {
+ /// Linear interpolation.
+ Linear,
+ /// Linear interpolation of the square root of beta.
+ ScaledLinear,
+ /// Glide cosine schedule
+ SquaredcosCapV2,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum PredictionType {
+ Epsilon,
+ VPrediction,
+ Sample,
+}
+
+/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+/// `(1-beta)` over time from `t = [0,1]`.
+///
+/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
+/// up to that part of the diffusion process.
+pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
+ let alpha_bar = |time_step: usize| {
+ f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
+ };
+ let mut betas = Vec::with_capacity(num_diffusion_timesteps);
+ for i in 0..num_diffusion_timesteps {
+ let t1 = i / num_diffusion_timesteps;
+ let t2 = (i + 1) / num_diffusion_timesteps;
+ betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
+ }
+ let betas_len = betas.len();
+ Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
+}
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
new file mode 100644
index 00000000..c250ed56
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -0,0 +1,212 @@
+#![allow(dead_code)]
+use crate::schedulers::PredictionType;
+use crate::{clip, ddim, unet_2d, vae};
+use candle::{DType, Device, Result};
+use candle_nn as nn;
+
+#[derive(Clone, Debug)]
+pub struct StableDiffusionConfig {
+ pub width: usize,
+ pub height: usize,
+ pub clip: clip::Config,
+ autoencoder: vae::AutoEncoderKLConfig,
+ unet: unet_2d::UNet2DConditionModelConfig,
+ scheduler: ddim::DDIMSchedulerConfig,
+}
+
+impl StableDiffusionConfig {
+ pub fn v1_5(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, true, 8),
+ bc(640, true, 8),
+ bc(1280, true, 8),
+ bc(1280, false, 8),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 768,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: false,
+ };
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ height
+ } else {
+ 512
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 512
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v1_5(),
+ autoencoder,
+ scheduler: Default::default(),
+ unet,
+ }
+ }
+
+ fn v2_1_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, true, 5),
+ bc(640, true, 10),
+ bc(1280, true, 20),
+ bc(1280, false, 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 1024,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ height
+ } else {
+ 768
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 768
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v2_1(),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn v2_1(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
+ Self::v2_1_(
+ sliced_attention_size,
+ height,
+ width,
+ PredictionType::VPrediction,
+ )
+ }
+
+ pub fn v2_1_inpaint(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
+ // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
+ // type being "epsilon" by default and not "v_prediction".
+ Self::v2_1_(
+ sliced_attention_size,
+ height,
+ width,
+ PredictionType::Epsilon,
+ )
+ }
+
+ pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
+ Ok(autoencoder)
+ }
+
+ pub fn build_unet(
+ &self,
+ unet_weights: &str,
+ device: &Device,
+ in_channels: usize,
+ ) -> Result<unet_2d::UNet2DConditionModel> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
+ Ok(unet)
+ }
+
+ pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
+ ddim::DDIMScheduler::new(n_steps, self.scheduler)
+ }
+
+ pub fn build_clip_transformer(
+ &self,
+ clip_weights: &str,
+ device: &Device,
+ ) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
+ Ok(text_model)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
new file mode 100644
index 00000000..8ebd1876
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -0,0 +1,383 @@
+#![allow(dead_code)]
+//! 2D UNet Denoising Models
+//!
+//! The 2D Unet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+use crate::embeddings::{TimestepEmbedding, Timesteps};
+use crate::unet_2d_blocks::*;
+use candle::{DType, Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone, Copy)]
+pub struct BlockConfig {
+ pub out_channels: usize,
+ pub use_cross_attn: bool,
+ pub attention_head_dim: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct UNet2DConditionModelConfig {
+ pub center_input_sample: bool,
+ pub flip_sin_to_cos: bool,
+ pub freq_shift: f64,
+ pub blocks: Vec<BlockConfig>,
+ pub layers_per_block: usize,
+ pub downsample_padding: usize,
+ pub mid_block_scale_factor: f64,
+ pub norm_num_groups: usize,
+ pub norm_eps: f64,
+ pub cross_attention_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ blocks: vec![
+ BlockConfig {
+ out_channels: 320,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 640,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: false,
+ attention_head_dim: 8,
+ },
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+#[derive(Debug)]
+enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+#[derive(Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: nn::Conv2d,
+ time_proj: Timesteps,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec<UNetDownBlock>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec<UNetUpBlock>,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ config: UNet2DConditionModelConfig,
+}
+
+impl UNet2DConditionModel {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UNet2DConditionModelConfig,
+ ) -> Result<Self> {
+ let n_blocks = config.blocks.len();
+ let b_channels = config.blocks[0].out_channels;
+ let bl_channels = config.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
+
+ let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
+ let time_embedding =
+ TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
+
+ let vs_db = vs.pp("down_blocks");
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let in_channels = if i > 0 {
+ config.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+ let db_cfg = DownBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: i < n_blocks - 1,
+ downsample_padding: config.downsample_padding,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig {
+ downblock: db_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnDownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetDownBlock::CrossAttn(block))
+ } else {
+ let block = DownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ db_cfg,
+ )?;
+ Ok(UNetDownBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let mid_cfg = UNetMidBlock2DCrossAttnConfig {
+ resnet_eps: config.norm_eps,
+ output_scale_factor: config.mid_block_scale_factor,
+ cross_attn_dim: config.cross_attention_dim,
+ attn_num_head_channels: bl_attention_head_dim,
+ resnet_groups: Some(config.norm_num_groups),
+ use_linear_projection: config.use_linear_projection,
+ ..Default::default()
+ };
+ let mid_block = UNetMidBlock2DCrossAttn::new(
+ vs.pp("mid_block"),
+ bl_channels,
+ Some(time_embed_dim),
+ mid_cfg,
+ )?;
+
+ let vs_ub = vs.pp("up_blocks");
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[n_blocks - 1 - i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let prev_out_channels = if i > 0 {
+ config.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+ let in_channels = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ config.blocks[index].out_channels
+ };
+ let ub_cfg = UpBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: i < n_blocks - 1,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig {
+ upblock: ub_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnUpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetUpBlock::CrossAttn(block))
+ } else {
+ let block = UpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ ub_cfg,
+ )?;
+ Ok(UNetUpBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ b_channels,
+ config.norm_eps,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
+ Ok(Self {
+ conv_in,
+ time_proj,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl UNet2DConditionModel {
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (bsize, _channels, height, width) = xs.dims4()?;
+ let device = xs.device();
+ let n_blocks = self.config.blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+ // 0. center input if necessary
+ let xs = if self.config.center_input_sample {
+ ((xs * 2.0)? - 1.0)?
+ } else {
+ xs.clone()
+ };
+ // 1. time
+ let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
+ let emb = self.time_proj.forward(&emb)?;
+ let emb = self.time_embedding.forward(&emb)?;
+ // 2. pre-process
+ let xs = self.conv_in.forward(&xs)?;
+ // 3. down
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = _xs;
+ }
+
+ let new_down_block_res_xs =
+ if let Some(down_block_additional_residuals) = down_block_additional_residuals {
+ let mut v = vec![];
+ // A previous version of this code had a bug because of the addition being made
+ // in place via += hence modifying the input of the mid block.
+ for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
+ v.push((&down_block_res_xs[i] + residuals)?)
+ }
+ v
+ } else {
+ down_block_res_xs
+ };
+ let mut down_block_res_xs = new_down_block_res_xs;
+
+ // 4. mid
+ let xs = self
+ .mid_block
+ .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
+ let xs = match mid_block_additional_residual {
+ None => xs,
+ Some(m) => (m + xs)?,
+ };
+ // 5. up
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
+ if i < n_blocks - 1 && forward_upsample_size {
+ let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
+ upsample_size = Some((h, w))
+ }
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ &xs,
+ &res_xs,
+ Some(&emb),
+ upsample_size,
+ Some(encoder_hidden_states),
+ )?,
+ };
+ }
+ // 6. post-process
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
new file mode 100644
index 00000000..82d5fad5
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -0,0 +1,808 @@
+#![allow(dead_code)]
+//! 2D UNet Building Blocks
+//!
+use crate::attention::{
+ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
+};
+use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct Downsample2D {
+ conv: Option<nn::Conv2d>,
+ padding: usize,
+}
+
+impl Downsample2D {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ use_conv: bool,
+ out_channels: usize,
+ padding: usize,
+ ) -> Result<Self> {
+ let conv = if use_conv {
+ let config = nn::Conv2dConfig { stride: 2, padding };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Some(conv)
+ } else {
+ None
+ };
+ Ok(Downsample2D { conv, padding })
+ }
+}
+
+impl Downsample2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match &self.conv {
+ None => xs.avg_pool2d((2, 2), (2, 2)),
+ Some(conv) => {
+ if self.padding == 0 {
+ let xs = xs
+ .pad_with_zeros(D::Minus1, 0, 1)?
+ .pad_with_zeros(D::Minus2, 0, 1)?;
+ conv.forward(&xs)
+ } else {
+ conv.forward(xs)
+ }
+ }
+ }
+ }
+}
+
+// This does not support the conv-transpose mode.
+#[derive(Debug)]
+struct Upsample2D {
+ conv: nn::Conv2d,
+}
+
+impl Upsample2D {
+ fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
+ let config = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl Upsample2D {
+ fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
+ let xs = match size {
+ None => {
+ let (_bsize, _channels, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(2 * h, 2 * w)?
+ }
+ Some((h, w)) => xs.upsample_nearest2d(h, w)?,
+ };
+ self.conv.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownEncoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownEncoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownEncoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownEncoderBlock2DConfig,
+}
+
+impl DownEncoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DownEncoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ out_channels: Some(out_channels),
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let downsampler = if config.add_downsample {
+ let downsample = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+}
+
+impl DownEncoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.downsampler {
+ Some(downsampler) => downsampler.forward(&xs),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpDecoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpDecoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpDecoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpDecoderBlock2DConfig,
+}
+
+impl UpDecoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UpDecoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let upsampler = if config.add_upsample {
+ let upsample =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+}
+
+impl UpDecoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, None),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: Option<usize>,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+}
+
+impl Default for UNetMidBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: Some(1),
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2D {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DConfig,
+}
+
+impl UNetMidBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let attn_cfg = AttentionBlockConfig {
+ num_head_channels: config.attn_num_head_channels,
+ num_groups: resnet_groups,
+ rescale_output_factor: config.output_scale_factor,
+ eps: config.resnet_eps,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DCrossAttnConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: usize,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+ pub cross_attn_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNetMidBlock2DCrossAttnConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: 1,
+ output_scale_factor: 1.,
+ cross_attn_dim: 1280,
+ sliced_attention_size: None, // Sliced attention disabled
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2DCrossAttn {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DCrossAttnConfig,
+}
+
+impl UNetMidBlock2DCrossAttn {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DCrossAttnConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let n_heads = config.attn_num_head_channels;
+ let attn_cfg = SpatialTransformerConfig {
+ depth: 1,
+ num_groups: resnet_groups,
+ context_dim: Some(config.cross_attn_dim),
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = SpatialTransformer::new(
+ vs_attns.pp(&index.to_string()),
+ in_channels,
+ n_heads,
+ in_channels / n_heads,
+ attn_cfg,
+ )?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownBlock2DConfig,
+}
+
+impl DownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: DownBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let downsampler = if config.add_downsample {
+ let downsampler = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut xs = xs.clone();
+ let mut output_states = vec![];
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, temb)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnDownBlock2DConfig {
+ pub downblock: DownBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnDownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ downblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnDownBlock2D {
+ downblock: DownBlock2D,
+ attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnDownBlock2DConfig,
+}
+
+impl CrossAttnDownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnDownBlock2DConfig,
+ ) -> Result<Self> {
+ let downblock = DownBlock2D::new(
+ vs.clone(),
+ in_channels,
+ out_channels,
+ temb_channels,
+ config.downblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.downblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.downblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ downblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut output_states = vec![];
+ let mut xs = xs.clone();
+ for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
+ xs = resnet.forward(&xs, temb)?;
+ xs = attn.forward(&xs, encoder_hidden_states)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downblock.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpBlock2D {
+ pub resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpBlock2DConfig,
+}
+
+impl UpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: UpBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ temb_channels,
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let res_skip_channels = if i == config.num_layers - 1 {
+ in_channels
+ } else {
+ out_channels
+ };
+ let resnet_in_channels = if i == 0 {
+ prev_output_channels
+ } else {
+ out_channels
+ };
+ let in_channels = resnet_in_channels + res_skip_channels;
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let upsampler = if config.add_upsample {
+ let upsampler =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnUpBlock2DConfig {
+ pub upblock: UpBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnUpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ upblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnUpBlock2D {
+ pub upblock: UpBlock2D,
+ pub attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnUpBlock2DConfig,
+}
+
+impl CrossAttnUpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnUpBlock2DConfig,
+ ) -> Result<Self> {
+ let upblock = UpBlock2D::new(
+ vs.clone(),
+ in_channels,
+ prev_output_channels,
+ out_channels,
+ temb_channels,
+ config.upblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.upblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.upblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ upblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.upblock.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
+ }
+ match &self.upblock.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
new file mode 100644
index 00000000..ef4dd956
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -0,0 +1,31 @@
+use candle::{Device, Result, Tensor};
+
+pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
+ if steps < 1 {
+ candle::bail!("cannot use linspace with steps {steps} <= 1")
+ }
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
+}
+
+/// Saves an image to disk using the image crate, this expects an input with shape
+/// (c, width, height).
+pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, width, height) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, width, height)")
+ }
+ let img = img.transpose(0, 1)?.t()?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
new file mode 100644
index 00000000..7a10d932
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -0,0 +1,378 @@
+#![allow(dead_code)]
+//! # Variational Auto-Encoder (VAE) Models.
+//!
+//! Auto-encoder models compress their input to a usually smaller latent space
+//! before expanding it back to its original shape. This results in the latent values
+//! compressing the original information.
+use crate::unet_2d_blocks::{
+ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
+ UpDecoderBlock2D, UpDecoderBlock2DConfig,
+};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone)]
+struct EncoderConfig {
+ // down_block_types: DownEncoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+ double_z: bool,
+}
+
+impl Default for EncoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ double_z: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Encoder {
+ conv_in: nn::Conv2d,
+ down_blocks: Vec<DownEncoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: EncoderConfig,
+}
+
+impl Encoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: EncoderConfig,
+ ) -> Result<Self> {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ config.block_out_channels[0],
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mut down_blocks = vec![];
+ let vs_down_blocks = vs.pp("down_blocks");
+ for index in 0..config.block_out_channels.len() {
+ let out_channels = config.block_out_channels[index];
+ let in_channels = if index > 0 {
+ config.block_out_channels[index - 1]
+ } else {
+ config.block_out_channels[0]
+ };
+ let is_final = index + 1 == config.block_out_channels.len();
+ let cfg = DownEncoderBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: !is_final,
+ downsample_padding: 0,
+ ..Default::default()
+ };
+ let down_block = DownEncoderBlock2D::new(
+ vs_down_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ down_blocks.push(down_block)
+ }
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ last_block_out_channels,
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out_channels = if config.double_z {
+ 2 * out_channels
+ } else {
+ out_channels
+ };
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ last_block_out_channels,
+ conv_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ down_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.conv_in.forward(xs)?;
+ for down_block in self.down_blocks.iter() {
+ xs = down_block.forward(&xs)?
+ }
+ let xs = self.mid_block.forward(&xs, None)?;
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderConfig {
+ // up_block_types: UpDecoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+}
+
+impl Default for DecoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Decoder {
+ conv_in: nn::Conv2d,
+ up_blocks: Vec<UpDecoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: DecoderConfig,
+}
+
+impl Decoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DecoderConfig,
+ ) -> Result<Self> {
+ let n_block_out_channels = config.block_out_channels.len();
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ last_block_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let mut up_blocks = vec![];
+ let vs_up_blocks = vs.pp("up_blocks");
+ let reversed_block_out_channels: Vec<_> =
+ config.block_out_channels.iter().copied().rev().collect();
+ for index in 0..n_block_out_channels {
+ let out_channels = reversed_block_out_channels[index];
+ let in_channels = if index > 0 {
+ reversed_block_out_channels[index - 1]
+ } else {
+ reversed_block_out_channels[0]
+ };
+ let is_final = index + 1 == n_block_out_channels;
+ let cfg = UpDecoderBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: !is_final,
+ ..Default::default()
+ };
+ let up_block = UpDecoderBlock2D::new(
+ vs_up_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ up_blocks.push(up_block)
+ }
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ config.block_out_channels[0],
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ config.block_out_channels[0],
+ out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ up_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
+ for up_block in self.up_blocks.iter() {
+ xs = up_block.forward(&xs)?
+ }
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoderKLConfig {
+ pub block_out_channels: Vec<usize>,
+ pub layers_per_block: usize,
+ pub latent_channels: usize,
+ pub norm_num_groups: usize,
+}
+
+impl Default for AutoEncoderKLConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 1,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+pub struct DiagonalGaussianDistribution {
+ mean: Tensor,
+ std: Tensor,
+}
+
+impl DiagonalGaussianDistribution {
+ pub fn new(parameters: &Tensor) -> Result<Self> {
+ let mut parameters = parameters.chunk(2, 1)?.into_iter();
+ let mean = parameters.next().unwrap();
+ let logvar = parameters.next().unwrap();
+ let std = (logvar * 0.5)?.exp()?;
+ Ok(DiagonalGaussianDistribution { mean, std })
+ }
+
+ pub fn sample(&self) -> Result<Tensor> {
+ let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
+ &self.mean + &self.std * sample
+ }
+}
+
+// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
+// This implementation is specific to the config used in stable-diffusion-v1-5
+// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+#[derive(Debug)]
+pub struct AutoEncoderKL {
+ encoder: Encoder,
+ decoder: Decoder,
+ quant_conv: nn::Conv2d,
+ post_quant_conv: nn::Conv2d,
+ pub config: AutoEncoderKLConfig,
+}
+
+impl AutoEncoderKL {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: AutoEncoderKLConfig,
+ ) -> Result<Self> {
+ let latent_channels = config.latent_channels;
+ let encoder_cfg = EncoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ double_z: true,
+ };
+ let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
+ let decoder_cfg = DecoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ };
+ let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
+ let conv_cfg = Default::default();
+ let quant_conv = nn::conv2d(
+ 2 * latent_channels,
+ 2 * latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("quant_conv"),
+ )?;
+ let post_quant_conv = nn::conv2d(
+ latent_channels,
+ latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("post_quant_conv"),
+ )?;
+ Ok(Self {
+ encoder,
+ decoder,
+ quant_conv,
+ post_quant_conv,
+ config,
+ })
+ }
+
+ /// Returns the distribution in the latent space.
+ pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
+ let xs = self.encoder.forward(xs)?;
+ let parameters = self.quant_conv.forward(&xs)?;
+ DiagonalGaussianDistribution::new(&parameters)
+ }
+
+ /// Takes as input some sampled values.
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.post_quant_conv.forward(xs)?;
+ self.decoder.forward(&xs)
+ }
+}
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index ee6ce90a..8d7afa69 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -7,7 +7,7 @@ description = "Flash attention layer for the candle ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
-license = "MIT/Apache-2.0"
+license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml
index 079c60fd..9978b25e 100644
--- a/candle-kernels/Cargo.toml
+++ b/candle-kernels/Cargo.toml
@@ -7,7 +7,7 @@ description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
-license = "MIT/Apache-2.0"
+license = "MIT OR Apache-2.0"
[dependencies]
diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu
index 0aab40cb..85d74b82 100644
--- a/candle-kernels/src/unary.cu
+++ b/candle-kernels/src/unary.cu
@@ -80,6 +80,7 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
+UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
@@ -95,6 +96,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
+UNARY_OP(__half, urecip_f16, recipg(x))
UNARY_OP(__half, uexp_f16, expg(x))
UNARY_OP(__half, ulog_f16, logg(x))
UNARY_OP(__half, usin_f16, sing(x))
@@ -113,6 +115,8 @@ UNARY_OP(float, ucopy_f32, x)
UNARY_OP(double, ucopy_f64, x)
UNARY_OP(float, uneg_f32, -x)
UNARY_OP(double, uneg_f64, -x)
+UNARY_OP(float, urecip_f32, recipg(x))
+UNARY_OP(double, urecip_f64, recipg(x))
UNARY_OP(float, uexp_f32, expg(x))
UNARY_OP(double, uexp_f64, expg(x))
UNARY_OP(float, ulog_f32, logg(x))
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index bb44acd3..6db9ccab 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -10,6 +10,7 @@ license.workspace = true
readme = "README.md"
[dependencies]
+accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@@ -20,5 +21,6 @@ anyhow = { workspace = true }
[features]
default = []
+accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 8fbe7659..67a80417 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -48,3 +48,92 @@ impl Conv1d {
}
}
}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct Conv2dConfig {
+ pub padding: usize,
+ pub stride: usize,
+}
+
+impl Default for Conv2dConfig {
+ fn default() -> Self {
+ Self {
+ padding: 0,
+ stride: 1,
+ }
+ }
+}
+
+#[allow(dead_code)]
+#[derive(Debug)]
+pub struct Conv2d {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ config: Conv2dConfig,
+}
+
+impl Conv2d {
+ pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
+ Self {
+ weight,
+ bias,
+ config,
+ }
+ }
+
+ pub fn config(&self) -> &Conv2dConfig {
+ &self.config
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
+ }
+}
+
+pub fn conv1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv1dConfig,
+ vs: crate::VarBuilder,
+) -> Result<Conv1d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?;
+ let bound = 1. / (in_channels as f64).sqrt();
+ let init_bs = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
+ Ok(Conv1d::new(ws, Some(bs), cfg))
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv2dConfig,
+ vs: crate::VarBuilder,
+) -> Result<Conv2d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vs.get_or_init(
+ (out_channels, in_channels, kernel_size, kernel_size),
+ "weight",
+ init_ws,
+ )?;
+ let bound = 1. / (in_channels as f64).sqrt();
+ let init_bs = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
+ Ok(Conv2d::new(ws, Some(bs), cfg))
+}
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
new file mode 100644
index 00000000..ac77db4b
--- /dev/null
+++ b/candle-nn/src/group_norm.rs
@@ -0,0 +1,83 @@
+//! Group Normalization.
+//!
+//! This layer applies Group Normalization over a mini-batch of inputs.
+use candle::{DType, Result, Tensor};
+
+// This group norm version handles both weight and bias so removes the mean.
+#[derive(Debug)]
+pub struct GroupNorm {
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ num_channels: usize,
+ num_groups: usize,
+}
+
+impl GroupNorm {
+ pub fn new(
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ num_groups: usize,
+ eps: f64,
+ ) -> Result<Self> {
+ if num_channels % num_groups != 0 {
+ candle::bail!(
+ "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
+ )
+ }
+ Ok(Self {
+ weight,
+ bias,
+ eps,
+ num_channels,
+ num_groups,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x_shape = x.dims();
+ if x_shape.len() <= 2 {
+ candle::bail!("input rank for GroupNorm should be at least 3");
+ }
+ let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
+ let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
+ if n_channels != self.num_channels {
+ candle::bail!(
+ "unexpected num-channels in GroupNorm ({n_channels} <> {}",
+ self.num_channels
+ )
+ }
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let mut w_dims = vec![1; x_shape.len()];
+ w_dims[1] = n_channels;
+ let weight = self.weight.reshape(w_dims.clone())?;
+ let bias = self.bias.reshape(w_dims)?;
+ x_normed
+ .to_dtype(x_dtype)?
+ .reshape(x_shape)?
+ .broadcast_mul(&weight)?
+ .broadcast_add(&bias)
+ }
+}
+
+pub fn group_norm(
+ num_groups: usize,
+ num_channels: usize,
+ eps: f64,
+ vb: crate::VarBuilder,
+) -> Result<GroupNorm> {
+ let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
+ let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
+ GroupNorm::new(weight, bias, num_channels, num_groups, eps)
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 46a83800..ae955f56 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -2,8 +2,8 @@
// error type if needed or add some specialized cases on the candle-core side.
pub mod activation;
pub mod conv;
-pub mod dataset;
pub mod embedding;
+pub mod group_norm;
pub mod init;
pub mod layer_norm;
pub mod linear;
@@ -11,11 +11,11 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod var_builder;
-pub mod vision;
pub use activation::Activation;
-pub use conv::{Conv1d, Conv1dConfig};
+pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
+pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, LayerNorm};
pub use linear::{linear, linear_no_bias, Linear};
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 611c66d8..397674f3 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -32,3 +32,13 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
Ok(log_sm)
}
+
+pub fn silu(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
+ xs / (xs.neg()?.exp()? + 1.0)?
+}
+
+pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
+ (xs.neg()?.exp()? + 1.0)?.recip()
+}
diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs
new file mode 100644
index 00000000..f3ef2455
--- /dev/null
+++ b/candle-nn/tests/group_norm.rs
@@ -0,0 +1,103 @@
+/* Equivalent PyTorch code.
+import torch
+from torch.nn.functional import group_norm
+t = torch.tensor(
+ [[[-0.3034, 0.2726, -0.9659],
+ [-1.1845, -1.3236, 0.0172],
+ [ 1.9507, 1.2554, -0.8625],
+ [ 1.0682, 0.3604, 0.3985],
+ [-0.4957, -0.4461, -0.9721],
+ [ 1.5157, -0.1546, -0.5596]],
+
+ [[-1.6698, -0.4040, -0.7927],
+ [ 0.3736, -0.0975, -0.1351],
+ [-0.9461, 0.5461, -0.6334],
+ [-1.0919, -0.1158, 0.1213],
+ [-0.9535, 0.1281, 0.4372],
+ [-0.2845, 0.3488, 0.5641]]])
+print(group_norm(t, num_groups=2))
+print(group_norm(t, num_groups=3))
+*/
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use candle::{Device, Tensor};
+use candle_nn::GroupNorm;
+mod test_utils;
+use test_utils::to_vec3_round;
+
+#[test]
+fn group_norm() -> Result<()> {
+ let device = &Device::Cpu;
+ let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
+ let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
+ let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
+ let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;
+
+ let input = Tensor::new(
+ &[
+ [
+ [-0.3034f32, 0.2726, -0.9659],
+ [-1.1845, -1.3236, 0.0172],
+ [1.9507, 1.2554, -0.8625],
+ [1.0682, 0.3604, 0.3985],
+ [-0.4957, -0.4461, -0.9721],
+ [1.5157, -0.1546, -0.5596],
+ ],
+ [
+ [-1.6698, -0.4040, -0.7927],
+ [0.3736, -0.0975, -0.1351],
+ [-0.9461, 0.5461, -0.6334],
+ [-1.0919, -0.1158, 0.1213],
+ [-0.9535, 0.1281, 0.4372],
+ [-0.2845, 0.3488, 0.5641],
+ ],
+ ],
+ device,
+ )?;
+ assert_eq!(
+ to_vec3_round(gn2.forward(&input)?, 4)?,
+ &[
+ [
+ [-0.1653, 0.3748, -0.7866],
+ [-0.9916, -1.1220, 0.1353],
+ [1.9485, 1.2965, -0.6896],
+ [1.2769, 0.3628, 0.4120],
+ [-0.7427, -0.6786, -1.3578],
+ [1.8547, -0.3022, -0.8252]
+ ],
+ [
+ [-1.9342, 0.0211, -0.5793],
+ [1.2223, 0.4945, 0.4365],
+ [-0.8163, 1.4887, -0.3333],
+ [-1.7960, -0.0392, 0.3875],
+ [-1.5469, 0.3998, 0.9561],
+ [-0.3428, 0.7970, 1.1845]
+ ]
+ ]
+ );
+ assert_eq!(
+ to_vec3_round(gn3.forward(&input)?, 4)?,
+ &[
+ [
+ [0.4560, 1.4014, -0.6313],
+ [-0.9901, -1.2184, 0.9822],
+ [1.4254, 0.6360, -1.7682],
+ [0.4235, -0.3800, -0.3367],
+ [-0.3890, -0.3268, -0.9862],
+ [2.1325, 0.0386, -0.4691]
+ ],
+ [
+ [-1.8797, 0.0777, -0.5234],
+ [1.2802, 0.5517, 0.4935],
+ [-1.0102, 1.5327, -0.4773],
+ [-1.2587, 0.4047, 0.8088],
+ [-1.9074, 0.1691, 0.7625],
+ [-0.6230, 0.5928, 1.0061]
+ ]
+ ]
+ );
+
+ Ok(())
+}
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index e5ebe953..89263fe0 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -16,8 +16,11 @@ doc = false
[dependencies]
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
-pyo3 = { version = "0.19.0", features = ["extension-module"] }
half = { workspace = true }
+pyo3 = { version = "0.19.0", features = ["extension-module"] }
+
+[build-dependencies]
+pyo3-build-config = "0.19"
[features]
default = []
diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md
index 1887f269..f716b092 100644
--- a/candle-pyo3/README.md
+++ b/candle-pyo3/README.md
@@ -1,5 +1,11 @@
-From the top level directory run:
+From the top level directory run the following for linux.
```
-cargo build --release --package candle-pyo3 && cp -f ./target/release/libcandle.so candle.so
+cargo build --profile=release-with-debug --package candle-pyo3 && cp -f ./target/release-with-debug/libcandle.so candle.so
+PYTHONPATH=. python3 candle-pyo3/test.py
+```bash
+
+ Or for macOS users:
+```bash
+cargo build --profile=release-with-debug --package candle-pyo3 && cp -f ./target/release-with-debug/libcandle.dylib candle.so
PYTHONPATH=. python3 candle-pyo3/test.py
```
diff --git a/candle-pyo3/build.rs b/candle-pyo3/build.rs
new file mode 100644
index 00000000..dace4a9b
--- /dev/null
+++ b/candle-pyo3/build.rs
@@ -0,0 +1,3 @@
+fn main() {
+ pyo3_build_config::add_extension_module_link_args();
+}
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 136f8a4f..1ff4db06 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,3 +1,4 @@
+// TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
@@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
-#[derive(Clone)]
+#[derive(Clone, Debug)]
+struct PyShape(Vec<usize>);
+
+impl<'source> pyo3::FromPyObject<'source> for PyShape {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
+ Ok(PyShape(dims))
+ }
+}
+
+impl From<PyShape> for ::candle::Shape {
+ fn from(val: PyShape) -> Self {
+ val.0.into()
+ }
+}
+
+#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
@@ -23,21 +40,30 @@ impl std::ops::Deref for PyTensor {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[pyclass(name = "DType")]
struct PyDType(DType);
-impl<'source> FromPyObject<'source> for PyDType {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
- use std::str::FromStr;
- let dtype: &str = ob.extract()?;
- let dtype = DType::from_str(dtype)
- .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
- Ok(Self(dtype))
+#[pymethods]
+impl PyDType {
+ fn __repr__(&self) -> String {
+ format!("{:?}", self.0)
+ }
+
+ fn __str__(&self) -> String {
+ self.__repr__()
}
}
-impl ToPyObject for PyDType {
- fn to_object(&self, py: Python<'_>) -> PyObject {
- self.0.as_str().to_object(py)
+impl PyDType {
+ fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
+ use std::str::FromStr;
+ if let Ok(dtype) = ob.extract::<&str>(py) {
+ let dtype = DType::from_str(dtype)
+ .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
+ Ok(Self(dtype))
+ } else {
+ ob.extract(py)
+ }
}
}
@@ -206,8 +232,8 @@ impl PyTensor {
}
#[getter]
- fn dtype(&self, py: Python<'_>) -> PyObject {
- PyDType(self.0.dtype()).to_object(py)
+ fn dtype(&self) -> PyDType {
+ PyDType(self.0.dtype())
}
#[getter]
@@ -279,16 +305,15 @@ impl PyTensor {
Ok(Self(tensor))
}
- // TODO: Add a PyShape type?
- fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
- fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
}
- fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
@@ -351,7 +376,8 @@ impl PyTensor {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
- fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
+ fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
+ let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
@@ -381,11 +407,72 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, vs)
}
+#[pyfunction]
+#[pyo3(signature = (shape, *, device=None))]
+fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, device=None))]
+fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, dtype=None, device=None))]
+fn ones(
+ py: Python<'_>,
+ shape: PyShape,
+ dtype: Option<PyObject>,
+ device: Option<PyDevice>,
+) -> PyResult<PyTensor> {
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, dtype=None, device=None))]
+fn zeros(
+ py: Python<'_>,
+ shape: PyShape,
+ dtype: Option<PyObject>,
+ device: Option<PyDevice>,
+) -> PyResult<PyTensor> {
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
+ m.add_class::<PyDType>()?;
+ m.add("u8", PyDType(DType::U8))?;
+ m.add("u32", PyDType(DType::U32))?;
+ m.add("bf16", PyDType(DType::BF16))?;
+ m.add("f16", PyDType(DType::F16))?;
+ m.add("f32", PyDType(DType::F32))?;
+ m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
+ m.add_function(wrap_pyfunction!(ones, m)?)?;
+ m.add_function(wrap_pyfunction!(rand, m)?)?;
+ m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
+ m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())
}
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index 8f906060..1711cdad 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,3 +1,18 @@
+import os
+import sys
+
+# The "import candle" statement below works if there is a "candle.so" file in sys.path.
+# Here we check for shared libraries that can be used in the build directory.
+BUILD_DIR = "./target/release-with-debug"
+so_file = BUILD_DIR + "/candle.so"
+if os.path.islink(so_file): os.remove(so_file)
+for lib_file in ["libcandle.dylib", "libcandle.so"]:
+ lib_file_ = BUILD_DIR + "/" + lib_file
+ if os.path.isfile(lib_file_):
+ os.symlink(lib_file, so_file)
+ sys.path.insert(0, BUILD_DIR)
+ break
+
import candle
t = candle.Tensor(42.0)
@@ -12,4 +27,9 @@ print(t+t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))
+print(t.to_dtype(candle.u8))
print(t.to_dtype("u8"))
+
+t = candle.randn((5, 3))
+print(t)
+print(t.dtype)
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index a37cc12a..457c0776 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -10,6 +10,7 @@ license.workspace = true
readme = "README.md"
[dependencies]
+accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
hf-hub = { workspace = true}
candle-nn = { path = "../candle-nn", version = "0.1.0" }
@@ -20,5 +21,6 @@ wav = { workspace = true }
[features]
default = []
+accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]