Update benchmark scripts

This commit is contained in:
augustin64 2023-02-28 13:16:39 +01:00
parent 1db6c6824d
commit 87b37aee80
2 changed files with 119 additions and 32 deletions

View File

@ -15,11 +15,34 @@ import os
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
try:
from tqdm import tqdm
progress = tqdm
except ModuleNotFoundError:
progress = lambda x : x
def train(binary, base_network, tries=3, dataset="train"): def train(binary, base_network, tries=3, dataset="train"):
temp_out = f".cache/tmp-{abs(hash(binary))}.bin" temp_out = f".cache/tmp-{abs(hash(binary))}.bin"
results = [] results = []
fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 -o {temp_out} --recover {base_network}")
output = subprocess.check_output([ # 1e époque training
try:
train_output = subprocess.check_output([
binary,
"train",
"--dataset", "mnist",
"--images", f"data/mnist/{dataset}-images-idx3-ubyte",
"--labels", f"data/mnist/{dataset}-labels-idx1-ubyte",
"--epochs", "1",
"-o", temp_out,
"--recover", base_network
]).decode("utf-8")
fail = 0
except:
fail = 1
# 1e époque testing
test_output = subprocess.check_output([
binary, binary,
'test', 'test',
'--modele', temp_out, '--modele', temp_out,
@ -27,11 +50,39 @@ def train(binary, base_network, tries=3, dataset="train"):
'-i', 'data/mnist/t10k-images-idx3-ubyte', '-i', 'data/mnist/t10k-images-idx3-ubyte',
'-l', "data/mnist/t10k-labels-idx1-ubyte" '-l', "data/mnist/t10k-labels-idx1-ubyte"
]).decode("utf-8") ]).decode("utf-8")
results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0]))
for i in range(tries-1): results.append({
"train": {
"accuracy": float(train_output.split("Accuracy: \x1b[32m")[-1].split("%")[0]), # \x1b[32m est la couleur verte pour le terminal
"loss": float(train_output.split("Loss: ")[-1].split("\t")[0])
},
"test": {
"accuracy": float(test_output.split("Taux de réussite: ")[-1].split("%")[0]),
"loss": float(test_output.split("Loss: ")[-1])
}
})
for i in progress(range(tries-1)):
# On ne continue pas si on a déjà eu une saturation du réseau, on ajoute juste des valeurs
if fail == 0: if fail == 0:
fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 --out {temp_out} --recover {temp_out}") # i-ème époque training
output = subprocess.check_output([ try:
train_output = subprocess.check_output([
binary,
"train",
"--dataset", "mnist",
"--images", f"data/mnist/{dataset}-images-idx3-ubyte",
"--labels", f"data/mnist/{dataset}-labels-idx1-ubyte",
"--epochs", "1",
"-o", temp_out,
"--recover", temp_out
]).decode("utf-8")
fail = 0
except:
fail = 1
# i-ème époque testing
test_output = subprocess.check_output([
binary, binary,
'test', 'test',
'--modele', temp_out, '--modele', temp_out,
@ -39,9 +90,30 @@ def train(binary, base_network, tries=3, dataset="train"):
'-i', 'data/mnist/t10k-images-idx3-ubyte', '-i', 'data/mnist/t10k-images-idx3-ubyte',
'-l', "data/mnist/t10k-labels-idx1-ubyte" '-l', "data/mnist/t10k-labels-idx1-ubyte"
]).decode("utf-8") ]).decode("utf-8")
results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0]))
# Ajout des résultats
results.append({
"train": {
"accuracy": float(train_output.split("Accuracy: \x1b[32m")[-1].split("%")[0]), # \x1b[32m est la couleur verte pour le terminal
"loss": float(train_output.split("Loss: ")[-1].split("\t")[0])
},
"test": {
"accuracy": float(test_output.split("Taux de réussite: ")[-1].split("%")[0]),
"loss": float(test_output.split("Loss: ")[-1])
}
})
else: else:
results.append(results[-1]) # Le réseau a saturé
results.append({
"train": {
"accuracy": 0,
"loss": 0
},
"test": {
"accuracy": 0,
"loss": 0
}
})
return results return results
@ -49,35 +121,50 @@ def create_base_network(binary, file):
os.system(f"{binary} train --dataset mnist --images data/mnist/train-images-idx3-ubyte --labels data/mnist/train-labels-idx1-ubyte --epochs 0 --out {file}") os.system(f"{binary} train --dataset mnist --images data/mnist/train-images-idx3-ubyte --labels data/mnist/train-labels-idx1-ubyte --epochs 0 --out {file}")
def compare_binaries(binaries, tries=3, dataset="train"): """
print(f"========== {len(binaries)} Fichiers chargés ==========") binaries: list of files
base_net = f".cache/basenet-{abs(hash(''.join(binaries)))}.bin" tries: number of epochs to train on
create_base_network(binaries[0], base_net) dataset: must be "train" or "t10k"
results = [] metric: "accuracy" or "loss"
for i in range(len(binaries)): """
binary = binaries[i] def compare_binaries(binaries, tries=3, dataset="train", metric="accuracy", values=None):
print(f"========== Benchmmark de {binary} ({1+i}/{len(binaries)}) ==========") if values is None:
try: print(f"========== {len(binaries)} Fichiers chargés ==========")
results.append(train(binary, base_net, tries, dataset=dataset)) base_net = f".cache/basenet-{abs(hash(''.join(binaries)))}.bin"
except: create_base_network(binaries[0], base_net)
print(f"========== Erreur sur {binary} ==========") results = {}
results.append([0.]*tries) for i in range(len(binaries)):
binary = binaries[i]
print(f"========== Benchmark de {binary} ({1+i}/{len(binaries)}) ==========")
try:
results[binaries[i]] = (train(binary, base_net, tries, dataset=dataset))
except Exception as e:
print(f"========== Erreur sur {binary} ==========")
print(e)
# Delete value if nothing happened
# results.append([{'train': {'accuracy': 0., 'loss': 0.}, 'test': {'accuracy': 0., 'loss': 0.}}]*tries)
else:
results = values
binaries = [key for key in values.keys()]
x = [i for i in range(tries)] x = [i+1 for i in range(tries)]
fig, ax = plt.subplots() fig, ax = plt.subplots()
res = [] for binary in results.keys():
for i in range(len(binaries)): for key in results[binary][0].keys():
res.append(ax.plot(x, results[i])[0]) key_values = [j[key][metric] for j in results[binary]]
res[i].set_label(binaries[i])
courbe = ax.plot(x, key_values)[0]
courbe.set_label(f"{key}/{binary}")
ax.set_ylabel("Taux de réussite (%)") ax.set_ylabel(f"{metric}")
ax.set_xlabel("Nombre de batchs") ax.set_xlabel("Nombre d'époques")
ax.legend() ax.legend()
plt.ylim(0, 100) # plt.ylim(0, 100)
plt.show() plt.show()
return binaries, results return results

View File

@ -16,8 +16,8 @@ values="0 5 25 50" # Example values
for val in $values; do for val in $values; do
# For a variable # For a variable
# sed -i 's/'"$VARIABLE_TO_MODIFY"'=.*/'"$VARIABLE_TO_MODIFY'='$val"';/g' "$FILE_TO_MODIFY" # sed -i 's/'"$VARIABLE_TO_MODIFY"'=.*/'"$VARIABLE_TO_MODIFY'='$val"';/g' "$FILE_TO_MODIFY"
# For a define # For a #define
sed -i 's/#define '"$VARIABLE_TO_MODIFY"' .*$/#define '"$VARIABLE_TO_MODIFY"' '"$val"'/g' "$FILE_TO_MODIFY" sed -i 's/#define '"$VARIABLE_TO_MODIFY"' .*$/#define '"$VARIABLE_TO_MODIFY"' '"$val"'/g' "$FILE_TO_MODIFY"
make all make -j$(nproc) all
cp build/cnn-main "$BIN_OUT"/"$VARIABLE_TO_MODIFY=$val" cp build/cnn-main "$BIN_OUT"/"$VARIABLE_TO_MODIFY=$val"
done done