summaryrefslogtreecommitdiff
path: root/candle-examples/examples/resnet/export_models.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-19 13:48:28 +0100
committerGitHub <noreply@github.com>2023-10-19 13:48:28 +0100
commit93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (patch)
treebbc00678fd08bbc8cfc392e5ff589477e80172b9 /candle-examples/examples/resnet/export_models.py
parentcd53c472df163b3baaf7c70ca5d4f8909af62324 (diff)
downloadcandle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.gz
candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.bz2
candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.zip
Expose the larger resnets (50/101/152) in the example. (#1131)
Diffstat (limited to 'candle-examples/examples/resnet/export_models.py')
-rw-r--r--candle-examples/examples/resnet/export_models.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/export_models.py b/candle-examples/examples/resnet/export_models.py
new file mode 100644
index 00000000..74ef6e7d
--- /dev/null
+++ b/candle-examples/examples/resnet/export_models.py
@@ -0,0 +1,12 @@
+# This script exports pre-trained model weights in the safetensors format.
+import numpy as np
+import torch
+import torchvision
+from safetensors import torch as stt
+
+m = torchvision.models.resnet50(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet50.safetensors')
+m = torchvision.models.resnet101(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet101.safetensors')
+m = torchvision.models.resnet152(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet152.safetensors')