From 3aac1047fec43a4d756ae4e60a8ae82f7c3e636e Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 10:52:34 +0100 Subject: Sketch the conv1d op. --- candle-core/src/backprop.rs | 8 +++++++- candle-core/src/cpu_backend.rs | 11 +++++++++++ candle-core/src/cuda_backend.rs | 11 +++++++++++ candle-core/src/dummy_cuda_backend.rs | 11 +++++++++++ candle-core/src/op.rs | 8 ++++++++ candle-core/src/storage.rs | 27 +++++++++++++++++++++++++++ candle-core/src/tensor.rs | 22 ++++++++++++++++++++++ 7 files changed, 97 insertions(+), 1 deletion(-) (limited to 'candle-core/src') diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 45448505..a44f732f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -33,7 +33,12 @@ impl Tensor { track_grad |= tg; nodes } - Op::Add(lhs, rhs) + Op::Conv1D { + arg: lhs, + kernel: rhs, + .. + } + | Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) @@ -147,6 +152,7 @@ impl Tensor { let f_grad = pred.where_cond(&zeros, &grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?; } + Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }), Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0871175f..ed3a5998 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -627,6 +627,17 @@ impl CpuStorage { WCond(pred, layout).map(t, t_l, f, f_l) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _padding: usize, + _stride: usize, + ) -> Result { + todo!() + } + pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; let (vocab_size, hidden_size) = rhs_l.shape().r2()?; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0c87004b..ec69688c 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -801,6 +801,17 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _padding: usize, + _stride: usize, + ) -> Result { + todo!() + } + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let device = self.device().clone(); let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index b025eeab..eca5961b 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -100,6 +100,17 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _padding: usize, + _stride: usize, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 860be0b3..ee57b325 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -12,6 +12,14 @@ pub(crate) enum Op { Embedding(Tensor, Tensor), WhereCond(Tensor, Tensor, Tensor), + #[allow(dead_code)] + Conv1D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + Cat(Vec, usize), #[allow(dead_code)] // add is currently unused. diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 4e630a58..235080c0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -144,6 +144,33 @@ impl Storage { } } + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + padding: usize, + stride: usize, + ) -> Result { + self.same_device(kernel, "conv1d")?; + self.same_dtype(kernel, "conv1d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; + Ok(Self::Cuda(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv1d", + }), + } + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a468d879..26d44718 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -432,6 +432,28 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { + let storage = self.storage.conv1d( + self.layout(), + &kernel.storage, + kernel.layout(), + padding, + stride, + )?; + let op = if self.track_op() || kernel.track_op() { + Some(Op::Conv1D { + arg: self.clone(), + kernel: kernel.clone(), + padding, + stride, + }) + } else { + None + }; + let dims = self.dims(); + Ok(from_storage(storage, dims, op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); -- cgit v1.2.3 From a424d95473ea9268ffb1dde4d73ce0cff9904845 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 11:15:45 +0100 Subject: Add more of the conv1d op. --- candle-core/src/conv.rs | 24 ++++++++++++++++++++++++ candle-core/src/cpu_backend.rs | 3 +-- candle-core/src/cuda_backend.rs | 3 +-- candle-core/src/dummy_cuda_backend.rs | 3 +-- candle-core/src/lib.rs | 1 + candle-core/src/storage.rs | 7 +++---- candle-core/src/tensor.rs | 27 ++++++++++++++++++++------- candle-examples/examples/whisper/main.rs | 3 +-- 8 files changed, 52 insertions(+), 19 deletions(-) create mode 100644 candle-core/src/conv.rs (limited to 'candle-core/src') diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs new file mode 100644 index 00000000..90bb5229 --- /dev/null +++ b/candle-core/src/conv.rs @@ -0,0 +1,24 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ParamsConv1D { + pub(crate) b_size: Option, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self, l_in: usize) -> usize { + let dilation = 1; + (l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self, l_in: usize) -> Vec { + let l_out = self.l_out(l_in); + match self.b_size { + None => vec![self.c_out, l_out], + Some(n) => vec![n, self.c_out, l_out], + } + } +} diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index ed3a5998..54002184 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -632,8 +632,7 @@ impl CpuStorage { _l: &Layout, _kernel: &Self, _kernel_l: &Layout, - _padding: usize, - _stride: usize, + _params: &crate::conv::ParamsConv1D, ) -> Result { todo!() } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index ec69688c..917655fc 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -806,8 +806,7 @@ impl CudaStorage { _l: &Layout, _kernel: &Self, _kernel_l: &Layout, - _padding: usize, - _stride: usize, + _params: &crate::conv::ParamsConv1D, ) -> Result { todo!() } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index eca5961b..0dbd8d54 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -105,8 +105,7 @@ impl CudaStorage { _l: &Layout, _kernel: &Self, _kernel_l: &Layout, - _padding: usize, - _stride: usize, + _params: &crate::conv::ParamsConv1D, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0d4c2a8d..2365a34d 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -1,4 +1,5 @@ mod backprop; +mod conv; mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 235080c0..53ea1544 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -149,18 +149,17 @@ impl Storage { l: &Layout, kernel: &Self, kernel_l: &Layout, - padding: usize, - stride: usize, + params: &crate::conv::ParamsConv1D, ) -> Result { self.same_device(kernel, "conv1d")?; self.same_dtype(kernel, "conv1d")?; match (self, &kernel) { (Storage::Cpu(inp), Storage::Cpu(kernel)) => { - let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; + let s = inp.conv1d(l, kernel, kernel_l, params)?; Ok(Self::Cpu(s)) } (Storage::Cuda(inp), Storage::Cuda(kernel)) => { - let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?; + let s = inp.conv1d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 26d44718..590b81c4 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -433,13 +433,26 @@ impl Tensor { } pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { - let storage = self.storage.conv1d( - self.layout(), - &kernel.storage, - kernel.layout(), + let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (b_size, c_in, l_in) = match *self.dims() { + [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), + [c_in, l_in] => (None, c_in, l_in), + _ => todo!("proper error message"), + }; + if c_in != c_in_k { + todo!("proper error message") + } + let params = crate::conv::ParamsConv1D { + b_size, + c_out, + c_in, + k_size, padding, stride, - )?; + }; + let storage = + self.storage + .conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?; let op = if self.track_op() || kernel.track_op() { Some(Op::Conv1D { arg: self.clone(), @@ -450,8 +463,8 @@ impl Tensor { } else { None }; - let dims = self.dims(); - Ok(from_storage(storage, dims, op, false)) + let out_dims = params.out_dims(l_in); + Ok(from_storage(storage, out_dims, op, false)) } pub fn matmul(&self, rhs: &Self) -> Result { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 75ab2189..a380d30e 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -236,8 +236,7 @@ impl Conv1D { fn forward(&self, x: &Tensor) -> candle::Result { let (bsize, _, _) = x.shape().r3()?; let w = self.weight.broadcast_left(bsize)?.t()?; - // TODO: Add the conv1d operation - let x = x.matmul(&w)?; + let x = x.conv1d(&w, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias), -- cgit v1.2.3 From 950b4af49e56b640b87eb273e839b2fd466e1424 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 11:29:28 +0100 Subject: Proper conv1d dispatch. --- candle-core/src/conv.rs | 11 +++++++---- candle-core/src/cpu_backend.rs | 30 +++++++++++++++++++++++++----- candle-core/src/tensor.rs | 3 ++- 3 files changed, 34 insertions(+), 10 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 90bb5229..041bb6fb 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,6 +1,9 @@ #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct ParamsConv1D { pub(crate) b_size: Option, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) k_size: usize, @@ -9,13 +12,13 @@ pub(crate) struct ParamsConv1D { } impl ParamsConv1D { - pub(crate) fn l_out(&self, l_in: usize) -> usize { + pub(crate) fn l_out(&self) -> usize { let dilation = 1; - (l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 } - pub(crate) fn out_dims(&self, l_in: usize) -> Vec { - let l_out = self.l_out(l_in); + pub(crate) fn out_dims(&self) -> Vec { + let l_out = self.l_out(); match self.b_size { None => vec![self.c_out, l_out], Some(n) => vec![n, self.c_out, l_out], diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 54002184..718b071c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -202,6 +202,26 @@ fn copy_strided_src_( } } +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); + +impl<'a> Map2 for Conv1D<'a> { + const OP: &'static str = "conv1d"; + fn f( + &self, + _inp: &[T], + _inp_l: &Layout, + _k: &[T], + _k_l: &Layout, + ) -> Result> { + let p = self.0; + let l_out = p.l_out(); + let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst = vec![T::zero(); out_elems]; + // TODO: actually implement the ops. + Ok(dst) + } +} + struct MatMul((usize, usize, usize, usize)); impl Map2 for MatMul { @@ -629,12 +649,12 @@ impl CpuStorage { pub(crate) fn conv1d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &crate::conv::ParamsConv1D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, ) -> Result { - todo!() + Conv1D(params).map(self, l, kernel, kernel_l) } pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 590b81c4..25ab0a9b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -444,6 +444,7 @@ impl Tensor { } let params = crate::conv::ParamsConv1D { b_size, + l_in, c_out, c_in, k_size, @@ -463,7 +464,7 @@ impl Tensor { } else { None }; - let out_dims = params.out_dims(l_in); + let out_dims = params.out_dims(); Ok(from_storage(storage, out_dims, op, false)) } -- cgit v1.2.3 From b3d4d0fd0f7c0115cf5c8cea60094abef5536e56 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 13:50:41 +0100 Subject: Very inefficient conv1d implementation. --- candle-core/src/cpu_backend.rs | 53 +++++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 8 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 718b071c..4eb57bc7 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -206,18 +206,55 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { const OP: &'static str = "conv1d"; - fn f( + fn f( &self, - _inp: &[T], - _inp_l: &Layout, - _k: &[T], - _k_l: &Layout, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, ) -> Result> { + // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { + (inp_stride[0], &inp_stride[1..]) + } else { + (0, inp_stride) // This value never gets used anyway + }; + let k_stride = k_l.stride(); + let k_over_2 = p.k_size / 2; let l_out = p.l_out(); - let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1); - let dst = vec![T::zero(); out_elems]; - // TODO: actually implement the ops. + let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let mut dst = vec![T::zero(); dst_elems]; + // The output shape is [b_size, c_out, l_out] + for b_idx in 0..p.b_size.unwrap_or(1) { + let inp_idx = b_idx * inp_stride0; + let dst_idx = b_idx * p.c_out * l_out; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; + let mut d = T::zero(); + for offset in 0..p.k_size { + // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] + if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in { + let src_l = dst_l + offset - k_over_2; + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] + } + } + } + dst[dst_idx] = d + } + } + } Ok(dst) } } -- cgit v1.2.3 From 459e2e1ae34c624ae9f54aac5430f9ed7e1666f3 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 15:05:04 +0100 Subject: Properly handle the stride in conv1d. --- candle-core/src/cpu_backend.rs | 5 +++-- candle-examples/examples/whisper/main.rs | 13 ++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4eb57bc7..b2345756 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -238,9 +238,10 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + dst_l; let mut d = T::zero(); for offset in 0..p.k_size { + let src_l_plus = p.stride * dst_l + offset; // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] - if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in { - let src_l = dst_l + offset - k_over_2; + if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { + let src_l = src_l_plus - k_over_2; for src_c_idx in 0..p.c_in { let inp_idx = inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 839dfc13..d119b6a7 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -435,12 +435,7 @@ impl AudioEncoder { }; let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; - let positional_embedding = if true { - vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))? - } else { - /* The positional embeddings could be regenerated via the following. */ - sinusoids(n_ctx, n_state)?.to_device(&vb.device)? - }; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; let blocks = (0..cfg.n_audio_layer) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) @@ -567,7 +562,11 @@ fn main() -> Result<()> { let model = Whisper::load(&vb, &cfg)?; let logits = model.forward(&mel, &tokens)?; - println!("{logits}"); + println!("tokens\n{tokens}"); + println!("logits:\n{logits}"); println!("python logits: {}", input.tensor("dec", &device)?); + let enc = model.encoder.forward(&mel)?; + println!("encoder:\n{enc}"); + println!("python enc: {}", input.tensor("enc", &device)?); Ok(()) } -- cgit v1.2.3 From 6d1e79d3782c93c7f6a097ce71a91a9a277e52ed Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 5 Jul 2023 06:42:29 +0100 Subject: Bugfix for to_scalar (use the proper start offset). --- candle-core/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'candle-core/src') diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 25ab0a9b..95f663f0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -326,7 +326,7 @@ impl Tensor { } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok::<_, Error>(data[0]) + Ok::<_, Error>(data[self.layout().start_offset()]) }; match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), -- cgit v1.2.3