summaryrefslogtreecommitdiff
path: root/candle-core/src/pickle.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-12 10:26:56 +0100
committerGitHub <noreply@github.com>2024-02-12 10:26:56 +0100
commit274bf11633f609d92729e11e16618778bed4b868 (patch)
treef57f8cde7a50880e36bc231d6a5a2d2dea9cf15d /candle-core/src/pickle.rs
parent1e26d539d9f9574222e8d049fdbfadfa09e3ce2e (diff)
downloadcandle-274bf11633f609d92729e11e16618778bed4b868.tar.gz
candle-274bf11633f609d92729e11e16618778bed4b868.tar.bz2
candle-274bf11633f609d92729e11e16618778bed4b868.zip
Support defaultdict in PyTorch checkpoints. (#1696)
* Support defaultdict in PyTorch checkpoints. * Fix clippy lint.
Diffstat (limited to 'candle-core/src/pickle.rs')
-rw-r--r--candle-core/src/pickle.rs6
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index f6d80830..e3f1f81d 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -350,8 +350,10 @@ impl Stack {
module_name,
class_name,
} => {
- if module_name == "collections" && class_name == "OrderedDict" {
- // TODO: have a separate ordered dict.
+ if module_name == "collections"
+ && (class_name == "OrderedDict" || class_name == "defaultdict")
+ {
+ // TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![]))
} else {
None