diff options
Diffstat (limited to 'candle-nn/src/func.rs')
-rw-r--r-- | candle-nn/src/func.rs | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index e7fd73ae..39311d45 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -1,10 +1,12 @@ //! Layers defined by closures. use candle::{Result, Tensor}; +use std::sync::Arc; /// A layer defined by a simple closure. +#[derive(Clone)] pub struct Func<'a> { #[allow(clippy::type_complexity)] - f: Box<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send>, + f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>, } impl<'a> std::fmt::Debug for Func<'a> { @@ -15,9 +17,9 @@ impl<'a> std::fmt::Debug for Func<'a> { pub fn func<'a, F>(f: F) -> Func<'a> where - F: 'a + Fn(&Tensor) -> Result<Tensor> + Send, + F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync, { - Func { f: Box::new(f) } + Func { f: Arc::new(f) } } impl<'a> super::Module for Func<'a> { @@ -29,8 +31,8 @@ impl<'a> super::Module for Func<'a> { impl<'a> Func<'a> { pub fn new<F>(f: F) -> Self where - F: 'a + Fn(&Tensor) -> Result<Tensor> + Send, + F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync, { - Self { f: Box::new(f) } + Self { f: Arc::new(f) } } } |