use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;

const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV2D: bool = true;

// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)]
pub enum CpuStorage {
    U8(Vec<u8>),
    U32(Vec<u32>),
    I64(Vec<i64>),
    BF16(Vec<bf16>),
    F16(Vec<f16>),
    F32(Vec<f32>),
    F64(Vec<f64>),
}

#[derive(Debug, Clone)]
pub struct CpuDevice;

pub trait Map1 {
    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;

    fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
        match vs {
            CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
            CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
            CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
            CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
            CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
            CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
            CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
        }
    }
}

pub trait Map1Any {
    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
        &self,
        vs: &[T],
        layout: &Layout,
        wrap: W,
    ) -> Result<CpuStorage>;

    fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
        match vs {
            CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
            CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
            CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
            CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
            CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
            CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
            CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
        }
    }
}

type C = CpuStorage;
pub trait Map2 {
    const OP: &'static str;
    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;

    fn map(
        &self,
        v1: &CpuStorage,
        l1: &Layout,
        v2: &CpuStorage,
        l2: &Layout,
    ) -> Result<CpuStorage> {
        match (v1, v2) {
            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
            (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
            (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
            (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
            (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
            (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
            _ => Err(Error::DTypeMismatchBinaryOp {
                lhs: v1.dtype(),
                rhs: v2.dtype(),
                op: Self::OP,
            }
            .bt()),
        }
    }
}

pub trait Map2U8 {
    const OP: &'static str;
    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;

    fn map(
        &self,
        v1: &CpuStorage,
        l1: &Layout,
        v2: &CpuStorage,
        l2: &Layout,
    ) -> Result<CpuStorage> {
        match (v1, v2) {
            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
            _ => Err(Error::DTypeMismatchBinaryOp {
                lhs: v1.dtype(),
                rhs: v2.dtype(),
                op: Self::OP,
            }
            .bt()),
        }
    }
}

struct Cmp(CmpOp);
impl Map2U8 for Cmp {
    const OP: &'static str = "cmp";
    #[inline(always)]
    fn f<T: WithDType>(
        &self,
        lhs: &[T],
        lhs_l: &Layout,
        rhs: &[T],
        rhs_l: &Layout,
    ) -> Result<Vec<u8>> {
        let dst = match self.0 {
            CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
            CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
            CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
            CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
            CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
            CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
        };
        Ok(dst)
    }
}

struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);

impl<'a, I: IntDType> Map2 for WCond<'a, I> {
    const OP: &'static str = "where";
    #[inline(always)]
    fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
        let vs = match (
            self.1.contiguous_offsets(),
            t_l.contiguous_offsets(),
            f_l.contiguous_offsets(),
        ) {
            (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
                let pred = &self.0[o1..o2];
                let t = &t[o_t1..o_t2];
                let f = &f[o_f1..o_f2];
                pred.iter()
                    .zip(t.iter().zip(f.iter()))
                    .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
                    .collect::<Vec<_>>()
            }
            _ => self
                .1
                .strided_index()
                .zip(t_l.strided_index().zip(f_l.strided_index()))
                .map(|(i_p, (i_t, i_f))| {
                    if self.0[i_p].is_true() {
                        t[i_t]
                    } else {
                        f[i_f]
                    }
                })
                .collect::<Vec<_>>(),
        };
        Ok(vs)
    }
}

struct ReduceIndex {
    reduce_dim_index: usize,
    use_min: bool,
    return_index: bool,
}

impl ReduceIndex {
    // The value gets replaced if f(s[current_acc], s[i]) returns true.
    #[inline(always)]
    fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
    where
        T: Clone + Copy,
        U: Clone + Copy,
        F: Fn(T, T) -> bool,
        G: Fn(T, usize) -> U,
    {
        let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
        let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
        let dst_len = src_l.shape().elem_count() / reduce_dim_size;
        let mut dst: Vec<U> = Vec::with_capacity(dst_len);
        let dst_to_set = dst.spare_capacity_mut();
        let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
        match src_l.contiguous_offsets() {
            Some((o1, o2)) => {
                let src = &src[o1..o2];
                if reduce_dim_stride == 1 {
                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
                        let start_src_i = start_src_i * reduce_dim_size;
                        let src = &src[start_src_i..start_src_i + reduce_dim_size];
                        let mut acc = 0;
                        let mut val = src[0];
                        for (src_i, &s) in src.iter().enumerate() {
                            if f(val, s) {
                                acc = src_i;
                                val = s
                            }
                        }
                        *dst_v = g(val, acc)
                    }
                } else {
                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
                        let (p, q) = (
                            start_src_i / reduce_dim_stride,
                            start_src_i % reduce_dim_stride,
                        );
                        // start_src_i = p * reduce_dim_stride + q
                        let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
                        let src = &src[start_src_i..];
                        let mut acc = 0;
                        let mut val = src[0];
                        for src_i in 0..reduce_dim_size {
                            let s = src[src_i * reduce_dim_stride];
                            if f(val, s) {
                                acc = src_i;
                                val = s
                            }
                        }
                        *dst_v = g(val, acc)
                    }
                }
            }
            None => {
                let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
                for (unstr_index, src_index) in l.strided_index().enumerate() {
                    let src = &src[src_index..];
                    let mut acc = 0;
                    let mut val = src[0];
                    for src_i in 0..reduce_dim_size {
                        let s = src[src_i * reduce_dim_stride];
                        if f(val, s) {
                            acc = src_i;
                            val = s
                        }
                    }
                    dst_to_set[unstr_index] = g(val, acc)
                }
            }
        }
        unsafe { dst.set_len(dst_len) };
        Ok(dst)
    }
}

impl Map1Any for ReduceIndex {
    #[inline(always)]
    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
        &self,
        src: &[T],
        src_l: &Layout,
        wrap: W,
    ) -> Result<CpuStorage> {
        if src_l.shape().elem_count() == 0 {
            Err(Error::EmptyTensor { op: "reduce" }.bt())?
        }
        let dst = match (self.return_index, self.use_min) {
            (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
            (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
            (true, true) => {
                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
            }
            (true, false) => {
                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
            }
        };
        Ok(dst)
    }
}

struct ReduceSum<'a> {
    dst_shape: &'a Shape,
    reduce_dims: &'a [usize],
    reduce_dims_and_stride: Vec<(usize, usize)>,
}

impl<'a> ReduceSum<'a> {
    #[inline(always)]
    fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
    where
        T: WithDType,
    {
        let mut dst = vec![start_elt; self.dst_shape.elem_count()];
        match src_l.contiguous_offsets() {
            Some((o1, o2)) => {
                let src = &src[o1..o2];
                // Handle the case where we reduce over the last dimensions separately as it is
                // fairly common and easy to optimize. This rely on the layout being contiguous!
                // reduce_dims is sorted, check if it is ranging from a to n-1.
                let reduce_over_last_dims = self
                    .reduce_dims
                    .iter()
                    .rev()
                    .enumerate()
                    .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
                if reduce_over_last_dims {
                    let reduce_sz = self
                        .reduce_dims_and_stride
                        .iter()
                        .map(|(u, _)| u)
                        .product::<usize>();
                    for (dst_i, dst_v) in dst.iter_mut().enumerate() {
                        let src_i = dst_i * reduce_sz;
                        unsafe {
                            T::vec_reduce_sum(
                                src[src_i..src_i + reduce_sz].as_ptr(),
                                dst_v,
                                reduce_sz,
                            )
                        };
                    }
                    return Ok(dst);
                };
                for (unstr_index, &src) in src.iter().enumerate() {
                    let mut dst_index = unstr_index;
                    // Set the reduce_dims indexes to 0.
                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
                        // The compiler is able to optimize the following in a single divmod op.
                        let (pre, post) = (dst_index / stride, dst_index % stride);
                        dst_index = (pre / dim) * stride + post;
                    }
                    dst[dst_index] += src;
                }
            }
            None => {
                for (unstr_index, src_index) in src_l.strided_index().enumerate() {
                    let mut dst_index = unstr_index;
                    // Set the reduce_dims indexes to 0.
                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
                        // The compiler is able to optimize the following in a single divmod op.
                        let (pre, post) = (dst_index / stride, dst_index % stride);
                        dst_index = (pre / dim) * stride + post;
                    }
                    dst[dst_index] += src[src_index];
                }
            }
        }
        Ok(dst)
    }
}

impl<'a> Map1 for ReduceSum<'a> {
    #[inline(always)]
    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
        self.fold_impl(src, src_l, T::zero())
    }
}

pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
    vs: &[T],
    layout: &Layout,
    mut f: F,
) -> Vec<U> {
    match layout.strided_blocks() {
        crate::StridedBlocks::SingleBlock { start_offset, len } => vs
            [start_offset..start_offset + len]
            .iter()
            .map(|&v| f(v))
            .collect(),
        crate::StridedBlocks::MultipleBlocks {
            block_start_index,
            block_len,
        } => {
            let mut result = Vec::with_capacity(layout.shape().elem_count());
            // Specialize the case where block_len is one to avoid the second loop.
            if block_len == 1 {
                for index in block_start_index {
                    let v = unsafe { vs.get_unchecked(index) };
                    result.push(f(*v))
                }
            } else {
                for index in block_start_index {
                    for offset in 0..block_len {
                        let v = unsafe { vs.get_unchecked(index + offset) };
                        result.push(f(*v))
                    }
                }
            }
            result
        }
    }
}

pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
    vs: &[T],
    layout: &Layout,
    mut f: F,
    mut f_vec: FV,
) -> Vec<U> {
    match layout.strided_blocks() {
        crate::StridedBlocks::SingleBlock { start_offset, len } => {
            let mut ys: Vec<U> = Vec::with_capacity(len);
            let ys_to_set = ys.spare_capacity_mut();
            let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
            f_vec(&vs[start_offset..start_offset + len], ys_to_set);
            // SAFETY: values are all set by f_vec.
            unsafe { ys.set_len(len) };
            ys
        }
        crate::StridedBlocks::MultipleBlocks {
            block_start_index,
            block_len,
        } => {
            let el_count = layout.shape().elem_count();
            // Specialize the case where block_len is one to avoid the second loop.
            if block_len == 1 {
                let mut result = Vec::with_capacity(el_count);
                for index in block_start_index {
                    let v = unsafe { vs.get_unchecked(index) };
                    result.push(f(*v))
                }
                result
            } else {
                let mut ys: Vec<U> = Vec::with_capacity(el_count);
                let ys_to_set = ys.spare_capacity_mut();
                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
                let mut dst_index = 0;
                for src_index in block_start_index {
                    let vs = &vs[src_index..src_index + block_len];
                    let ys = &mut ys_to_set[dst_index..dst_index + block_len];
                    f_vec(vs, ys);
                    dst_index += block_len;
                }
                // SAFETY: values are all set by f_vec.
                unsafe { ys.set_len(el_count) };
                ys
            }
        }
    }
}

// This function maps over two strided index sequences.
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
    lhs_l: &Layout,
    rhs_l: &Layout,
    lhs: &[T],
    rhs: &[T],
    mut f: F,
) -> Vec<U> {
    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
            .iter()
            .zip(rhs[o_r1..o_r2].iter())
            .map(|(&l, &r)| f(l, r))
            .collect(),
        (Some((o_l1, o_l2)), None) => {
            // TODO: Maybe we want to avoid going through the layout twice.
            match rhs_l.offsets_b() {
                Some(ob) => {
                    let mut i_in_block = 0;
                    let mut i_right_broadcast = 0;
                    lhs[o_l1..o_l2]
                        .iter()
                        .map(|&l| {
                            let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
                            i_right_broadcast += 1;
                            if i_right_broadcast >= ob.right_broadcast {
                                i_in_block += 1;
                                i_right_broadcast = 0;
                            }
                            if i_in_block >= ob.len {
                                i_in_block = 0
                            }
                            f(l, *r)
                        })
                        .collect()
                }
                None => lhs_l
                    .strided_index()
                    .zip(rhs_l.strided_index())
                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
                    .collect(),
            }
        }
        (None, Some((o_r1, o_r2))) => {
            // TODO: Maybe we want to avoid going through the layout twice.
            match lhs_l.offsets_b() {
                Some(ob) => {
                    let mut i_in_block = 0;
                    let mut i_right_broadcast = 0;
                    rhs[o_r1..o_r2]
                        .iter()
                        .map(|&r| {
                            let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
                            i_right_broadcast += 1;
                            if i_right_broadcast >= ob.right_broadcast {
                                i_in_block += 1;
                                i_right_broadcast = 0;
                            }
                            if i_in_block >= ob.len {
                                i_in_block = 0
                            }
                            f(*l, r)
                        })
                        .collect()
                }
                None => lhs_l
                    .strided_index()
                    .zip(rhs_l.strided_index())
                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
                    .collect(),
            }
        }
        _ => lhs_l
            .strided_index()
            .zip(rhs_l.strided_index())
            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
            .collect(),
    }
}

// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
    lhs_l: &Layout,
    rhs_l: &Layout,
    lhs: &[T],
    rhs: &[T],
    mut f: F,
    mut f_vec: FV,
) -> Vec<T> {
    let el_count = lhs_l.shape().elem_count();
    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
            let mut ys: Vec<T> = Vec::with_capacity(el_count);
            let ys_to_set = ys.spare_capacity_mut();
            let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
            f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
            // SAFETY: values are all set by f_vec.
            unsafe { ys.set_len(el_count) };
            ys
        }
        (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
            Some(ob) if ob.right_broadcast == 1 => {
                let rhs = &rhs[ob.start..ob.start + ob.len];
                let mut ys: Vec<T> = Vec::with_capacity(el_count);
                let ys_to_set = ys.spare_capacity_mut();
                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
                let mut dst_i = 0;
                for src_i in (o_l1..o_l2).step_by(ob.len) {
                    f_vec(
                        &lhs[src_i..src_i + ob.len],
                        rhs,
                        &mut ys_to_set[dst_i..dst_i + ob.len],
                    );
                    dst_i += ob.len;
                }
                // SAFETY: values are all set by f_vec.
                unsafe { ys.set_len(el_count) };
                ys
            }
            Some(ob) => {
                let rhs = &rhs[ob.start..ob.start + ob.len];
                let mut ys = lhs[o_l1..o_l2].to_vec();
                for idx_l in 0..ob.left_broadcast {
                    let start = idx_l * ob.len * ob.right_broadcast;
                    for (i, &r) in rhs.iter().enumerate() {
                        let start = start + i * ob.right_broadcast;
                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
                            *v = f(*v, r)
                        }
                    }
                }
                ys
            }
            None => lhs_l
                .strided_index()
                .zip(rhs_l.strided_index())
                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
                .collect(),
        },
        (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
            Some(ob) if ob.right_broadcast == 1 => {
                let lhs = &lhs[ob.start..ob.start + ob.len];
                let mut ys: Vec<T> = Vec::with_capacity(el_count);
                let ys_to_set = ys.spare_capacity_mut();
                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
                let mut dst_i = 0;
                for src_i in (o_r1..o_r2).step_by(ob.len) {
                    f_vec(
                        lhs,
                        &rhs[src_i..src_i + ob.len],
                        &mut ys_to_set[dst_i..dst_i + ob.len],
                    );
                    dst_i += ob.len;
                }
                // SAFETY: values are all set by f_vec.
                unsafe { ys.set_len(el_count) };
                ys
            }
            Some(ob) => {
                let lhs = &lhs[ob.start..ob.start + ob.len];
                let mut ys = rhs[o_r1..o_r2].to_vec();
                for idx_l in 0..ob.left_broadcast {
                    let start = idx_l * ob.len * ob.right_broadcast;
                    for (i, &l) in lhs.iter().enumerate() {
                        let start = start + i * ob.right_broadcast;
                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
                            *v = f(l, *v)
                        }
                    }
                }
                ys
            }
            None => lhs_l
                .strided_index()
                .zip(rhs_l.strided_index())
                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
                .collect(),
        },
        _ => lhs_l
            .strided_index()
            .zip(rhs_l.strided_index())
            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
            .collect(),
    }
}

struct Affine(f64, f64);

impl Map1 for Affine {
    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
        let mul = T::from_f64(self.0);
        let add = T::from_f64(self.1);
        Ok(unary_map(vs, layout, |v| v * mul + add))
    }
}

struct AvgPool2D((usize, usize), (usize, usize));

impl Map1 for AvgPool2D {
    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
        let (k_h, k_w) = self.0;
        let (s_h, s_w) = self.1;
        let (b_sz, c, h, w) = layout.shape().dims4()?;
        let stride = layout.stride();
        let (stride_h, stride_w) = (stride[2], stride[3]);
        let h_out = (h - k_h) / s_h + 1;
        let w_out = (w - k_w) / s_w + 1;
        let src_index = layout.start_offset();
        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
        let scale = 1f64 / (k_h * k_w) as f64;
        let scale = T::from_f64(scale);
        for b_idx in 0..b_sz {
            let dst = &mut dst[b_idx * c * h_out * w_out..];
            let src_index = src_index + b_idx * stride[0];
            for c_idx in 0..c {
                let dst = &mut dst[c_idx * h_out * w_out..];
                let src_index = src_index + c_idx * stride[1];
                for h_idx in 0..h_out {
                    for w_idx in 0..w_out {
                        let mut sum = T::zero();
                        for m in 0..k_h {
                            for n in 0..k_w {
                                let m = s_h * h_idx + m;
                                let n = s_w * w_idx + n;
                                sum += src[src_index + m * stride_h + n * stride_w]
                            }
                        }
                        dst[h_idx * w_out + w_idx] = sum * scale;
                    }
                }
            }
        }
        Ok(dst)
    }
}

struct MaxPool2D((usize, usize), (usize, usize));

impl Map1 for MaxPool2D {
    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
        let (k_h, k_w) = self.0;
        let (s_h, s_w) = self.1;
        let (b_sz, c, h, w) = layout.shape().dims4()?;
        let stride = layout.stride();
        let (stride_h, stride_w) = (stride[2], stride[3]);
        let h_out = (h - k_h) / s_h + 1;
        let w_out = (w - k_w) / s_w + 1;
        let src_index = layout.start_offset();
        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
        for b_idx in 0..b_sz {
            let dst = &mut dst[b_idx * c * h_out * w_out..];
            let src_index = src_index + b_idx * stride[0];
            for c_idx in 0..c {
                let dst = &mut dst[c_idx * h_out * w_out..];
                let src_index = src_index + c_idx * stride[1];
                for h_idx in 0..h_out {
                    for w_idx in 0..w_out {
                        let mut largest =
                            src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
                        for m in 0..k_h {
                            for n in 0..k_w {
                                let m = s_h * h_idx + m;
                                let n = s_w * w_idx + n;
                                if largest < src[src_index + m * stride_h + n * stride_w] {
                                    largest = src[src_index + m * stride_h + n * stride_w]
                                }
                            }
                        }
                        dst[h_idx * w_out + w_idx] = largest;
                    }
                }
            }
        }
        Ok(dst)
    }
}

struct UpsampleNearest1D(usize);

impl Map1 for UpsampleNearest1D {
    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
        // TODO: Specialized implementation for the case 2*sz?
        let dst_sz = self.0;
        let (b_sz, c, src_sz) = layout.shape().dims3()?;
        let stride = layout.stride();
        let stride_sz = stride[2];
        let src_index = layout.start_offset();
        let scale_sz = src_sz as f64 / dst_sz as f64;
        let mut dst = vec![T::zero(); b_sz * c * dst_sz];
        let src_idxs = (0..dst_sz)
            .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
            .collect::<Vec<_>>();
        for b_idx in 0..b_sz {
            let dst = &mut dst[b_idx * c * dst_sz..];
            let src_index = src_index + b_idx * stride[0];
            for c_idx in 0..c {
                let dst = &mut dst[c_idx * dst_sz..];
                let src_index = src_index + c_idx * stride[1];
                for (idx, src_idx) in src_idxs.iter().enumerate() {
                    dst[idx] = src[src_index + src_idx * stride_sz]
                }
            }
        }
        Ok(dst)
    }
}

struct UpsampleNearest2D(usize, usize);

impl Map1 for UpsampleNearest2D {
    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
        // TODO: Specialized implementation for the case 2*h, 2*w?
        let (dst_h, dst_w) = (self.0, self.1);
        let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
        let stride = layout.stride();
        let (stride_h, stride_w) = (stride[2], stride[3]);
        let src_index = layout.start_offset();
        let scale_h = src_h as f64 / dst_h as f64;
        let scale_w = src_w as f64 / dst_w as f64;
        let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
        let src_h_idxs = (0..dst_h)
            .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
            .collect::<Vec<_>>();
        let src_w_idxs = (0..dst_w)
            .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
            .collect::<Vec<_>>();
        for b_idx in 0..b_sz {
            let dst = &mut dst[b_idx * c * dst_h * dst_w..];
            let src_index = src_index + b_idx * stride[0];
            for c_idx in 0..c {
                let dst = &mut dst[c_idx * dst_h * dst_w..];
                let src_index = src_index + c_idx * stride[1];
                for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
                    for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
                        let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
                        dst[h_idx * dst_w + w_idx] = src[src_index]
                    }
                }
            }
        }
        Ok(dst)
    }
}

struct Gather<'a, I: IntDType> {
    ids: &'a [I],
    ids_l: &'a Layout,
    dim: usize,
}

impl<'a, I: IntDType> Map1 for Gather<'a, I> {
    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
        let ids = match self.ids_l.contiguous_offsets() {
            Some((a, b)) => &self.ids[a..b],
            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
        };
        let src = match src_l.contiguous_offsets() {
            Some((a, b)) => &src[a..b],
            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
        };
        let dim = self.dim;
        let ids_dims = self.ids_l.dims();
        let src_dims = src_l.dims();
        let dst_len: usize = ids_dims.iter().product();
        let dst_left_len: usize = ids_dims[..dim].iter().product();
        let dst_dim_len = ids_dims[dim];
        let dst_right_len: usize = ids_dims[dim + 1..].iter().product();

        let src_dim_len = src_dims[dim];
        let src_right_len: usize = src_dims[dim + 1..].iter().product();

        let mut dst = vec![T::zero(); dst_len];
        for left_i in 0..dst_left_len {
            let start_src_idx = left_i * src_right_len * src_dim_len;
            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
            for i in 0..dst_dim_len {
                let start_dst_idx = start_dst_idx + i * dst_right_len;
                for right_i in 0..dst_right_len {
                    let dst_idx = start_dst_idx + right_i;
                    let index = ids[dst_idx].as_usize();
                    if index >= src_dim_len {
                        Err(Error::InvalidIndex {
                            index,
                            size: src_dim_len,
                            op: "gather",
                        }
                        .bt())?
                    }
                    let src_idx = start_src_idx + index * src_right_len + right_i;
                    dst[dst_idx] = src[src_idx]
                }
            }
        }
        Ok(dst)
    }
}

struct IndexSelect<'a, T: IntDType> {
    ids: &'a [T],
    ids_l: &'a Layout,
    dim: usize,
}

impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
        let src = match layout.contiguous_offsets() {
            Some((a, b)) => &src[a..b],
            None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
        };
        let dim = self.dim;
        let n_ids = match self.ids_l.dims() {
            [n_ids] => *n_ids,
            d => Err(Error::UnexpectedNumberOfDims {
                expected: 1,
                got: d.len(),
                shape: self.ids_l.shape().clone(),
            }
            .bt())?,
        };
        let stride_ids = self.ids_l.stride()[0];
        let mut dst_dims = layout.dims().to_vec();
        let src_dim = dst_dims[dim];
        dst_dims[dim] = n_ids;
        let dst_len: usize = dst_dims.iter().product();
        let left_len: usize = dst_dims[..dim].iter().product();
        let right_len: usize = dst_dims[dim + 1..].iter().product();
        let mut dst = vec![T::zero(); dst_len];
        for left_i in 0..left_len {
            let start_src_idx = left_i * right_len * src_dim;
            let start_dst_idx = left_i * right_len * n_ids;
            for i in 0..n_ids {
                let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
                if index >= src_dim {
                    Err(Error::InvalidIndex {
                        index,
                        size: src_dim,
                        op: "index-select",
                    }
                    .bt())?
                }
                let start_src_idx = start_src_idx + index * right_len;
                let start_dst_idx = start_dst_idx + i * right_len;
                dst[start_dst_idx..start_dst_idx + right_len]
                    .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
            }
        }
        Ok(dst)
    }
}

struct ScatterAdd<'a, I: IntDType> {
    ids: &'a [I],
    ids_l: &'a Layout,
    dim: usize,
}

impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
    const OP: &'static str = "scatter-add";
    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
        let dst_len = l1.shape().elem_count();
        let mut dst = vec![T::zero(); dst_len];
        copy_strided_src_(v1, &mut dst, 0, l1);
        let src = match src_l.contiguous_offsets() {
            None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
            Some((o1, o2)) => &src[o1..o2],
        };

        let dim = self.dim;
        let ids_dims = self.ids_l.dims();
        let dst_dims = l1.dims();
        let dst_dim_len = dst_dims[dim];
        let dst_right_len: usize = dst_dims[dim + 1..].iter().product();

        let ids_left_len: usize = ids_dims[..dim].iter().product();
        let ids_dim_len = ids_dims[dim];
        let ids_right_len: usize = ids_dims[dim + 1..].iter().product();

        let ids = match self.ids_l.contiguous_offsets() {
            Some((a, b)) => &self.ids[a..b],
            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
        };
        for left_i in 0..ids_left_len {
            let start_ids_idx = left_i * ids_right_len * ids_dim_len;
            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
            for i in 0..ids_dim_len {
                let start_ids_idx = start_ids_idx + i * ids_right_len;
                for right_i in 0..dst_right_len {
                    let ids_idx = start_ids_idx + right_i;
                    let index = ids[ids_idx].as_usize();
                    if index >= dst_dim_len {
                        Err(Error::InvalidIndex {
                            index,
                            size: dst_dim_len,
                            op: "gather",
                        }
                        .bt())?
                    }
                    let dst_idx = start_dst_idx + index * dst_right_len + right_i;
                    dst[dst_idx] += src[ids_idx]
                }
            }
        }

        Ok(dst)
    }
}

struct IndexAdd<'a, I: IntDType> {
    ids: &'a [I],
    dim: usize,
}

impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
    const OP: &'static str = "index-add";
    // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
    // v1, l1 -> self
    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
        let dst_len = l1.shape().elem_count();
        let mut dst = vec![T::zero(); dst_len];
        copy_strided_src_(v1, &mut dst, 0, l1);
        let src = match src_l.contiguous_offsets() {
            None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
            Some((o1, o2)) => &src[o1..o2],
        };
        let dim = self.dim;
        let max_idx = l1.dims()[dim];
        let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
        let src_dim_sz = src_l.dims()[dim];
        let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
        if dim == 0 {
            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
                let dst_idx = dst_idx.as_usize();
                if dst_idx >= max_idx {
                    Err(Error::InvalidIndex {
                        index: dst_idx,
                        op: "index-add",
                        size: max_idx,
                    })?
                }
                let src_idx = src_idx * post_dim;
                let dst_idx = dst_idx * post_dim;
                let src = &src[src_idx..src_idx + post_dim];
                let dst = &mut dst[dst_idx..dst_idx + post_dim];
                for (d, &s) in dst.iter_mut().zip(src.iter()) {
                    *d += s
                }
            }
        } else {
            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
                let dst_idx = dst_idx.as_usize();
                if dst_idx >= max_idx {
                    Err(Error::InvalidIndex {
                        index: dst_idx,
                        op: "index-add",
                        size: max_idx,
                    })?
                }
                for pre_i in 0..pre_dim {
                    let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
                    let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
                    let src = &src[pre_src_i..pre_src_i + post_dim];
                    let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
                    for (d, &s) in dst.iter_mut().zip(src.iter()) {
                        *d += s
                    }
                }
            }
        }
        Ok(dst)
    }
}

fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
    match src_l.strided_blocks() {
        crate::StridedBlocks::SingleBlock { start_offset, len } => {
            let to_copy = (dst.len() - dst_offset).min(len);
            dst[dst_offset..dst_offset + to_copy]
                .copy_from_slice(&src[start_offset..start_offset + to_copy])
        }
        crate::StridedBlocks::MultipleBlocks {
            block_start_index,
            block_len: 1,
        } => {
            for (dst_index, src_index) in block_start_index.enumerate() {
                let dst_index = dst_index + dst_offset;
                if dst_index >= dst.len() {
                    break;
                }
                dst[dst_index] = src[src_index]
            }
        }
        crate::StridedBlocks::MultipleBlocks {
            block_start_index,
            block_len,
        } => {
            let mut dst_index = dst_offset;
            for src_index in block_start_index {
                let next_dst_index = dst_index + block_len;
                if dst_index >= dst.len() {
                    break;
                }
                let to_copy = usize::min(block_len, dst.len() - dst_index);
                dst[dst_index..dst_index + to_copy]
                    .copy_from_slice(&src[src_index..src_index + to_copy]);
                dst_index = next_dst_index
            }
        }
    }
}

struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);

impl<'a> Map2 for Conv1D<'a> {
    const OP: &'static str = "conv1d";
    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
        let p = self.0;
        let inp = &inp[inp_l.start_offset()..];
        let k = &k[k_l.start_offset()..];
        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
        let l_out = p.l_out();
        let dst_elems = p.c_out * l_out * p.b_size;
        // The output shape is [b_size, c_out, l_out]
        let dst = vec![T::zero(); dst_elems];

        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
        for b_idx in 0..p.b_size {
            for src_l in 0..p.l_in {
                for src_c_idx in 0..p.c_in {
                    let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
                    inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
                }
            }
        }

        for offset in 0..p.k_size {
            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
                let dst_idx = dst_c_idx * l_out;
                let k_cont = (0..p.c_in)
                    .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
                    .collect::<Vec<_>>();
                for b_idx in 0..p.b_size {
                    let dst_idx = dst_idx + b_idx * p.c_out * l_out;
                    for dst_l in 0..l_out {
                        let dst_idx = dst_idx + dst_l;
                        let src_l = p.stride * dst_l + offset * p.dilation;
                        if src_l < p.padding || src_l >= p.padding + p.l_in {
                            continue;
                        }
                        let src_l = src_l - p.padding;
                        let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
                        assert!(inp_cont.len() >= p.c_in);
                        assert!(k_cont.len() >= p.c_in);
                        let mut d = T::zero();
                        unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
                        let dst_p = dst.as_ptr();
                        // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
                        // the different tasks so no two threads can try to write at the same
                        // location.
                        unsafe {
                            let ptr = dst_p.add(dst_idx) as *mut T;
                            *ptr += d
                        }
                    }
                }
            })
        }
        Ok(dst)
    }
}

struct Im2Col1D {
    l_k: usize,
    stride: usize,
    dilation: usize,
    padding: usize,
}

impl Im2Col1D {
    fn l_out(&self, l: usize) -> usize {
        (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
    }
}

impl Map1 for Im2Col1D {
    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
        let &Self {
            l_k,
            stride,
            dilation,
            padding,
        } = self;
        let (b, c, l) = layout.shape().dims3()?;
        let l_out = self.l_out(l);
        let src = &vs[layout.start_offset()..];
        let mut dst = vec![T::zero(); b * l_out * c * l_k];
        let (src_s0, src_s1, src_s2) = {
            let s = layout.stride();
            (s[0], s[1], s[2])
        };
        // TODO: provide specialized kernels for the common use cases.
        // - l_k = 1
        // - padding = 0
        // - stride = 1
        // - dilation = 1
        for b_idx in 0..b {
            let src_idx = b_idx * src_s0;
            let dst_idx = b_idx * l_out * c * l_k;
            for l_idx in 0..l_out {
                let dst_idx = dst_idx + l_idx * c * l_k;
                for c_idx in 0..c {
                    let dst_idx = dst_idx + c_idx * l_k;
                    let src_idx = c_idx * src_s1 + src_idx;
                    for l_k_idx in 0..l_k {
                        let src_l = l_idx * stride + l_k_idx * dilation;
                        if padding != 0 && (src_l < padding || src_l >= l + padding) {
                            continue;
                        }
                        let src_l = src_l - padding;
                        let src_idx = src_idx + src_l * src_s2;
                        let dst_idx = dst_idx + l_k_idx;
                        dst[dst_idx] = src[src_idx]
                    }
                }
            }
        }
        Ok(dst)
    }
}

struct Im2Col {
    h_k: usize,
    w_k: usize,
    stride: usize,
    dilation: usize,
    padding: usize,
}

impl Im2Col {
    fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
        let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
        let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
        (h_out, w_out)
    }
}

impl Map1 for Im2Col {
    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
        let &Self {
            h_k,
            w_k,
            stride,
            dilation,
            padding,
        } = self;
        let (b, c, h, w) = layout.shape().dims4()?;
        let (h_out, w_out) = self.hw_out(h, w);
        let src = &vs[layout.start_offset()..];
        let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
        let (src_s0, src_s1, src_s2, src_s3) = {
            let s = layout.stride();
            (s[0], s[1], s[2], s[3])
        };
        // TODO: provide specialized kernels for the common use cases.
        // - h_k = w_k = 1
        // - padding = 0
        // - stride = 1
        // - dilation = 1
        for b_idx in 0..b {
            let src_idx = b_idx * src_s0;
            let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
            for h_idx in 0..h_out {
                let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
                for w_idx in 0..w_out {
                    let dst_idx = dst_idx + w_idx * c * h_k * w_k;
                    for c_idx in 0..c {
                        let dst_idx = dst_idx + c_idx * h_k * w_k;
                        let src_idx = c_idx * src_s1 + src_idx;
                        for h_k_idx in 0..h_k {
                            let src_h = h_idx * stride + h_k_idx * dilation;
                            if padding != 0 && (src_h < padding || src_h >= h + padding) {
                                continue;
                            }
                            let src_h = src_h - padding;
                            let src_idx = src_idx + src_h * src_s2;
                            let dst_idx = dst_idx + h_k_idx * w_k;
                            for w_k_idx in 0..w_k {
                                let src_w = w_idx * stride + w_k_idx * dilation;
                                if padding != 0 && (src_w < padding || src_w >= w + padding) {
                                    continue;
                                }
                                let src_w = src_w - padding;
                                let src_idx = src_idx + src_w * src_s3;
                                let dst_idx = dst_idx + w_k_idx;
                                dst[dst_idx] = src[src_idx]
                            }
                        }
                    }
                }
            }
        }
        Ok(dst)
    }
}

struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);

impl<'a> Map2 for ConvTranspose1D<'a> {
    const OP: &'static str = "conv_transpose1d";
    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
        let p = self.0;
        let inp = &inp[inp_l.start_offset()..];
        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
        let l_out = p.l_out();

        // Output shape: [b_size, c_out, l_out].
        let dst_elems = p.c_out * l_out * p.b_size;
        let dst = vec![T::zero(); dst_elems];
        let dst_s0 = p.c_out * l_out;
        let dst_s1 = l_out;
        let dst_s2 = 1;

        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
        let cont_s0 = p.l_in * p.c_in;
        let cont_s1 = p.c_in;
        for b_idx in 0..p.b_size {
            for l_idx in 0..p.l_in {
                for c_idx in 0..p.c_in {
                    let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
                    let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
                    inp_cont[dst_idx] = inp[src_idx]
                }
            }
        }

        for k_idx in 0..p.k_size {
            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
                let k_cont = (0..p.c_in)
                    .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
                    .collect::<Vec<_>>();
                for b_idx in 0..p.b_size {
                    for l_idx in 0..p.l_in {
                        let out_idx = l_idx * p.stride + k_idx * p.dilation;
                        if out_idx < p.padding {
                            continue;
                        }
                        let out_idx = out_idx - p.padding;
                        if out_idx < l_out {
                            let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
                            let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
                            let mut d = T::zero();
                            unsafe {
                                T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
                            }
                            let dst_p = dst.as_ptr();
                            // Safety: dst_idx are uniques per dst_c_idx which is used to
                            // parallelise the different tasks so no two threads can try to
                            // write at the same location.
                            unsafe {
                                let ptr = dst_p.add(dst_idx) as *mut T;
                                *ptr += d
                            }
                        }
                    }
                }
            })
        }
        Ok(dst)
    }
}

struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);

impl<'a> Map2 for Conv2D<'a> {
    const OP: &'static str = "conv2d";
    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
        let p = self.0;
        let inp = &inp[inp_l.start_offset()..];
        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
        let k = &k[k_l.start_offset()..];
        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
        let (out_h, out_w) = (p.out_h(), p.out_w());

        // Output shape: [b_size, c_out, out_h, out_w].
        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];

        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
        let cont_s0 = p.i_h * p.i_w * p.c_in;
        let cont_s1 = p.i_w * p.c_in;
        let cont_s2 = p.c_in;
        for b_idx in 0..p.b_size {
            for h_idx in 0..p.i_h {
                for w_idx in 0..p.i_w {
                    for c_idx in 0..p.c_in {
                        let src_idx =
                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
                        inp_cont[dst_idx] = inp[src_idx]
                    }
                }
            }
        }

        for offset_h in 0..p.k_h {
            for offset_w in 0..p.k_w {
                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
                    let dst_idx = dst_c_idx * out_w * out_h;
                    let k_cont = (0..p.c_in)
                        .map(|c_in_idx| {
                            k[dst_c_idx * k_s0
                                + c_in_idx * k_s1
                                + offset_h * k_s2
                                + offset_w * k_s3]
                        })
                        .collect::<Vec<_>>();
                    for b_idx in 0..p.b_size {
                        let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
                        for dst_h in 0..out_h {
                            let dst_idx = dst_idx + dst_h * out_w;
                            let src_h = p.stride * dst_h + offset_h * p.dilation;
                            if src_h < p.padding || src_h >= p.i_h + p.padding {
                                continue;
                            }
                            let src_h = src_h - p.padding;
                            for dst_w in 0..out_w {
                                let dst_idx = dst_idx + dst_w;
                                let src_w = p.stride * dst_w + offset_w * p.dilation;
                                if src_w < p.padding || src_w >= p.i_w + p.padding {
                                    continue;
                                }
                                let src_w = src_w - p.padding;
                                let inp_cont = &inp_cont
                                    [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
                                assert!(inp_cont.len() >= p.c_in);
                                assert!(k_cont.len() >= p.c_in);
                                let mut d = T::zero();
                                unsafe {
                                    T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
                                }
                                let dst_p = dst.as_ptr();
                                // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
                                // the different tasks so no two threads can try to write at the same
                                // location.
                                unsafe {
                                    let ptr = dst_p.add(dst_idx) as *mut T;
                                    *ptr += d
                                }
                            }
                        }
                    }
                });
            }
        }

        Ok(dst)
    }
}

struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);

impl<'a> Map2 for ConvTranspose2D<'a> {
    const OP: &'static str = "conv_transpose2d";
    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
        let p = self.0;
        let inp = &inp[inp_l.start_offset()..];
        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
        let k = &k[k_l.start_offset()..];
        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
        let (out_h, out_w) = (p.out_h(), p.out_w());

        // Output shape: [b_size, c_out, out_h, out_w].
        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
        let dst_s0 = p.c_out * out_h * out_w;
        let dst_s1 = out_h * out_w;
        let dst_s2 = out_w;
        let dst_s3 = 1;

        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
        let cont_s0 = p.i_h * p.i_w * p.c_in;
        let cont_s1 = p.i_w * p.c_in;
        let cont_s2 = p.c_in;
        for b_idx in 0..p.b_size {
            for h_idx in 0..p.i_h {
                for w_idx in 0..p.i_w {
                    for c_idx in 0..p.c_in {
                        let src_idx =
                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
                        inp_cont[dst_idx] = inp[src_idx]
                    }
                }
            }
        }

        for k_y in 0..p.k_h {
            for k_x in 0..p.k_w {
                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
                    let k_cont = (0..p.c_in)
                        .map(|c_in_idx| {
                            k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
                        })
                        .collect::<Vec<_>>();
                    for b_idx in 0..p.b_size {
                        for inp_y in 0..p.i_h {
                            for inp_x in 0..p.i_w {
                                let out_x = inp_x * p.stride + k_x * p.dilation;
                                let out_y = inp_y * p.stride + k_y * p.dilation;
                                if out_x < p.padding || out_y < p.padding {
                                    continue;
                                }
                                let out_x = out_x - p.padding;
                                let out_y = out_y - p.padding;
                                if out_x < out_w && out_y < out_h {
                                    let inp_cont = &inp_cont
                                        [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
                                    let dst_idx = b_idx * dst_s0
                                        + out_y * dst_s2
                                        + out_x * dst_s3
                                        + dst_c_idx * dst_s1;
                                    let mut d = T::zero();
                                    unsafe {
                                        T::vec_dot(
                                            inp_cont.as_ptr(),
                                            k_cont.as_ptr(),
                                            &mut d,
                                            p.c_in,
                                        )
                                    }
                                    let dst_p = dst.as_ptr();
                                    // Safety: dst_idx are uniques per dst_c_idx which is used to
                                    // parallelise the different tasks so no two threads can try to
                                    // write at the same location.
                                    unsafe {
                                        let ptr = dst_p.add(dst_idx) as *mut T;
                                        *ptr += d
                                    }
                                }
                            }
                        }
                    }
                })
            }
        }
        Ok(dst)
    }
}

struct MatMul((usize, usize, usize, usize));

impl MatMul {
    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
            lhs_l: lhs_l.clone(),
            rhs_l: rhs_l.clone(),
            bmnk: self.0,
            msg,
        }))
        .bt()
    }
}

impl Map2 for MatMul {
    const OP: &'static str = "mat_mul";

    #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
        &self,
        lhs: &[T],
        lhs_l: &Layout,
        rhs: &[T],
        rhs_l: &Layout,
    ) -> Result<Vec<T>> {
        use gemm::{gemm, Parallelism};

        match T::DTYPE {
            DType::F16 | DType::F32 | DType::F64 => {}
            _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
        }

        let (b, m, n, k) = self.0;
        let lhs = &lhs[lhs_l.start_offset()..];
        let rhs = &rhs[rhs_l.start_offset()..];

        let lhs_stride = lhs_l.stride();
        let rhs_stride = rhs_l.stride();
        let rank = lhs_stride.len();
        let lhs_cs = lhs_stride[rank - 1];
        let lhs_rs = lhs_stride[rank - 2];

        let rhs_cs = rhs_stride[rank - 1];
        let rhs_rs = rhs_stride[rank - 2];

        let a_skip: usize = match lhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => m * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
        };
        let b_skip: usize = match rhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => n * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
        };
        let c_skip: usize = m * n;

        let dst_shape: Shape = (m, n).into();
        let dst_strides = dst_shape.stride_contiguous();
        let dst_rs = dst_strides[0];
        let dst_cs = dst_strides[1];

        let mut dst = vec![T::zero(); b * m * n];
        let num_threads = crate::utils::get_num_threads();
        let parallelism = if num_threads > 1 {
            Parallelism::Rayon(num_threads)
        } else {
            Parallelism::None
        };
        for step in 0..b {
            let lhs_p = &lhs[step * a_skip..];
            let rhs_p = &rhs[step * b_skip..];
            let dst_p = &mut dst[step * c_skip..];
            unsafe {
                gemm(
                    /* m: usize = */ m,
                    /* n: usize = */ n,
                    /* k: usize = */ k,
                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
                    /* dst_cs: isize = */ dst_cs as isize,
                    /* dst_rs: isize = */ dst_rs as isize,
                    /* read_dst: bool = */ false,
                    /* lhs: *const T = */ lhs_p.as_ptr(),
                    /* lhs_cs: isize = */ lhs_cs as isize,
                    /* lhs_rs: isize = */ lhs_rs as isize,
                    /* rhs: *const T = */ rhs_p.as_ptr(),
                    /* rhs_cs: isize = */ rhs_cs as isize,
                    /* rhs_rs: isize = */ rhs_rs as isize,
                    /* alpha: T = */ T::zero(),
                    /* beta: T = */ T::one(),
                    /* conj_dst: bool = */ false,
                    /* conj_lhs: bool = */ false,
                    /* conj_rhs: bool = */ false,
                    parallelism,
                )
            }
        }
        Ok(dst)
    }

    #[cfg(feature = "accelerate")]
    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
        &self,
        lhs: &[T],
        lhs_l: &Layout,
        rhs: &[T],
        rhs_l: &Layout,
    ) -> Result<Vec<T>> {
        let (b, m, n, k) = self.0;
        let lhs = &lhs[lhs_l.start_offset()..];
        let rhs = &rhs[rhs_l.start_offset()..];

        let lhs_stride = lhs_l.stride();
        let rhs_stride = rhs_l.stride();
        let rank = lhs_stride.len();

        let a_skip: usize = match lhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => m * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
        };
        let b_skip: usize = match rhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => n * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
        };
        let c_skip: usize = m * n;

        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];

        let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
            (n as i32, b'N')
        } else if rhs_m1 == k && rhs_m2 == 1 {
            (k as i32, b'T')
        } else {
            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
        };
        // The b tensor has dims batching, m, k (lhs)
        let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
            (k as i32, b'N')
        } else if lhs_m1 == m && lhs_m2 == 1 {
            (m as i32, b'T')
        } else {
            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
        };

        let mut dst = vec![T::zero(); b * m * n];
        match T::DTYPE {
            DType::F16 => {
                crate::bail!("the accelerate backend does not support f16 matmul")
            }
            DType::F32 => {
                for step in 0..b {
                    let lhs_p = &lhs[step * a_skip..];
                    let rhs_p = &rhs[step * b_skip..];
                    let dst_p = &mut dst[step * c_skip..];
                    unsafe {
                        let a = rhs_p.as_ptr() as *const f32;
                        let b = lhs_p.as_ptr() as *const f32;
                        let c = dst_p.as_mut_ptr() as *mut f32;
                        let a = std::slice::from_raw_parts(a, a_skip);
                        let b = std::slice::from_raw_parts(b, b_skip);
                        let c = std::slice::from_raw_parts_mut(c, c_skip);
                        crate::accelerate::sgemm(
                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
                        )
                    }
                }
            }
            DType::F64 => {
                for step in 0..b {
                    let lhs_p = &lhs[step * a_skip..];
                    let rhs_p = &rhs[step * b_skip..];
                    let dst_p = &mut dst[step * c_skip..];
                    unsafe {
                        let a = rhs_p.as_ptr() as *const f64;
                        let b = lhs_p.as_ptr() as *const f64;
                        let c = dst_p.as_mut_ptr() as *mut f64;
                        let a = std::slice::from_raw_parts(a, a_skip);
                        let b = std::slice::from_raw_parts(b, b_skip);
                        let c = std::slice::from_raw_parts_mut(c, c_skip);
                        crate::accelerate::dgemm(
                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
                        )
                    }
                }
            }
            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
        }
        Ok(dst)
    }

    #[cfg(feature = "mkl")]
    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
        &self,
        lhs: &[T],
        lhs_l: &Layout,
        rhs: &[T],
        rhs_l: &Layout,
    ) -> Result<Vec<T>> {
        let (b, m, n, k) = self.0;
        let lhs = &lhs[lhs_l.start_offset()..];
        let rhs = &rhs[rhs_l.start_offset()..];

        let lhs_stride = lhs_l.stride();
        let rhs_stride = rhs_l.stride();
        let rank = lhs_stride.len();

        let a_skip: usize = match lhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => m * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
        };
        let b_skip: usize = match rhs_stride[..rank - 2] {
            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
            [stride] => stride,
            [] => n * k,
            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
        };
        let c_skip: usize = m * n;

        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];

        let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
            (n as i32, b'N')
        } else if rhs_m1 == k && rhs_m2 == 1 {
            (k as i32, b'T')
        } else {
            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
        };
        // The b tensor has dims batching, m, k (lhs)
        let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
            (k as i32, b'N')
        } else if lhs_m1 == m && lhs_m2 == 1 {
            (m as i32, b'T')
        } else {
            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
        };

        let mut dst = vec![T::zero(); b * m * n];
        match T::DTYPE {
            DType::F16 => {
                for step in 0..b {
                    let lhs_p = &lhs[step * a_skip..];
                    let rhs_p = &rhs[step * b_skip..];
                    let dst_p = &mut dst[step * c_skip..];
                    unsafe {
                        let a = rhs_p.as_ptr() as *const f16;
                        let b = lhs_p.as_ptr() as *const f16;
                        let c = dst_p.as_mut_ptr() as *mut f16;
                        let a = std::slice::from_raw_parts(a, a_skip);
                        let b = std::slice::from_raw_parts(b, b_skip);
                        let c = std::slice::from_raw_parts_mut(c, c_skip);
                        crate::mkl::hgemm(
                            transa,
                            transb,
                            /* m= */ n as i32,
                            /* n= */ m as i32,
                            /* k= */ k as i32,
                            /* alpha= */ f16::ONE,
                            /* a= */ a,
                            /* lda= */ lda,
                            /* b= */ b,
                            /* ldb= */ ldb,
                            /* beta= */ f16::ZERO,
                            /* c= */ c,
                            /* ldc= */ n as i32,
                        )
                    }
                }
            }
            DType::F32 => {
                for step in 0..b {
                    let lhs_p = &lhs[step * a_skip..];
                    let rhs_p = &rhs[step * b_skip..];
                    let dst_p = &mut dst[step * c_skip..];
                    unsafe {
                        let a = rhs_p.as_ptr() as *const f32;
                        let b = lhs_p.as_ptr() as *const f32;
                        let c = dst_p.as_mut_ptr() as *mut f32;
                        let a = std::slice::from_raw_parts(a, a_skip);
                        let b = std::slice::from_raw_parts(b, b_skip);
                        let c = std::slice::from_raw_parts_mut(c, c_skip);
                        crate::mkl::sgemm(
                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
                        )
                    }
                }
            }
            DType::F64 => {
                for step in 0..b {
                    let lhs_p = &lhs[step * a_skip..];
                    let rhs_p = &rhs[step * b_skip..];
                    let dst_p = &mut dst[step * c_skip..];
                    unsafe {
                        let a = rhs_p.as_ptr() as *const f64;
                        let b = lhs_p.as_ptr() as *const f64;
                        let c = dst_p.as_mut_ptr() as *mut f64;
                        let a = std::slice::from_raw_parts(a, a_skip);
                        let b = std::slice::from_raw_parts(b, b_skip);
                        let c = std::slice::from_raw_parts_mut(c, c_skip);
                        crate::mkl::dgemm(
                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
                        )
                    }
                }
            }
            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
        }
        Ok(dst)
    }
}

fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
    if v.is_sign_positive() {
        v
    } else {
        (v.exp() - T::one()) * alpha
    }
}

impl CpuStorage {
    pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
        D::cpu_storage_as_slice(self)
    }

    pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
        let storage0 = &storages[0];
        let s = match storage0 {
            Self::U8(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::U8(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::U8(storages)
            }
            Self::U32(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::U32(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::U32(storages)
            }
            Self::I64(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::I64(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::I64(storages)
            }
            Self::BF16(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::BF16(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::BF16(storages)
            }
            Self::F16(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::F16(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::F16(storages)
            }
            Self::F32(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::F32(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::F32(storages)
            }
            Self::F64(_) => {
                let storages = storages
                    .iter()
                    .map(|s| match s {
                        Self::F64(s) => Ok(s.as_slice()),
                        _ => crate::bail!("dtype mismatch"),
                    })
                    .collect::<Result<Vec<_>>>()?
                    .concat();
                Self::F64(storages)
            }
        };
        Ok(s)
    }
}

impl BackendStorage for CpuStorage {
    type Device = CpuDevice;

    fn dtype(&self) -> DType {
        match self {
            Self::U8(_) => DType::U8,
            Self::U32(_) => DType::U32,
            Self::I64(_) => DType::I64,
            Self::BF16(_) => DType::BF16,
            Self::F16(_) => DType::F16,
            Self::F32(_) => DType::F32,
            Self::F64(_) => DType::F64,
        }
    }

    fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
        // TODO: find a way around the quadratic number of cases below.
        match (self, dtype) {
            (Self::U8(storage), DType::BF16) => {
                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
                Ok(Self::BF16(data))
            }
            (Self::U32(storage), DType::BF16) => {
                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
                Ok(Self::BF16(data))
            }
            (Self::I64(storage), DType::BF16) => {
                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
                Ok(Self::BF16(data))
            }
            (Self::BF16(storage), DType::BF16) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::BF16(data))
            }
            (Self::F16(storage), DType::BF16) => {
                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
                Ok(Self::BF16(data))
            }
            (Self::F32(storage), DType::BF16) => {
                let data = unary_map(storage, layout, bf16::from_f32);
                Ok(Self::BF16(data))
            }
            (Self::F64(storage), DType::BF16) => {
                let data = unary_map(storage, layout, bf16::from_f64);
                Ok(Self::BF16(data))
            }
            (Self::U8(storage), DType::F16) => {
                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
                Ok(Self::F16(data))
            }
            (Self::U32(storage), DType::F16) => {
                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
                Ok(Self::F16(data))
            }
            (Self::I64(storage), DType::F16) => {
                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
                Ok(Self::F16(data))
            }
            (Self::BF16(storage), DType::F16) => {
                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
                Ok(Self::F16(data))
            }
            (Self::F16(storage), DType::F16) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::F16(data))
            }
            (Self::F32(storage), DType::F16) => {
                let data = unary_map(storage, layout, f16::from_f32);
                Ok(Self::F16(data))
            }
            (Self::F64(storage), DType::F16) => {
                let data = unary_map(storage, layout, f16::from_f64);
                Ok(Self::F16(data))
            }
            (Self::U8(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v as f32);
                Ok(Self::F32(data))
            }
            (Self::U32(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v as f32);
                Ok(Self::F32(data))
            }
            (Self::I64(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v as f32);
                Ok(Self::F32(data))
            }
            (Self::BF16(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v.to_f32());
                Ok(Self::F32(data))
            }
            (Self::F16(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v.to_f32());
                Ok(Self::F32(data))
            }
            (Self::F32(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::F32(data))
            }
            (Self::F64(storage), DType::F32) => {
                let data = unary_map(storage, layout, |v| v as f32);
                Ok(Self::F32(data))
            }
            (Self::U8(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::U8(data))
            }
            (Self::BF16(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
                Ok(Self::U8(data))
            }
            (Self::F16(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
                Ok(Self::U8(data))
            }
            (Self::F32(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v as u8);
                Ok(Self::U8(data))
            }
            (Self::F64(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v as u8);
                Ok(Self::U8(data))
            }
            (Self::U32(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v as u8);
                Ok(Self::U8(data))
            }
            (Self::I64(storage), DType::U8) => {
                let data = unary_map(storage, layout, |v| v as u8);
                Ok(Self::U8(data))
            }
            (Self::U8(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v as u32);
                Ok(Self::U32(data))
            }
            (Self::U32(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::U32(data))
            }
            (Self::I64(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v as u32);
                Ok(Self::U32(data))
            }
            (Self::BF16(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
                Ok(Self::U32(data))
            }
            (Self::F16(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
                Ok(Self::U32(data))
            }
            (Self::F32(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v as u32);
                Ok(Self::U32(data))
            }
            (Self::F64(storage), DType::U32) => {
                let data = unary_map(storage, layout, |v| v as u32);
                Ok(Self::U32(data))
            }
            (Self::U8(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v as i64);
                Ok(Self::I64(data))
            }
            (Self::U32(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v as i64);
                Ok(Self::I64(data))
            }
            (Self::I64(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::I64(data))
            }
            (Self::BF16(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
                Ok(Self::I64(data))
            }
            (Self::F16(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
                Ok(Self::I64(data))
            }
            (Self::F32(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v as i64);
                Ok(Self::I64(data))
            }
            (Self::F64(storage), DType::I64) => {
                let data = unary_map(storage, layout, |v| v as i64);
                Ok(Self::I64(data))
            }
            (Self::U8(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v as f64);
                Ok(Self::F64(data))
            }
            (Self::U32(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v as f64);
                Ok(Self::F64(data))
            }
            (Self::I64(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v as f64);
                Ok(Self::F64(data))
            }
            (Self::BF16(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v.to_f64());
                Ok(Self::F64(data))
            }
            (Self::F16(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v.to_f64());
                Ok(Self::F64(data))
            }
            (Self::F32(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v as f64);
                Ok(Self::F64(data))
            }
            (Self::F64(storage), DType::F64) => {
                let data = unary_map(storage, layout, |v| v);
                Ok(Self::F64(data))
            }
        }
    }

    fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
        match op {
            ReduceOp::Sum => {
                let src_dims = layout.dims();
                let mut dst_dims = src_dims.to_vec();
                for &dim in reduce_dims.iter() {
                    dst_dims[dim] = 1;
                }
                let dst_shape = Shape::from(dst_dims);
                let mut reduce_dims = reduce_dims.to_vec();
                // Sort the reduce_dims as they have to be processed from left to right when converting the
                // indexes.
                reduce_dims.sort();
                let reduce_dims_and_stride: Vec<_> = reduce_dims
                    .iter()
                    .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
                    .collect();
                ReduceSum {
                    dst_shape: &dst_shape,
                    reduce_dims: &reduce_dims,
                    reduce_dims_and_stride,
                }
                .map(self, layout)
            }
            ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
                let reduce_dim_index = match reduce_dims {
                    [reduce_dim_index] => *reduce_dim_index,
                    _ => {
                        let op = match op {
                            ReduceOp::Min => "min",
                            ReduceOp::ArgMin => "argmin",
                            ReduceOp::Max => "max",
                            ReduceOp::ArgMax => "argmax",
                            _ => unreachable!(),
                        };
                        let dims = reduce_dims.to_vec();
                        Err(Error::OnlySingleDimension { op, dims })?
                    }
                };
                let (use_min, return_index) = match op {
                    ReduceOp::Min => (true, false),
                    ReduceOp::ArgMin => (true, true),
                    ReduceOp::Max => (false, false),
                    ReduceOp::ArgMax => (false, true),
                    _ => unreachable!(),
                };
                ReduceIndex {
                    reduce_dim_index,
                    use_min,
                    return_index,
                }
                .map(self, layout)
            }
        }
    }

    fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
        Cmp(op).map(self, lhs_l, rhs, rhs_l)
    }

    fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
        Affine(mul, add).map(self, layout)
    }

    fn avg_pool2d(
        &self,
        layout: &Layout,
        kernel_size: (usize, usize),
        stride: (usize, usize),
    ) -> Result<Self> {
        AvgPool2D(kernel_size, stride).map(self, layout)
    }

    fn max_pool2d(
        &self,
        layout: &Layout,
        kernel_size: (usize, usize),
        stride: (usize, usize),
    ) -> Result<Self> {
        MaxPool2D(kernel_size, stride).map(self, layout)
    }

    fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
        UpsampleNearest1D(sz).map(self, layout)
    }

    fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
        UpsampleNearest2D(h, w).map(self, layout)
    }

    fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
        use num_traits::Float;
        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
        match self {
            Self::BF16(storage) => {
                let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
                Ok(Self::BF16(data))
            }
            Self::F16(storage) => {
                let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
                Ok(Self::F16(data))
            }
            Self::F32(storage) => {
                let data = unary_map(storage, layout, |v| v.powf(e as f32));
                Ok(Self::F32(data))
            }
            Self::F64(storage) => {
                let data = unary_map(storage, layout, |v| v.powf(e));
                Ok(Self::F64(data))
            }
            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
        }
    }

    fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
        match self {
            Self::BF16(storage) => {
                let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
                Ok(Self::BF16(data))
            }
            Self::F16(storage) => {
                let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
                Ok(Self::F16(data))
            }
            Self::F32(storage) => {
                let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
                Ok(Self::F32(data))
            }
            Self::F64(storage) => {
                let data = unary_map(storage, layout, |v| elu(v, alpha));
                Ok(Self::F64(data))
            }
            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
        }
    }

    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
        match self {
            Self::BF16(storage) => {
                if B::BF16_VEC {
                    let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
                    Ok(Self::BF16(data))
                } else {
                    let data = unary_map(storage, layout, B::bf16);
                    Ok(Self::BF16(data))
                }
            }
            Self::F16(storage) => {
                if B::F16_VEC {
                    let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
                    Ok(Self::F16(data))
                } else {
                    let data = unary_map(storage, layout, B::f16);
                    Ok(Self::F16(data))
                }
            }
            Self::F32(storage) => {
                if B::F32_VEC {
                    let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
                    Ok(Self::F32(data))
                } else {
                    let data = unary_map(storage, layout, B::f32);
                    Ok(Self::F32(data))
                }
            }
            Self::F64(storage) => {
                if B::F64_VEC {
                    let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
                    Ok(Self::F64(data))
                } else {
                    let data = unary_map(storage, layout, B::f64);
                    Ok(Self::F64(data))
                }
            }
            Self::U8(storage) => {
                let data = unary_map(storage, layout, B::u8);
                Ok(Self::U8(data))
            }
            Self::U32(storage) => {
                let data = unary_map(storage, layout, B::u32);
                Ok(Self::U32(data))
            }
            Self::I64(storage) => {
                let data = unary_map(storage, layout, B::i64);
                Ok(Self::I64(data))
            }
        }
    }

    fn binary_impl<B: BinaryOpT>(
        &self,
        rhs: &Self,
        lhs_l: &Layout,
        rhs_l: &Layout,
    ) -> Result<Self> {
        match (self, rhs) {
            (Self::BF16(lhs), Self::BF16(rhs)) => {
                let data = if B::BF16_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
                };
                Ok(Self::BF16(data))
            }
            (Self::F16(lhs), Self::F16(rhs)) => {
                let data = if B::F16_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
                };
                Ok(Self::F16(data))
            }
            (Self::F32(lhs), Self::F32(rhs)) => {
                let data = if B::F32_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
                };
                Ok(Self::F32(data))
            }
            (Self::F64(lhs), Self::F64(rhs)) => {
                let data = if B::F64_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
                };
                Ok(Self::F64(data))
            }
            (Self::U32(lhs), Self::U32(rhs)) => {
                let data = if B::U32_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
                };
                Ok(Self::U32(data))
            }
            (Self::I64(lhs), Self::I64(rhs)) => {
                let data = if B::I64_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
                };
                Ok(Self::I64(data))
            }
            (Self::U8(lhs), Self::U8(rhs)) => {
                let data = if B::U8_VEC {
                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
                } else {
                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
                };
                Ok(Self::U8(data))
            }
            _ => {
                // This should be covered by the dtype check above.
                Err(Error::DTypeMismatchBinaryOp {
                    lhs: self.dtype(),
                    rhs: rhs.dtype(),
                    op: B::NAME,
                }
                .bt())
            }
        }
    }

    fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
        match (self, dst) {
            (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
            (_, dst) => {
                // This should be covered by the dtype check above.
                return Err(Error::DTypeMismatchBinaryOp {
                    lhs: self.dtype(),
                    rhs: dst.dtype(),
                    op: "copy_strided",
                }
                .bt());
            }
        }
        Ok(())
    }

    fn where_cond(
        &self,
        layout: &Layout,
        t: &Self,
        t_l: &Layout,
        f: &Self,
        f_l: &Layout,
    ) -> Result<Self> {
        match self {
            Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
            Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
            Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
        }
    }

    fn conv1d(
        &self,
        l: &Layout,
        kernel: &Self,
        kernel_l: &Layout,
        params: &crate::conv::ParamsConv1D,
    ) -> Result<Self> {
        if !USE_IM2COL_CONV1D {
            return Conv1D(params).map(self, l, kernel, kernel_l);
        }
        let op = Im2Col1D {
            l_k: params.k_size,
            padding: params.padding,
            stride: params.stride,
            dilation: params.dilation,
        };
        let col = op.map(self, l)?;
        let b = params.b_size;
        let n = params.c_out;
        let l_out = params.l_out();
        let k = op.l_k * params.c_in;
        let m = l_out;
        let col_l = Layout::contiguous((b, m, k));
        let res = if kernel_l.is_contiguous() {
            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
                .transpose(1, 2)?
                .broadcast_as((b, k, n))?;
            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
        } else {
            // Make the kernel contiguous if not already the case.
            let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
                .transpose(1, 2)?
                .broadcast_as((b, k, n))?;
            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
        };
        let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
        let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
        res.copy_strided_src(&mut res_t, 0, &res_l)?;
        Ok(res_t)
    }

    fn conv_transpose1d(
        &self,
        l: &Layout,
        kernel: &Self,
        kernel_l: &Layout,
        params: &crate::conv::ParamsConvTranspose1D,
    ) -> Result<Self> {
        ConvTranspose1D(params).map(self, l, kernel, kernel_l)
    }

    fn conv2d(
        &self,
        l: &Layout,
        kernel: &Self,
        kernel_l: &Layout,
        params: &crate::conv::ParamsConv2D,
    ) -> Result<Self> {
        if !USE_IM2COL_CONV2D {
            return Conv2D(params).map(self, l, kernel, kernel_l);
        }
        let op = Im2Col {
            h_k: params.k_h,
            w_k: params.k_w,
            padding: params.padding,
            stride: params.stride,
            dilation: params.dilation,
        };
        let col = op.map(self, l)?;
        let b = params.b_size;
        let n = params.c_out;
        let (h_out, w_out) = (params.out_h(), params.out_w());
        let k = op.h_k * op.w_k * params.c_in;
        let m = h_out * w_out;
        let col_l = Layout::contiguous((b, m, k));
        let res = if kernel_l.is_contiguous() {
            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
                .transpose(1, 2)?
                .broadcast_as((b, k, n))?;
            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
        } else {
            // Make the kernel contiguous if not already the case.
            let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
                .transpose(1, 2)?
                .broadcast_as((b, k, n))?;
            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
        };
        let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
            .transpose(1, 2)?
            .transpose(1, 3)?;
        let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
        res.copy_strided_src(&mut res_t, 0, &res_l)?;
        Ok(res_t)
    }

    fn conv_transpose2d(
        &self,
        l: &Layout,
        kernel: &Self,
        kernel_l: &Layout,
        params: &crate::conv::ParamsConvTranspose2D,
    ) -> Result<Self> {
        ConvTranspose2D(params).map(self, l, kernel, kernel_l)
    }

    fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
        match ids {
            Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
            Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
            Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
        }
    }

    fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
        match ids {
            Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
            Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
            Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
        }
    }

    fn scatter_add(
        &self,
        l: &Layout,
        ids: &Self,
        ids_l: &Layout,
        src: &Self,
        src_l: &Layout,
        dim: usize,
    ) -> Result<Self> {
        match ids {
            Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
            Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
            Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
        }
    }

    fn index_add(
        &self,
        l: &Layout,
        ids: &Self,
        ids_l: &Layout,
        src: &Self,
        src_l: &Layout,
        dim: usize,
    ) -> Result<Self> {
        match ids {
            Self::U8(ids) => {
                let ids = match ids_l.contiguous_offsets() {
                    Some((a, b)) => &ids[a..b],
                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
                };
                IndexAdd { ids, dim }.map(self, l, src, src_l)
            }
            Self::U32(ids) => {
                let ids = match ids_l.contiguous_offsets() {
                    Some((a, b)) => &ids[a..b],
                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
                };
                IndexAdd { ids, dim }.map(self, l, src, src_l)
            }
            Self::I64(ids) => {
                let ids = match ids_l.contiguous_offsets() {
                    Some((a, b)) => &ids[a..b],
                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
                };
                IndexAdd { ids, dim }.map(self, l, src, src_l)
            }
            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
        }
    }

    fn matmul(
        &self,
        rhs: &Self,
        bmnk: (usize, usize, usize, usize),
        lhs_l: &Layout,
        rhs_l: &Layout,
    ) -> Result<Self> {
        MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
    }

    fn device(&self) -> &Self::Device {
        &CpuDevice
    }

    fn try_clone(&self, _: &Layout) -> Result<Self> {
        Ok(self.clone())
    }

    fn to_cpu_storage(&self) -> Result<CpuStorage> {
        Ok(self.clone())
    }
}

impl BackendDevice for CpuDevice {
    type Storage = CpuStorage;

    fn location(&self) -> crate::DeviceLocation {
        crate::DeviceLocation::Cpu
    }

    fn same_device(&self, _: &Self) -> bool {
        true
    }

    fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
        Ok(s.clone())
    }

    fn new(_: usize) -> Result<Self> {
        Ok(Self)
    }

    fn set_seed(&self, _seed: u64) -> Result<()> {
        crate::bail!("cannot seed the CPU rng with set_seed")
    }

    fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
        use rand::prelude::*;

        let elem_count = shape.elem_count();
        let mut rng = rand::thread_rng();
        match dtype {
            DType::U8 | DType::U32 | DType::I64 => {
                Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
            }
            DType::BF16 => {
                let mut data = Vec::with_capacity(elem_count);
                let uniform =
                    rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
                for _i in 0..elem_count {
                    data.push(rng.sample::<bf16, _>(uniform))
                }
                Ok(CpuStorage::BF16(data))
            }
            DType::F16 => {
                let mut data = Vec::with_capacity(elem_count);
                let uniform =
                    rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
                for _i in 0..elem_count {
                    data.push(rng.sample::<f16, _>(uniform))
                }
                Ok(CpuStorage::F16(data))
            }
            DType::F32 => {
                let mut data = Vec::with_capacity(elem_count);
                let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
                for _i in 0..elem_count {
                    data.push(rng.sample::<f32, _>(uniform))
                }
                Ok(CpuStorage::F32(data))
            }
            DType::F64 => {
                let mut data = Vec::with_capacity(elem_count);
                let uniform = rand::distributions::Uniform::new(min, max);
                for _i in 0..elem_count {
                    data.push(rng.sample::<f64, _>(uniform))
                }
                Ok(CpuStorage::F64(data))
            }
        }
    }

    fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
        use rand::prelude::*;

        let elem_count = shape.elem_count();
        let mut rng = rand::thread_rng();
        match dtype {
            DType::U8 | DType::U32 | DType::I64 => {
                Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
            }
            DType::BF16 => {
                let mut data = Vec::with_capacity(elem_count);
                let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
                    .map_err(Error::wrap)?;
                for _i in 0..elem_count {
                    data.push(normal.sample(&mut rng))
                }
                Ok(CpuStorage::BF16(data))
            }
            DType::F16 => {
                let mut data = Vec::with_capacity(elem_count);
                let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
                    .map_err(Error::wrap)?;
                for _i in 0..elem_count {
                    data.push(normal.sample(&mut rng))
                }
                Ok(CpuStorage::F16(data))
            }
            DType::F32 => {
                let mut data = Vec::with_capacity(elem_count);
                let normal =
                    rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
                for _i in 0..elem_count {
                    data.push(normal.sample(&mut rng))
                }
                Ok(CpuStorage::F32(data))
            }
            DType::F64 => {
                let mut data = Vec::with_capacity(elem_count);
                let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
                for _i in 0..elem_count {
                    data.push(normal.sample(&mut rng))
                }
                Ok(CpuStorage::F64(data))
            }
        }
    }

    fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
        let elem_count = shape.elem_count();
        let storage = match dtype {
            DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
            DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
            DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
            DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
            DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
            DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
            DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
        };
        Ok(storage)
    }

    fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
        let elem_count = shape.elem_count();
        let storage = match dtype {
            DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
            DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
            DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
            DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
            DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
            DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
        };
        Ok(storage)
    }
}

#[macro_export]
macro_rules! map_dtype {
    ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
        match $storage {
            $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
            s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
        }
    };
}