diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 41 | ||||
-rw-r--r-- | candle-core/tests/pool_tests.rs | 10 | ||||
-rw-r--r-- | candle-examples/examples/bert/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 5 | ||||
-rw-r--r-- | candle-kernels/src/conv.cu | 62 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/worker.rs | 5 |
10 files changed, 119 insertions, 26 deletions
@@ -48,7 +48,7 @@ safetensors = "0.3.1" serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.13.3", default-features = false } +tokenizers = { version = "0.13.4", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 90d3ee6d..0a73c023 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -980,7 +980,7 @@ impl Map1 for Pool2D { dev: &CudaDevice, inp_l: &Layout, ) -> Result<CudaSlice<T>> { - // Kernel shape: (c_out, c_in_k, w_k, h_k) + // Input shape: (b_size, c, h, w) let inp = &inp.slice(inp_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); @@ -1018,6 +1018,39 @@ impl Map1 for Pool2D { } } +struct UpsampleNearest2D(usize, usize); +impl Map1 for UpsampleNearest2D { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + inp: &CudaSlice<T>, + dev: &CudaDevice, + inp_l: &Layout, + ) -> Result<CudaSlice<T>> { + // Input shape: (b_size, c, h, w) + let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let ds = if dims.len() == 4 { + [dims, inp_l.stride()].concat() + } else { + panic!("unexpected input shape for conv1d {dims:?}") + }; + let (out_w, out_h) = (self.0, self.1); + let dst_el = out_w * out_h * dims[0] * dims[1]; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let ds = dev.htod_copy(ds).w()?; + let scale_w = dims[2] as f64 / out_w as f64; + let scale_h = dims[3] as f64 / out_h as f64; + let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct WhereCond<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map2 for WhereCond<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1513,8 +1546,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { - todo!() + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { + let device = self.device().clone(); + let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 009564fa..d2eb8f3f 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -56,9 +56,8 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { Ok(()) } -#[test] -fn upsample_nearest2d() -> Result<()> { - let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?; +fn upsample_nearest2d(dev: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?; let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?; assert_eq!( t.i(0)?.i(0)?.to_vec2::<f32>()?, @@ -83,3 +82,8 @@ test_device!( avg_pool2d_pytorch_gpu ); test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu); +test_device!( + upsample_nearest2d, + upsample_nearest2d_cpu, + upsample_nearest2d_gpu +); diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 574755ed..7f0ae7b1 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -111,7 +111,10 @@ fn main() -> Result<()> { let device = &model.device; if let Some(prompt) = args.prompt { - let tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index ac9c63c7..39b1de27 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -65,10 +65,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - let token = self - .tokenizer - .decode(vec![next_token], true) - .map_err(E::msg)?; + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; print!("{token}"); std::io::stdout().flush()?; } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index c37d9a96..0df3a001 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -72,16 +72,14 @@ impl TextGeneration { "{} token: {} '{}'", index + 1, next_token, - self.tokenizer - .decode(vec![next_token], true) - .map_err(E::msg)? + self.tokenizer.decode(&[next_token], true).map_err(E::msg)? ); } let dt = start_gen.elapsed(); println!( "{sample_len} tokens generated ({} token/s)\n----\n{}\n----", sample_len as f64 / dt.as_secs_f64(), - self.tokenizer.decode(new_tokens, true).map_err(E::msg)? + self.tokenizer.decode(&new_tokens, true).map_err(E::msg)? ); Ok(()) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 9a62eba5..b1e112fd 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -223,7 +223,7 @@ fn main() -> Result<()> { "{} token: {} '{}'", index + 1, next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + tokenizer.decode(&[next_token], true).map_err(E::msg)? ); } let dt = start_gen.elapsed(); @@ -231,7 +231,7 @@ fn main() -> Result<()> { "{} tokens generated ({} token/s)\n----\n{}\n----", args.sample_len, args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? + tokenizer.decode(&new_tokens, true).map_err(E::msg)? ); Ok(()) } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 99919f8d..5c58c002 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -169,10 +169,7 @@ impl Decoder { } sum_logprob += prob.ln(); } - let text = self - .tokenizer - .decode(tokens.clone(), true) - .map_err(E::msg)?; + let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; let avg_logprob = sum_logprob / tokens.len() as f64; Ok(DecodingResult { diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 2da4d401..afda7d1d 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -220,6 +220,48 @@ __device__ void max_pool2d( dst[dst_i] = d; } +template <typename T> +__device__ void upsample_nearest2d( + const size_t w_out, + const size_t h_out, + const double w_scale, + const double h_scale, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + size_t src_w = static_cast<size_t>(dst_w * w_scale); + size_t src_h = static_cast<size_t>(dst_h * h_scale); + if (src_w >= w_in) { + src_w = w_in - 1; + } + if (src_h >= h_in) { + src_h = h_in - 1; + } + + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + dst[dst_i] = src[src_i]; +} + #define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -278,11 +320,25 @@ extern "C" __global__ void FN_NAME( \ max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ } \ +#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t w_out, \ + const size_t h_out, \ + const double w_scale, \ + const double h_scale, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, info, src, dst); \ +} \ + #if __CUDA_ARCH__ >= 800 CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) +UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -290,6 +346,7 @@ CONV1D_OP(__half, float, conv1d_f16) CONV2D_OP(__half, float, conv2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16) +UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -311,3 +368,8 @@ MAX_POOL2D_OP(float, max_pool2d_f32) MAX_POOL2D_OP(double, max_pool2d_f64) MAX_POOL2D_OP(uint8_t, max_pool2d_u8) MAX_POOL2D_OP(uint32_t, max_pool2d_u32) + +UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) +UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) +UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 139755cb..d77d3e32 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -159,10 +159,7 @@ impl Decoder { } sum_logprob += prob.ln(); } - let text = self - .tokenizer - .decode(tokens.clone(), true) - .map_err(E::msg)?; + let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; let avg_logprob = sum_logprob / tokens.len() as f64; Ok(DecodingResult { |