summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 12:58:55 +0100
committerGitHub <noreply@github.com>2023-08-23 12:58:55 +0100
commitaba1e90797e430f28eec13b14b76dd5355876f9c (patch)
tree16bcf7fb151715d3bcdbec2b5263922bd0bdd35a
parent4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff)
downloadcandle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
-rw-r--r--CHANGELOG.md13
-rw-r--r--Cargo.toml2
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/examples/basics.rs2
-rw-r--r--candle-core/examples/cpu_benchmarks.rs4
-rw-r--r--candle-core/examples/cuda_basics.rs2
-rw-r--r--candle-core/src/conv.rs112
-rw-r--r--candle-core/src/tensor.rs70
-rw-r--r--candle-core/tests/conv_tests.rs14
-rw-r--r--candle-datasets/Cargo.toml4
-rw-r--r--candle-examples/Cargo.toml10
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs12
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs6
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs4
-rw-r--r--candle-examples/examples/whisper/model.rs2
-rw-r--r--candle-examples/examples/yolo-v3/darknet.rs6
-rw-r--r--candle-examples/examples/yolo-v8/main.rs6
-rw-r--r--candle-flash-attn/Cargo.toml6
-rw-r--r--candle-kernels/Cargo.toml2
-rw-r--r--candle-nn/Cargo.toml2
-rw-r--r--candle-nn/src/conv.rs18
-rw-r--r--candle-pyo3/Cargo.toml2
-rw-r--r--candle-transformers/Cargo.toml4
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml4
-rw-r--r--candle-wasm-examples/whisper/Cargo.toml4
-rw-r--r--candle-wasm-examples/whisper/src/model.rs2
-rw-r--r--candle-wasm-examples/yolo/Cargo.toml4
-rw-r--r--candle-wasm-examples/yolo/src/model.rs6
30 files changed, 216 insertions, 113 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 00000000..7f997cb0
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,13 @@
+# Changelog
+This documents the main changes to the `candle` crate.
+
+## Unreleased
+### Added
+- Add a group parameter to convolutions
+ [566](https://github.com/huggingface/candle/pull/566).
+- New dtype: int64
+ [563](https://github.com/huggingface/candle/pull/563).
+- Handling of the GGUF file format.
+ [559](https://github.com/huggingface/candle/pull/559).
+
+## v0.1.2 - 2023-08-21
diff --git a/Cargo.toml b/Cargo.toml
index 7957c038..d391ee7a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,7 +16,7 @@ exclude = [
]
[workspace.package]
-version = "0.1.2"
+version = "0.1.3"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index b190c55e..3b3e4eb7 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.1.2", optional = true }
+candle-kernels = { path = "../candle-kernels", version = "0.1.3", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index efce913a..9d4734de 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
- let res = inp.conv2d(&w, 0, 1);
+ let res = inp.conv2d(&w, 0, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
index d7f60f81..1ebd9b75 100644
--- a/candle-core/examples/cpu_benchmarks.rs
+++ b/candle-core/examples/cpu_benchmarks.rs
@@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv1d(&d.1, 0, 1)
+ d.0.conv1d(&d.1, 0, 1, 1)
}
const ITERS: usize = 5;
@@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv2d(&d.1, 0, 1)
+ d.0.conv2d(&d.1, 0, 1, 1)
}
const ITERS: usize = 1;
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index 12febb60..ac435488 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
- let res = t.conv2d(&w, 1, 1)?;
+ let res = t.conv2d(&w, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index e3fea861..d4b7a76d 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,3 +1,5 @@
+use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: usize,
@@ -51,3 +53,113 @@ impl ParamsConv2D {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
+
+impl Tensor {
+ fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 1D convolution over the input tensor.
+ pub fn conv1d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (c_out, c_in_k, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
+ if c_in != c_in_k * groups {
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ }
+ .bt())?
+ }
+
+ let params = ParamsConv1D {
+ b_size,
+ l_in,
+ c_out,
+ c_in,
+ k_size,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv1d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv1d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 2D convolution over the input tensor.
+ pub fn conv2d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k * groups {
+ crate::bail!(
+ "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
+ )
+ }
+ let params = ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv2d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv2d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a4b9795b..46f9c53f 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
-fn from_storage<S: Into<Shape>>(
+pub(crate) fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: BackpropOp,
@@ -787,72 +787,6 @@ impl Tensor {
self.cmp(rhs, CmpOp::Le)
}
- /// Applies a 1D convolution over the input tensor.
- pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = self.dims3()?;
- if c_in != c_in_k {
- Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "the number of in-channels on the input doesn't match the kernel size",
- }
- .bt())?
- }
- let params = crate::conv::ParamsConv1D {
- b_size,
- l_in,
- c_out,
- c_in,
- k_size,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
- pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (b_size, c_in, i_h, i_w) = self.dims4()?;
- let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
- if c_in != c_in_k {
- crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
- }
- let params = crate::conv::ParamsConv2D {
- b_size,
- i_h,
- i_w,
- k_h,
- k_w,
- c_out,
- c_in,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
@@ -1920,7 +1854,7 @@ impl Tensor {
}
}
- fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
+ pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index c777fec7..d09fa344 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -33,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
- let res = t.conv1d(&w, 0, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
@@ -52,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> {
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
- let res = t.conv1d(&w, 0, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -109,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -143,7 +143,7 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -162,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index 88b81311..d4a34b01 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.1.2" }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.1.3" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 24ad47f2..bbd7c3b0 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.1.2" }
-candle-nn = { path = "../candle-nn", version = "0.1.2" }
-candle-transformers = { path = "../candle-transformers", version = "0.1.2" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.2", optional = true }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.1.3" }
+candle-nn = { path = "../candle-nn", version = "0.1.3" }
+candle-transformers = { path = "../candle-transformers", version = "0.1.3" }
+candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.3", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index 9c966497..e7712bf3 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -274,14 +274,22 @@ impl EncodecConv1d {
in_c,
out_c,
kernel_size,
- Conv1dConfig { padding: 0, stride },
+ Conv1dConfig {
+ padding: 0,
+ stride,
+ groups: 1,
+ },
vb.pp("conv"),
)?,
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
- Conv1dConfig { padding: 0, stride },
+ Conv1dConfig {
+ padding: 0,
+ stride,
+ groups: 1,
+ },
vb.pp("conv"),
)?,
};
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
index 94f436c8..172a9359 100644
--- a/candle-examples/examples/stable-diffusion/resnet.rs
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -66,6 +66,7 @@ impl ResnetBlock2D {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
+ groups: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@@ -79,6 +80,7 @@ impl ResnetBlock2D {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
+ groups: 1,
};
Some(conv2d(
in_channels,
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index 6f568113..eb2dbf10 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -112,8 +112,8 @@ impl UNet2DConditionModel {
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index b7adb2c0..65341e74 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -24,7 +24,11 @@ impl Downsample2D {
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
- let config = nn::Conv2dConfig { stride: 2, padding };
+ let config = nn::Conv2dConfig {
+ stride: 2,
+ padding,
+ ..Default::default()
+ };
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
index abba39fa..aa8e13a0 100644
--- a/candle-examples/examples/stable-diffusion/vae.rs
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -51,8 +51,8 @@ impl Encoder {
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
@@ -182,8 +182,8 @@ impl Decoder {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 553bd93b..4ccc79f7 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -308,10 +308,12 @@ impl AudioEncoder {
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
+ groups: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
+ groups: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs
index d0392308..de8fcf09 100644
--- a/candle-examples/examples/yolo-v3/darknet.rs
+++ b/candle-examples/examples/yolo-v3/darknet.rs
@@ -128,7 +128,11 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
}
Some(_) | None => (None, true),
};
- let conv_cfg = candle_nn::Conv2dConfig { stride, padding };
+ let conv_cfg = candle_nn::Conv2dConfig {
+ stride,
+ padding,
+ groups: 1,
+ };
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
} else {
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index 616e04ed..3b9c1ce9 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -101,7 +101,11 @@ impl ConvBlock {
padding: Option<usize>,
) -> Result<Self> {
let padding = padding.unwrap_or(k / 2);
- let cfg = Conv2dConfig { padding, stride };
+ let cfg = Conv2dConfig {
+ padding,
+ stride,
+ groups: 1,
+ };
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
Ok(Self { conv, bn })
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index f88a88d5..b0efaf52 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
-version = "0.1.2"
+version = "0.1.3"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
-candle = { path = "../candle-core", features = ["cuda"], version = "0.1.2", package = "candle-core" }
+candle = { path = "../candle-core", features = ["cuda"], version = "0.1.3", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
-candle-nn = { path = "../candle-nn", version = "0.1.2", features = ["cuda"] }
+candle-nn = { path = "../candle-nn", version = "0.1.3", features = ["cuda"] }
diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml
index a3f55c3d..6144e2d5 100644
--- a/candle-kernels/Cargo.toml
+++ b/candle-kernels/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
-version = "0.1.2"
+version = "0.1.3"
edition = "2021"
description = "CUDA kernels for Candle"
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index b3e9c0bf..7cd1d7a2 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
safetensors = { workspace = true }
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index df9818ab..204402c3 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv1dConfig {
@@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -39,7 +41,12 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv1d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {
@@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv2dConfig {
@@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -90,7 +99,12 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv2d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 1a64cc17..45ab38c0 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -15,7 +15,7 @@ crate-type = ["cdylib"]
doc = false
[dependencies]
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 92b1137a..5c4c8860 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.1.2" }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.1.3" }
intel-mkl-src = { workspace = true, optional = true }
rand = { workspace = true }
wav = { workspace = true }
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
index a43578cd..370708bd 100644
--- a/candle-wasm-examples/llama2-c/Cargo.toml
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.1.2" }
+candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.1.3" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml
index 5d777011..f404af55 100644
--- a/candle-wasm-examples/whisper/Cargo.toml
+++ b/candle-wasm-examples/whisper/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.1.2" }
+candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.1.3" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs
index 3470c3d6..aea993f5 100644
--- a/candle-wasm-examples/whisper/src/model.rs
+++ b/candle-wasm-examples/whisper/src/model.rs
@@ -295,10 +295,12 @@ impl AudioEncoder {
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
+ groups: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
+ groups: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml
index ef9498ee..b565c04b 100644
--- a/candle-wasm-examples/yolo/Cargo.toml
+++ b/candle-wasm-examples/yolo/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.1.2" }
+candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.1.3" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs
index 50fd100c..7e40fcfc 100644
--- a/candle-wasm-examples/yolo/src/model.rs
+++ b/candle-wasm-examples/yolo/src/model.rs
@@ -97,7 +97,11 @@ impl ConvBlock {
padding: Option<usize>,
) -> Result<Self> {
let padding = padding.unwrap_or(k / 2);
- let cfg = Conv2dConfig { padding, stride };
+ let cfg = Conv2dConfig {
+ padding,
+ stride,
+ groups: 1,
+ };
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
Ok(Self { conv, bn })