Add CNN to webserver

This commit is contained in:
augustin64 2023-01-23 21:16:36 +01:00
parent e11d1f552a
commit b75388f463
8 changed files with 146 additions and 44 deletions

View File

@ -136,12 +136,22 @@ endif
webserver: $(CACHE_DIR)/mnist-reseau.bin
FLASK_APP="src/webserver/app.py" flask run
$(CACHE_DIR)/mnist-reseau.bin: $(BUILDDIR)/mnist-main
$(CACHE_DIR)/mnist-reseau-fully-connected.bin: $(BUILDDIR)/mnist-main
@mkdir -p $(CACHE_DIR)
$(BUILDDIR)/mnist-main train \
--images "data/mnist/train-images-idx3-ubyte" \
--labels "data/mnist/train-labels-idx1-ubyte" \
--out "$(CACHE_DIR)/mnist-reseau.bin"
--out "$(CACHE_DIR)/mnist-reseau-fully-connected.bin"
$(CACHE_DIR)/mnist-reseau-cnn.bin: $(BUILDDIR)/cnn-main
@mkdir -p $(CACHE_DIR)
$(BUILDDIR)/cnn-main train \
--dataset mnist \
--images data/mnist/train-images-idx3-ubyte \
--labels data/mnist/train-labels-idx1-ubyte \
--epochs 10 \
--out $(CACHE_DIR)/mnist-reseau-cnn.bin
#

View File

@ -11,15 +11,15 @@ void test_network(int dataset_type, char* modele, char* images_file, char* label
/*
* Classifie un fichier d'images sous le format MNIST à partir d'un réseau préalablement entraîné
*/
void recognize_mnist(Network* network, char* input_file);
void recognize_mnist(Network* network, char* input_file, char* out);
/*
* Classifie une image jpg à partir d'un réseau préalablement entraîné
*/
void recognize_jpg(Network* network, char* input_file);
void recognize_jpg(Network* network, char* input_file,char* out);
/*
* Classifie une image à partir d'un réseau préalablement entraîné
*/
void recognize(int dataset_type, char* modele, char* input_file);
void recognize(int dataset_type, char* modele, char* input_file, char* out);
#endif

View File

@ -31,6 +31,7 @@ void help(char* call) {
printf("\t\t--dataset | -d (mnist|jpg)\tFormat de l'image à reconnaître.\n");
printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné.\n");
printf("\t\t--input | -i [FILENAME]\tImage jpeg ou fichier binaire à reconnaître.\n");
printf("\t\t--out | -o (text|json)\tFormat de sortie.\n");
printf("\ttest:\n");
printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné.\n");
printf("\t\t--dataset | -d (mnist|jpg)\tFormat du set de données.\n");
@ -192,6 +193,7 @@ int main(int argc, char* argv[]) {
char* dataset = NULL; // mnist ou jpg
char* modele = NULL; // Fichier contenant le modèle
char* input_file = NULL; // Image à reconnaître
char* out = NULL;
int dataset_type;
int i = 2;
while (i < argc) {
@ -203,6 +205,10 @@ int main(int argc, char* argv[]) {
modele = argv[i+1];
i += 2;
}
else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) {
out = argv[i+1];
i += 2;
}
else if ((! strcmp(argv[i], "--input"))||(! strcmp(argv[i], "-i"))) {
input_file = argv[i+1];
i += 2;
@ -224,11 +230,14 @@ int main(int argc, char* argv[]) {
printf("Pas de fichier d'entrée spécifié, rien à faire.\n");
return 1;
}
if (!out) {
out = "text";
}
if (!modele) {
printf("Pas de modèle à utiliser spécifié.\n");
return 1;
}
recognize(dataset_type, modele, input_file);
recognize(dataset_type, modele, input_file, out);
return 0;
}
printf("Option choisie non reconnue: %s\n", argv[1]);

View File

@ -1,6 +1,7 @@
#include <stdlib.h>
#include <stdio.h>
#include <stdbool.h>
#include <string.h>
#include "../mnist/include/mnist.h"
#include "include/neuron_io.h"
@ -97,10 +98,9 @@ void test_network(int dataset_type, char* modele, char* images_file, char* label
}
void recognize_mnist(Network* network, char* input_file) {
void recognize_mnist(Network* network, char* input_file, char* out) {
int width, height; // Dimensions de l'image
int nb_elem; // Nombre d'éléments
int maxi; // Catégorie reconnue
// Load image
int* mnist_parameters = read_mnist_images_parameters(input_file);
@ -111,24 +111,54 @@ void recognize_mnist(Network* network, char* input_file) {
height = mnist_parameters[2];
free(mnist_parameters);
printf("Image\tCatégorie détectée\n");
if (! strcmp(out, "json")) {
printf("{\n");
} else {
printf("Image\tCatégorie détectée\n");
}
// Load image in the first layer of the Network
for (int i=0; i < nb_elem; i++) {
if (! strcmp(out, "json")) {
printf("\"%d\" : [", i);
}
write_image_in_network_32(images[i], height, width, network->input[0][0]);
forward_propagation(network);
maxi = indice_max(network->input[network->size-1][0][0], 10);
printf("%d\t%d\n", i, maxi);
if (! strcmp(out, "json")) {
for (int j=0; j < 10; j++) {
printf("%f", network->input[network->size-1][0][0][j]);
if (j+1 < 10) {
printf(", ");
}
}
} else {
printf("%d\t%d\n", i, indice_max(network->input[network->size-1][0][0], 10));
}
if (! strcmp(out, "json")) {
if (i+1 < nb_elem) {
printf("],\n");
} else {
printf("]\n");
}
}
for (int j=0; j < height; j++) {
free(images[i][j]);
}
free(images[i]);
}
if (! strcmp(out, "json")) {
printf("}\n");
}
free(images);
}
void recognize_jpg(Network* network, char* input_file) {
void recognize_jpg(Network* network, char* input_file, char* out) {
int width, height; // Dimensions de l'image
int maxi;
@ -136,22 +166,45 @@ void recognize_jpg(Network* network, char* input_file) {
width = image->width;
height = image->height;
if (! strcmp(out, "json")) {
printf("{\n");
printf("\"0\" : [");
}
// Load image in the first layer of the Network
write_image_in_network_260(image->lpData, height, width, network->input[0]);
forward_propagation(network);
maxi = indice_max(network->input[network->size-1][0][0], 50);
printf("Catégorie reconnue: %d\n", maxi);
if (! strcmp(out, "json")) {
for (int j=0; j < 50; j++) {
printf("%f", network->input[network->size-1][0][0][j]);
if (j+1 < 10) {
printf(", ");
}
}
} else {
maxi = indice_max(network->input[network->size-1][0][0], 50);
printf("Catégorie reconnue: %d\n", maxi);
}
if (! strcmp(out, "json")) {
printf("]\n");
printf("}\n");
}
free(image->lpData);
free(image);
}
void recognize(int dataset_type, char* modele, char* input_file) {
void recognize(int dataset_type, char* modele, char* input_file, char* out) {
Network* network = read_network(modele);
if (dataset_type == 0) {
recognize_mnist(network, input_file);
recognize_mnist(network, input_file, out);
} else {
recognize_jpg(network, input_file);
recognize_jpg(network, input_file, out);
}
free_network(network);

View File

@ -40,13 +40,30 @@ def recognize_mnist(image):
output = subprocess.check_output([
'build/mnist-main',
'recognize',
'--modele', '.cache/mnist-reseau.bin',
'--modele', '.cache/mnist-reseau-fully-connected.bin',
'--in', '.cache/image-idx3-ubyte',
'--out', 'json'
]).decode("utf-8")
json_data = json.loads(output.replace("nan", "0.0"))["0"]
return {"status": 200, "data": json_data}
except subprocess.CalledProcessError:
json_data_fc = json.loads(output.replace("nan", "0.0"))["0"]
output = subprocess.check_output([
'build/cnn-main',
'recognize',
'--dataset', 'mnist',
'--modele', '.cache/mnist-reseau-cnn.bin',
'--input', '.cache/image-idx3-ubyte',
'--out', 'json'
]).decode("utf-8")
json_data_cnn = json.loads(output.replace("nan", "0.0"))["0"]
return {
"status": 200,
"data": {
"fully_connected": json_data_fc,
"cnn": json_data_cnn
}
}
except Error:
return {
"status": 500,
"data": "Internal Server Error"

View File

@ -14,6 +14,8 @@ var mouseY;
var mouseDown = 0;
const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
const RES_STANDARD = "Résultats réseau de neurones standard";
const RES_CONV = "Résultats réseau de neurones convolutif";
var canvasSize;
function init() {
@ -44,6 +46,7 @@ function init() {
canvas.addEventListener('touchmove', s_touchMove, false);
canvas.addEventListener('touchend', s_mouseUp, false);
}
clear();
}
@ -107,8 +110,9 @@ function s_touchMove(e) {
function clear() {
document.getElementById("result").innerHTML = "";
ctx.clearRect(0, 0, canvas.width, canvas.height);
document.getElementById("result_fc").innerHTML = RES_STANDARD;
document.getElementById("result_cnn").innerHTML = RES_CONV;
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.fillStyle = "black";
ctx.fillRect(0, 0, canvas.width, canvas.height);
}
@ -144,6 +148,26 @@ function sendJSON(data,callback){
}
function addResults(elem, data, text) {
elem.innerHTML = text;
var elements = [];
for (let i=0; i < data.length; i++) elements.push([data[i], i]);
elements.sort(function(a, b) {
a = a[1];
b = b[1];
return a < b ? -1 : (a > b ? 1 : 0);
});
let res = elements.sort().reverse();
for (let j=0; j < res.length; j++) {
elem.innerHTML += "<br/>"+res[j][1]+" : "+res[j][0];
}
}
function getPrediction() {
let imageSize = 28;
let totalWidth = canvasSize;
@ -184,23 +208,8 @@ function getPrediction() {
if (data["status"] != 200) {
document.getElementById("result").innerHTML = "500 Internal Server Error";
} else {
let result = document.getElementById("result");
result.innerHTML = "Résultat:";
var elements = [];
for (let i=0; i < data["data"].length; i++) elements.push([data["data"][i], i]);
elements.sort(function(a, b) {
a = a[1];
b = b[1];
return a < b ? -1 : (a > b ? 1 : 0);
});
let res = elements.sort().reverse();
for (let j=0; j < res.length; j++) {
result .innerHTML += "<br/>"+res[j][1]+" : "+res[j][0];
}
addResults(document.getElementById("result_fc"), data["data"]["fully_connected"], RES_STANDARD);
addResults(document.getElementById("result_cnn"), data["data"]["cnn"], RES_CONV);
}
})
}

View File

@ -11,8 +11,11 @@ body {
touch-action:none;
}
#result {
margin: 3%;
.result-elem {
margin: 1%;
border-style: solid;
border-width: 1px;
padding: 2%;
}
/* For mobile devices */
@ -23,7 +26,7 @@ body {
button {
width: 300px;
height: 50px;
-webkit-appearance: none;
--webkit-appearance: none;
font-size: large;
}
#digit-canvas {

View File

@ -12,6 +12,7 @@
<button id="clear">Clear</button>
<button id="predict" onclick="getPrediction()">Prédire</button>
</div><br/>
<div id="result"></div>
<div id="result_fc" class="result-elem"></div>
<div id="result_cnn" class="result-elem"></div>
</body>
</html>