diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-19 13:48:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-19 13:48:28 +0100 |
commit | 93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (patch) | |
tree | bbc00678fd08bbc8cfc392e5ff589477e80172b9 /candle-examples/examples/resnet/export_models.py | |
parent | cd53c472df163b3baaf7c70ca5d4f8909af62324 (diff) | |
download | candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.gz candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.bz2 candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.zip |
Expose the larger resnets (50/101/152) in the example. (#1131)
Diffstat (limited to 'candle-examples/examples/resnet/export_models.py')
-rw-r--r-- | candle-examples/examples/resnet/export_models.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/export_models.py b/candle-examples/examples/resnet/export_models.py new file mode 100644 index 00000000..74ef6e7d --- /dev/null +++ b/candle-examples/examples/resnet/export_models.py @@ -0,0 +1,12 @@ +# This script exports pre-trained model weights in the safetensors format. +import numpy as np +import torch +import torchvision +from safetensors import torch as stt + +m = torchvision.models.resnet50(pretrained=True) +stt.save_file(m.state_dict(), 'resnet50.safetensors') +m = torchvision.models.resnet101(pretrained=True) +stt.save_file(m.state_dict(), 'resnet101.safetensors') +m = torchvision.models.resnet152(pretrained=True) +stt.save_file(m.state_dict(), 'resnet152.safetensors') |