summaryrefslogtreecommitdiff
path: root/candle-examples/examples/yolo-v3/extract-weights.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-20 23:19:15 +0100
committerGitHub <noreply@github.com>2023-08-20 23:19:15 +0100
commit11c7e7bd672bea0da05207d8fdea0dfe8bb14e46 (patch)
tree5090e6f687fb8e3819867d00224cb13556036bfd /candle-examples/examples/yolo-v3/extract-weights.py
parenta1812f934f4e0830ed3c2f147d13837ccf67f2bd (diff)
downloadcandle-11c7e7bd672bea0da05207d8fdea0dfe8bb14e46.tar.gz
candle-11c7e7bd672bea0da05207d8fdea0dfe8bb14e46.tar.bz2
candle-11c7e7bd672bea0da05207d8fdea0dfe8bb14e46.zip
Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3. * Use the running stats for inference in the batch-norm layer. * Get some proper predictions for yolo. * Avoid the quadratic insertion.
Diffstat (limited to 'candle-examples/examples/yolo-v3/extract-weights.py')
-rw-r--r--candle-examples/examples/yolo-v3/extract-weights.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-examples/examples/yolo-v3/extract-weights.py b/candle-examples/examples/yolo-v3/extract-weights.py
new file mode 100644
index 00000000..4e523ee6
--- /dev/null
+++ b/candle-examples/examples/yolo-v3/extract-weights.py
@@ -0,0 +1,7 @@
+def remove_prefix(text, prefix):
+ return text[text.startswith(prefix) and len(prefix):]
+nps = {}
+for k, v in model.state_dict().items():
+ k = remove_prefix(k, 'module_list.')
+ nps[k] = v.detach().numpy()
+np.savez('yolo-v3.ot', **nps)