summaryrefslogtreecommitdiff
path: root/candle-examples/examples/resnet/export_models.py
blob: 74ef6e7d3947b977dbdb4fb9b4e0234a1686f6e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt

m = torchvision.models.resnet50(pretrained=True)
stt.save_file(m.state_dict(), 'resnet50.safetensors')
m = torchvision.models.resnet101(pretrained=True)
stt.save_file(m.state_dict(), 'resnet101.safetensors')
m = torchvision.models.resnet152(pretrained=True)
stt.save_file(m.state_dict(), 'resnet152.safetensors')