summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-27 16:59:32 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-07-27 16:59:32 +0200
commit952eca6b540078b1f30b58d9eb930f8e32d903cb (patch)
tree8eb5378393d9c64482638aae5dfeaaabb4ff0248
parent25a2086e8f4cc23fada32a44607d3b8550916ebe (diff)
downloadcandle-952eca6b540078b1f30b58d9eb930f8e32d903cb.tar.gz
candle-952eca6b540078b1f30b58d9eb930f8e32d903cb.tar.bz2
candle-952eca6b540078b1f30b58d9eb930f8e32d903cb.zip
Fixing slice errors + comments.
-rw-r--r--candle-core/src/error.rs7
-rw-r--r--candle-nn/src/var_builder.rs25
2 files changed, 29 insertions, 3 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index f9e69122..30d06239 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -79,6 +79,13 @@ pub enum Error {
nth_shape: Shape,
},
+ #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
+ ShapeMismatchSplit {
+ shape: Shape,
+ dim: usize,
+ n_parts: usize,
+ },
+
#[error("{op} can only be performed on a single dimension")]
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 1466f6d0..3133f210 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> {
}
impl<'a> VarBuilder<'a> {
+ /// Get part of a tensor, typically used to do Tensor Parallelism sharding.
+ ///
+ /// If the tensor is of size (1024, 1024).
+ ///
+ /// `dim` corresponds to the dimension to slice into
+ /// `rank` is the rank of the current process
+ /// `world_size` is the total number of ranks in the process group
+ ///
+ /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
+ /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
+ /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
pub fn get_sharded(
&self,
tensor_name: &str,
@@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> {
let dtype = view.dtype();
let mut shape = view.shape().to_vec();
let size = shape[dim];
+
+ if size % world_size != 0 {
+ return Err(Error::ShapeMismatchSplit {
+ shape: shape.into(),
+ dim,
+ n_parts: world_size,
+ });
+ }
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
let iterator = if dim == 0 {
- view.slice(start..stop).unwrap()
+ view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
} else if dim == 1 {
- view.slice((.., start..stop)).unwrap()
+ view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
} else {
- unimplemented!("Get sharded on dimensions != 0 or 1");
+ candle::bail!("Get sharded on dimensions != 0 or 1")
};
shape[dim] = block_size;