summaryrefslogtreecommitdiff
path: root/candle-examples/examples/resnet/export_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/resnet/export_models.py')
-rw-r--r--candle-examples/examples/resnet/export_models.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/export_models.py b/candle-examples/examples/resnet/export_models.py
new file mode 100644
index 00000000..74ef6e7d
--- /dev/null
+++ b/candle-examples/examples/resnet/export_models.py
@@ -0,0 +1,12 @@
+# This script exports pre-trained model weights in the safetensors format.
+import numpy as np
+import torch
+import torchvision
+from safetensors import torch as stt
+
+m = torchvision.models.resnet50(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet50.safetensors')
+m = torchvision.models.resnet101(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet101.safetensors')
+m = torchvision.models.resnet152(pretrained=True)
+stt.save_file(m.state_dict(), 'resnet152.safetensors')