mirror of
https://github.com/augustin64/projet-tipe
synced 2025-02-02 19:39:39 +01:00
Add benchmark_binary.py
This commit is contained in:
parent
de21f865cb
commit
b3b918aa4f
84
src/scripts/benchmark_binary.py
Normal file
84
src/scripts/benchmark_binary.py
Normal file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/python3
|
||||
#
|
||||
# Steps to use:
|
||||
# - modify src/scripts/generate_binaries.sh to suit your needs
|
||||
# - execute the following commands:
|
||||
# ```bash
|
||||
# src/scripts/generate_binaries.sh
|
||||
# python -i src/scripts/benchmark_binary.py
|
||||
# >>> compare_binaries(["binaries/"+i for i in os.listdir("binaries")], tries=4, dataset='train')
|
||||
# ```
|
||||
# tries is the number of epochs to train on and dataset is the dataset to use ('train' or 't10k')
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
def train(binary, base_network, tries=3, dataset="train"):
|
||||
temp_out = f".cache/tmp-{abs(hash(binary))}.bin"
|
||||
results = []
|
||||
fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 -o {temp_out} --recover {base_network}")
|
||||
output = subprocess.check_output([
|
||||
binary,
|
||||
'test',
|
||||
'--modele', temp_out,
|
||||
'-d', 'mnist',
|
||||
'-i', 'data/mnist/t10k-images-idx3-ubyte',
|
||||
'-l', "data/mnist/t10k-labels-idx1-ubyte"
|
||||
]).decode("utf-8")
|
||||
results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0]))
|
||||
for i in range(tries-1):
|
||||
if fail == 0:
|
||||
fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 --out {temp_out} --recover {temp_out}")
|
||||
output = subprocess.check_output([
|
||||
binary,
|
||||
'test',
|
||||
'--modele', temp_out,
|
||||
'-d', 'mnist',
|
||||
'-i', 'data/mnist/t10k-images-idx3-ubyte',
|
||||
'-l', "data/mnist/t10k-labels-idx1-ubyte"
|
||||
]).decode("utf-8")
|
||||
results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0]))
|
||||
else:
|
||||
results.append(results[-1])
|
||||
|
||||
return results
|
||||
|
||||
def create_base_network(binary, file):
|
||||
os.system(f"{binary} train --dataset mnist --images data/mnist/train-images-idx3-ubyte --labels data/mnist/train-labels-idx1-ubyte --epochs 0 --out {file}")
|
||||
|
||||
|
||||
def compare_binaries(binaries, tries=3, dataset="train"):
|
||||
print(f"========== {len(binaries)} Fichiers chargés ==========")
|
||||
base_net = f".cache/basenet-{abs(hash(''.join(binaries)))}.bin"
|
||||
create_base_network(binaries[0], base_net)
|
||||
results = []
|
||||
for i in range(len(binaries)):
|
||||
binary = binaries[i]
|
||||
print(f"========== Benchmmark de {binary} ({1+i}/{len(binaries)}) ==========")
|
||||
try:
|
||||
results.append(train(binary, base_net, tries, dataset=dataset))
|
||||
except:
|
||||
print(f"========== Erreur sur {binary} ==========")
|
||||
results.append(0)
|
||||
|
||||
x = [i for i in range(tries)]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
res = []
|
||||
for i in range(len(binaries)):
|
||||
if results[i] != 0:
|
||||
res.append(ax.plot(x, results[i])[0])
|
||||
res[i].set_label(binaries[i])
|
||||
|
||||
ax.set_ylabel("Taux de réussite (%)")
|
||||
ax.set_xlabel("Nombre de batchs")
|
||||
|
||||
ax.legend()
|
||||
|
||||
plt.ylim(0, 100)
|
||||
plt.show()
|
||||
return results
|
||||
|
Loading…
Reference in New Issue
Block a user