diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-27 16:59:32 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-27 16:59:32 +0200 |
commit | 952eca6b540078b1f30b58d9eb930f8e32d903cb (patch) | |
tree | 8eb5378393d9c64482638aae5dfeaaabb4ff0248 | |
parent | 25a2086e8f4cc23fada32a44607d3b8550916ebe (diff) | |
download | candle-952eca6b540078b1f30b58d9eb930f8e32d903cb.tar.gz candle-952eca6b540078b1f30b58d9eb930f8e32d903cb.tar.bz2 candle-952eca6b540078b1f30b58d9eb930f8e32d903cb.zip |
Fixing slice errors + comments.
-rw-r--r-- | candle-core/src/error.rs | 7 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 25 |
2 files changed, 29 insertions, 3 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index f9e69122..30d06239 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,13 @@ pub enum Error { nth_shape: Shape, }, + #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")] + ShapeMismatchSplit { + shape: Shape, + dim: usize, + n_parts: usize, + }, + #[error("{op} can only be performed on a single dimension")] OnlySingleDimension { op: &'static str, dims: Vec<usize> }, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 1466f6d0..3133f210 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { + /// Get part of a tensor, typically used to do Tensor Parallelism sharding. + /// + /// If the tensor is of size (1024, 1024). + /// + /// `dim` corresponds to the dimension to slice into + /// `rank` is the rank of the current process + /// `world_size` is the total number of ranks in the process group + /// + /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` + /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` + /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` pub fn get_sharded( &self, tensor_name: &str, @@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> { let dtype = view.dtype(); let mut shape = view.shape().to_vec(); let size = shape[dim]; + + if size % world_size != 0 { + return Err(Error::ShapeMismatchSplit { + shape: shape.into(), + dim, + n_parts: world_size, + }); + } let block_size = size / world_size; let start = rank * block_size; let stop = (rank + 1) * block_size; let iterator = if dim == 0 { - view.slice(start..stop).unwrap() + view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else if dim == 1 { - view.slice((.., start..stop)).unwrap() + view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else { - unimplemented!("Get sharded on dimensions != 0 or 1"); + candle::bail!("Get sharded on dimensions != 0 or 1") }; shape[dim] = block_size; |