From 87b37aee80b1750250687104e48502ef7d024935 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Tue, 28 Feb 2023 13:16:39 +0100 Subject: [PATCH] Update benchmark scripts --- src/scripts/benchmark_binary.py | 147 ++++++++++++++++++++++++------- src/scripts/generate_binaries.sh | 4 +- 2 files changed, 119 insertions(+), 32 deletions(-) diff --git a/src/scripts/benchmark_binary.py b/src/scripts/benchmark_binary.py index 75e6a9a..19f4522 100644 --- a/src/scripts/benchmark_binary.py +++ b/src/scripts/benchmark_binary.py @@ -15,11 +15,34 @@ import os from matplotlib import pyplot as plt +try: + from tqdm import tqdm + progress = tqdm +except ModuleNotFoundError: + progress = lambda x : x + def train(binary, base_network, tries=3, dataset="train"): temp_out = f".cache/tmp-{abs(hash(binary))}.bin" results = [] - fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 -o {temp_out} --recover {base_network}") - output = subprocess.check_output([ + + # 1e époque training + try: + train_output = subprocess.check_output([ + binary, + "train", + "--dataset", "mnist", + "--images", f"data/mnist/{dataset}-images-idx3-ubyte", + "--labels", f"data/mnist/{dataset}-labels-idx1-ubyte", + "--epochs", "1", + "-o", temp_out, + "--recover", base_network + ]).decode("utf-8") + fail = 0 + except: + fail = 1 + + # 1e époque testing + test_output = subprocess.check_output([ binary, 'test', '--modele', temp_out, @@ -27,11 +50,39 @@ def train(binary, base_network, tries=3, dataset="train"): '-i', 'data/mnist/t10k-images-idx3-ubyte', '-l', "data/mnist/t10k-labels-idx1-ubyte" ]).decode("utf-8") - results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0])) - for i in range(tries-1): + + results.append({ + "train": { + "accuracy": float(train_output.split("Accuracy: \x1b[32m")[-1].split("%")[0]), # \x1b[32m est la couleur verte pour le terminal + "loss": float(train_output.split("Loss: ")[-1].split("\t")[0]) + }, + "test": { + "accuracy": float(test_output.split("Taux de réussite: ")[-1].split("%")[0]), + "loss": float(test_output.split("Loss: ")[-1]) + } + }) + + for i in progress(range(tries-1)): + # On ne continue pas si on a déjà eu une saturation du réseau, on ajoute juste des valeurs if fail == 0: - fail = os.system(f"{binary} train --dataset mnist --images data/mnist/{dataset}-images-idx3-ubyte --labels data/mnist/{dataset}-labels-idx1-ubyte --epochs 1 --out {temp_out} --recover {temp_out}") - output = subprocess.check_output([ + # i-ème époque training + try: + train_output = subprocess.check_output([ + binary, + "train", + "--dataset", "mnist", + "--images", f"data/mnist/{dataset}-images-idx3-ubyte", + "--labels", f"data/mnist/{dataset}-labels-idx1-ubyte", + "--epochs", "1", + "-o", temp_out, + "--recover", temp_out + ]).decode("utf-8") + fail = 0 + except: + fail = 1 + + # i-ème époque testing + test_output = subprocess.check_output([ binary, 'test', '--modele', temp_out, @@ -39,9 +90,30 @@ def train(binary, base_network, tries=3, dataset="train"): '-i', 'data/mnist/t10k-images-idx3-ubyte', '-l', "data/mnist/t10k-labels-idx1-ubyte" ]).decode("utf-8") - results.append(float(output.split('\r')[-1].split(" ")[-1].split("%")[0])) + + # Ajout des résultats + results.append({ + "train": { + "accuracy": float(train_output.split("Accuracy: \x1b[32m")[-1].split("%")[0]), # \x1b[32m est la couleur verte pour le terminal + "loss": float(train_output.split("Loss: ")[-1].split("\t")[0]) + }, + "test": { + "accuracy": float(test_output.split("Taux de réussite: ")[-1].split("%")[0]), + "loss": float(test_output.split("Loss: ")[-1]) + } + }) else: - results.append(results[-1]) + # Le réseau a saturé + results.append({ + "train": { + "accuracy": 0, + "loss": 0 + }, + "test": { + "accuracy": 0, + "loss": 0 + } + }) return results @@ -49,35 +121,50 @@ def create_base_network(binary, file): os.system(f"{binary} train --dataset mnist --images data/mnist/train-images-idx3-ubyte --labels data/mnist/train-labels-idx1-ubyte --epochs 0 --out {file}") -def compare_binaries(binaries, tries=3, dataset="train"): - print(f"========== {len(binaries)} Fichiers chargés ==========") - base_net = f".cache/basenet-{abs(hash(''.join(binaries)))}.bin" - create_base_network(binaries[0], base_net) - results = [] - for i in range(len(binaries)): - binary = binaries[i] - print(f"========== Benchmmark de {binary} ({1+i}/{len(binaries)}) ==========") - try: - results.append(train(binary, base_net, tries, dataset=dataset)) - except: - print(f"========== Erreur sur {binary} ==========") - results.append([0.]*tries) +""" +binaries: list of files +tries: number of epochs to train on +dataset: must be "train" or "t10k" +metric: "accuracy" or "loss" +""" +def compare_binaries(binaries, tries=3, dataset="train", metric="accuracy", values=None): + if values is None: + print(f"========== {len(binaries)} Fichiers chargés ==========") + base_net = f".cache/basenet-{abs(hash(''.join(binaries)))}.bin" + create_base_network(binaries[0], base_net) + results = {} + for i in range(len(binaries)): + binary = binaries[i] + print(f"========== Benchmark de {binary} ({1+i}/{len(binaries)}) ==========") + try: + results[binaries[i]] = (train(binary, base_net, tries, dataset=dataset)) + except Exception as e: + print(f"========== Erreur sur {binary} ==========") + print(e) + # Delete value if nothing happened + # results.append([{'train': {'accuracy': 0., 'loss': 0.}, 'test': {'accuracy': 0., 'loss': 0.}}]*tries) + else: + results = values + binaries = [key for key in values.keys()] - x = [i for i in range(tries)] + x = [i+1 for i in range(tries)] fig, ax = plt.subplots() - res = [] - for i in range(len(binaries)): - res.append(ax.plot(x, results[i])[0]) - res[i].set_label(binaries[i]) + for binary in results.keys(): + for key in results[binary][0].keys(): + key_values = [j[key][metric] for j in results[binary]] + + courbe = ax.plot(x, key_values)[0] + courbe.set_label(f"{key}/{binary}") + - ax.set_ylabel("Taux de réussite (%)") - ax.set_xlabel("Nombre de batchs") + ax.set_ylabel(f"{metric}") + ax.set_xlabel("Nombre d'époques") ax.legend() - plt.ylim(0, 100) + # plt.ylim(0, 100) plt.show() - return binaries, results + return results \ No newline at end of file diff --git a/src/scripts/generate_binaries.sh b/src/scripts/generate_binaries.sh index 35cd144..238f0b5 100755 --- a/src/scripts/generate_binaries.sh +++ b/src/scripts/generate_binaries.sh @@ -16,8 +16,8 @@ values="0 5 25 50" # Example values for val in $values; do # For a variable # sed -i 's/'"$VARIABLE_TO_MODIFY"'=.*/'"$VARIABLE_TO_MODIFY'='$val"';/g' "$FILE_TO_MODIFY" - # For a define + # For a #define sed -i 's/#define '"$VARIABLE_TO_MODIFY"' .*$/#define '"$VARIABLE_TO_MODIFY"' '"$val"'/g' "$FILE_TO_MODIFY" - make all + make -j$(nproc) all cp build/cnn-main "$BIN_OUT"/"$VARIABLE_TO_MODIFY=$val" done