diff options
Diffstat (limited to 'candle-examples/examples/resnet/export_models.py')
-rw-r--r-- | candle-examples/examples/resnet/export_models.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/export_models.py b/candle-examples/examples/resnet/export_models.py new file mode 100644 index 00000000..74ef6e7d --- /dev/null +++ b/candle-examples/examples/resnet/export_models.py @@ -0,0 +1,12 @@ +# This script exports pre-trained model weights in the safetensors format. +import numpy as np +import torch +import torchvision +from safetensors import torch as stt + +m = torchvision.models.resnet50(pretrained=True) +stt.save_file(m.state_dict(), 'resnet50.safetensors') +m = torchvision.models.resnet101(pretrained=True) +stt.save_file(m.state_dict(), 'resnet101.safetensors') +m = torchvision.models.resnet152(pretrained=True) +stt.save_file(m.state_dict(), 'resnet152.safetensors') |