From b75388f4639f99e8900ea072580dcf6aaca99cdd Mon Sep 17 00:00:00 2001 From: augustin64 Date: Mon, 23 Jan 2023 21:16:36 +0100 Subject: [PATCH] Add CNN to webserver --- Makefile | 14 +++++- src/cnn/include/test_network.h | 6 +-- src/cnn/main.c | 11 ++++- src/cnn/test_network.c | 75 +++++++++++++++++++++++++----- src/webserver/app.py | 25 ++++++++-- src/webserver/static/script.js | 47 +++++++++++-------- src/webserver/static/style.css | 9 ++-- src/webserver/templates/mnist.html | 3 +- 8 files changed, 146 insertions(+), 44 deletions(-) diff --git a/Makefile b/Makefile index 0df95ac..5bd174b 100644 --- a/Makefile +++ b/Makefile @@ -136,12 +136,22 @@ endif webserver: $(CACHE_DIR)/mnist-reseau.bin FLASK_APP="src/webserver/app.py" flask run -$(CACHE_DIR)/mnist-reseau.bin: $(BUILDDIR)/mnist-main +$(CACHE_DIR)/mnist-reseau-fully-connected.bin: $(BUILDDIR)/mnist-main @mkdir -p $(CACHE_DIR) $(BUILDDIR)/mnist-main train \ --images "data/mnist/train-images-idx3-ubyte" \ --labels "data/mnist/train-labels-idx1-ubyte" \ - --out "$(CACHE_DIR)/mnist-reseau.bin" + --out "$(CACHE_DIR)/mnist-reseau-fully-connected.bin" + + +$(CACHE_DIR)/mnist-reseau-cnn.bin: $(BUILDDIR)/cnn-main + @mkdir -p $(CACHE_DIR) + $(BUILDDIR)/cnn-main train \ + --dataset mnist \ + --images data/mnist/train-images-idx3-ubyte \ + --labels data/mnist/train-labels-idx1-ubyte \ + --epochs 10 \ + --out $(CACHE_DIR)/mnist-reseau-cnn.bin # diff --git a/src/cnn/include/test_network.h b/src/cnn/include/test_network.h index 1c924e9..de7f62b 100644 --- a/src/cnn/include/test_network.h +++ b/src/cnn/include/test_network.h @@ -11,15 +11,15 @@ void test_network(int dataset_type, char* modele, char* images_file, char* label /* * Classifie un fichier d'images sous le format MNIST à partir d'un réseau préalablement entraîné */ -void recognize_mnist(Network* network, char* input_file); +void recognize_mnist(Network* network, char* input_file, char* out); /* * Classifie une image jpg à partir d'un réseau préalablement entraîné */ -void recognize_jpg(Network* network, char* input_file); +void recognize_jpg(Network* network, char* input_file,char* out); /* * Classifie une image à partir d'un réseau préalablement entraîné */ -void recognize(int dataset_type, char* modele, char* input_file); +void recognize(int dataset_type, char* modele, char* input_file, char* out); #endif \ No newline at end of file diff --git a/src/cnn/main.c b/src/cnn/main.c index ba90b6f..ccec369 100644 --- a/src/cnn/main.c +++ b/src/cnn/main.c @@ -31,6 +31,7 @@ void help(char* call) { printf("\t\t--dataset | -d (mnist|jpg)\tFormat de l'image à reconnaître.\n"); printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné.\n"); printf("\t\t--input | -i [FILENAME]\tImage jpeg ou fichier binaire à reconnaître.\n"); + printf("\t\t--out | -o (text|json)\tFormat de sortie.\n"); printf("\ttest:\n"); printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné.\n"); printf("\t\t--dataset | -d (mnist|jpg)\tFormat du set de données.\n"); @@ -192,6 +193,7 @@ int main(int argc, char* argv[]) { char* dataset = NULL; // mnist ou jpg char* modele = NULL; // Fichier contenant le modèle char* input_file = NULL; // Image à reconnaître + char* out = NULL; int dataset_type; int i = 2; while (i < argc) { @@ -203,6 +205,10 @@ int main(int argc, char* argv[]) { modele = argv[i+1]; i += 2; } + else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) { + out = argv[i+1]; + i += 2; + } else if ((! strcmp(argv[i], "--input"))||(! strcmp(argv[i], "-i"))) { input_file = argv[i+1]; i += 2; @@ -224,11 +230,14 @@ int main(int argc, char* argv[]) { printf("Pas de fichier d'entrée spécifié, rien à faire.\n"); return 1; } + if (!out) { + out = "text"; + } if (!modele) { printf("Pas de modèle à utiliser spécifié.\n"); return 1; } - recognize(dataset_type, modele, input_file); + recognize(dataset_type, modele, input_file, out); return 0; } printf("Option choisie non reconnue: %s\n", argv[1]); diff --git a/src/cnn/test_network.c b/src/cnn/test_network.c index 848789e..bd53641 100644 --- a/src/cnn/test_network.c +++ b/src/cnn/test_network.c @@ -1,6 +1,7 @@ #include #include #include +#include #include "../mnist/include/mnist.h" #include "include/neuron_io.h" @@ -97,10 +98,9 @@ void test_network(int dataset_type, char* modele, char* images_file, char* label } -void recognize_mnist(Network* network, char* input_file) { +void recognize_mnist(Network* network, char* input_file, char* out) { int width, height; // Dimensions de l'image int nb_elem; // Nombre d'éléments - int maxi; // Catégorie reconnue // Load image int* mnist_parameters = read_mnist_images_parameters(input_file); @@ -111,24 +111,54 @@ void recognize_mnist(Network* network, char* input_file) { height = mnist_parameters[2]; free(mnist_parameters); - printf("Image\tCatégorie détectée\n"); + if (! strcmp(out, "json")) { + printf("{\n"); + } else { + printf("Image\tCatégorie détectée\n"); + } // Load image in the first layer of the Network for (int i=0; i < nb_elem; i++) { + if (! strcmp(out, "json")) { + printf("\"%d\" : [", i); + } + write_image_in_network_32(images[i], height, width, network->input[0][0]); forward_propagation(network); - maxi = indice_max(network->input[network->size-1][0][0], 10); - printf("%d\t%d\n", i, maxi); + + if (! strcmp(out, "json")) { + for (int j=0; j < 10; j++) { + printf("%f", network->input[network->size-1][0][0][j]); + + if (j+1 < 10) { + printf(", "); + } + } + } else { + printf("%d\t%d\n", i, indice_max(network->input[network->size-1][0][0], 10)); + } + + if (! strcmp(out, "json")) { + if (i+1 < nb_elem) { + printf("],\n"); + } else { + printf("]\n"); + } + } for (int j=0; j < height; j++) { free(images[i][j]); } free(images[i]); } + if (! strcmp(out, "json")) { + printf("}\n"); + } + free(images); } -void recognize_jpg(Network* network, char* input_file) { +void recognize_jpg(Network* network, char* input_file, char* out) { int width, height; // Dimensions de l'image int maxi; @@ -136,22 +166,45 @@ void recognize_jpg(Network* network, char* input_file) { width = image->width; height = image->height; + if (! strcmp(out, "json")) { + printf("{\n"); + printf("\"0\" : ["); + } + + // Load image in the first layer of the Network write_image_in_network_260(image->lpData, height, width, network->input[0]); forward_propagation(network); - maxi = indice_max(network->input[network->size-1][0][0], 50); - printf("Catégorie reconnue: %d\n", maxi); + + if (! strcmp(out, "json")) { + for (int j=0; j < 50; j++) { + printf("%f", network->input[network->size-1][0][0][j]); + + if (j+1 < 10) { + printf(", "); + } + } + } else { + maxi = indice_max(network->input[network->size-1][0][0], 50); + printf("Catégorie reconnue: %d\n", maxi); + } + + if (! strcmp(out, "json")) { + printf("]\n"); + printf("}\n"); + } + free(image->lpData); free(image); } -void recognize(int dataset_type, char* modele, char* input_file) { +void recognize(int dataset_type, char* modele, char* input_file, char* out) { Network* network = read_network(modele); if (dataset_type == 0) { - recognize_mnist(network, input_file); + recognize_mnist(network, input_file, out); } else { - recognize_jpg(network, input_file); + recognize_jpg(network, input_file, out); } free_network(network); diff --git a/src/webserver/app.py b/src/webserver/app.py index 8b50735..fbe59fc 100644 --- a/src/webserver/app.py +++ b/src/webserver/app.py @@ -40,13 +40,30 @@ def recognize_mnist(image): output = subprocess.check_output([ 'build/mnist-main', 'recognize', - '--modele', '.cache/mnist-reseau.bin', + '--modele', '.cache/mnist-reseau-fully-connected.bin', '--in', '.cache/image-idx3-ubyte', '--out', 'json' ]).decode("utf-8") - json_data = json.loads(output.replace("nan", "0.0"))["0"] - return {"status": 200, "data": json_data} - except subprocess.CalledProcessError: + json_data_fc = json.loads(output.replace("nan", "0.0"))["0"] + + output = subprocess.check_output([ + 'build/cnn-main', + 'recognize', + '--dataset', 'mnist', + '--modele', '.cache/mnist-reseau-cnn.bin', + '--input', '.cache/image-idx3-ubyte', + '--out', 'json' + ]).decode("utf-8") + json_data_cnn = json.loads(output.replace("nan", "0.0"))["0"] + + return { + "status": 200, + "data": { + "fully_connected": json_data_fc, + "cnn": json_data_cnn + } + } + except Error: return { "status": 500, "data": "Internal Server Error" diff --git a/src/webserver/static/script.js b/src/webserver/static/script.js index f5f781f..380b209 100644 --- a/src/webserver/static/script.js +++ b/src/webserver/static/script.js @@ -14,6 +14,8 @@ var mouseY; var mouseDown = 0; const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0); +const RES_STANDARD = "Résultats réseau de neurones standard"; +const RES_CONV = "Résultats réseau de neurones convolutif"; var canvasSize; function init() { @@ -44,6 +46,7 @@ function init() { canvas.addEventListener('touchmove', s_touchMove, false); canvas.addEventListener('touchend', s_mouseUp, false); } + clear(); } @@ -107,8 +110,9 @@ function s_touchMove(e) { function clear() { - document.getElementById("result").innerHTML = ""; - ctx.clearRect(0, 0, canvas.width, canvas.height); + document.getElementById("result_fc").innerHTML = RES_STANDARD; + document.getElementById("result_cnn").innerHTML = RES_CONV; + ctx.clearRect(0, 0, canvas.width, canvas.height); ctx.fillStyle = "black"; ctx.fillRect(0, 0, canvas.width, canvas.height); } @@ -144,6 +148,26 @@ function sendJSON(data,callback){ } +function addResults(elem, data, text) { + elem.innerHTML = text; + var elements = []; + + for (let i=0; i < data.length; i++) elements.push([data[i], i]); + + elements.sort(function(a, b) { + a = a[1]; + b = b[1]; + + return a < b ? -1 : (a > b ? 1 : 0); + }); + + let res = elements.sort().reverse(); + for (let j=0; j < res.length; j++) { + elem.innerHTML += "
"+res[j][1]+" : "+res[j][0]; + } +} + + function getPrediction() { let imageSize = 28; let totalWidth = canvasSize; @@ -184,23 +208,8 @@ function getPrediction() { if (data["status"] != 200) { document.getElementById("result").innerHTML = "500 Internal Server Error"; } else { - let result = document.getElementById("result"); - result.innerHTML = "Résultat:"; - var elements = []; - - for (let i=0; i < data["data"].length; i++) elements.push([data["data"][i], i]); - - elements.sort(function(a, b) { - a = a[1]; - b = b[1]; - - return a < b ? -1 : (a > b ? 1 : 0); - }); - - let res = elements.sort().reverse(); - for (let j=0; j < res.length; j++) { - result .innerHTML += "
"+res[j][1]+" : "+res[j][0]; - } + addResults(document.getElementById("result_fc"), data["data"]["fully_connected"], RES_STANDARD); + addResults(document.getElementById("result_cnn"), data["data"]["cnn"], RES_CONV); } }) } \ No newline at end of file diff --git a/src/webserver/static/style.css b/src/webserver/static/style.css index 5e22c33..14b404c 100644 --- a/src/webserver/static/style.css +++ b/src/webserver/static/style.css @@ -11,8 +11,11 @@ body { touch-action:none; } -#result { - margin: 3%; +.result-elem { + margin: 1%; + border-style: solid; + border-width: 1px; + padding: 2%; } /* For mobile devices */ @@ -23,7 +26,7 @@ body { button { width: 300px; height: 50px; - -webkit-appearance: none; + --webkit-appearance: none; font-size: large; } #digit-canvas { diff --git a/src/webserver/templates/mnist.html b/src/webserver/templates/mnist.html index 2174c8d..0283193 100644 --- a/src/webserver/templates/mnist.html +++ b/src/webserver/templates/mnist.html @@ -12,6 +12,7 @@
-
+
+
\ No newline at end of file