diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-12 10:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-12 10:26:56 +0100 |
commit | 274bf11633f609d92729e11e16618778bed4b868 (patch) | |
tree | f57f8cde7a50880e36bc231d6a5a2d2dea9cf15d /candle-core/src/pickle.rs | |
parent | 1e26d539d9f9574222e8d049fdbfadfa09e3ce2e (diff) | |
download | candle-274bf11633f609d92729e11e16618778bed4b868.tar.gz candle-274bf11633f609d92729e11e16618778bed4b868.tar.bz2 candle-274bf11633f609d92729e11e16618778bed4b868.zip |
Support defaultdict in PyTorch checkpoints. (#1696)
* Support defaultdict in PyTorch checkpoints.
* Fix clippy lint.
Diffstat (limited to 'candle-core/src/pickle.rs')
-rw-r--r-- | candle-core/src/pickle.rs | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index f6d80830..e3f1f81d 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -350,8 +350,10 @@ impl Stack { module_name, class_name, } => { - if module_name == "collections" && class_name == "OrderedDict" { - // TODO: have a separate ordered dict. + if module_name == "collections" + && (class_name == "OrderedDict" || class_name == "defaultdict") + { + // TODO: have a separate ordered dict and a separate default dict. Some(Object::Dict(vec![])) } else { None |