diff --git a/src/cnn/export.c b/src/cnn/export.c index e95b426..abedf1e 100644 --- a/src/cnn/export.c +++ b/src/cnn/export.c @@ -18,15 +18,13 @@ void help(char* call) { printf("Usage: %s ( print-poids-kernel-cnn | visual-propagation ) [OPTIONS]\n\n", call); printf("OPTIONS:\n"); printf("\tprint-poids-kernel-cnn\n"); - printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné\n"); + printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné\n"); printf("\tvisual-propagation\n"); - printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné\n"); - printf("\t\t--images | -i [FILENAME]\tFichier contenant les images.\n"); - printf("\t\t--numero | -n [numero]\tNuméro de l'image dont la propagation veut être visualisée\n"); - printf("\t\t--out | -o [BASE_FILENAME]\tLes images seront stockées dans ${out}_layer-${numéro de couche}_feature-${kernel_numero}.jpeg\n"); - - printf("\n"); - printf_warning("Seul les datasets de type MNIST sont pris en charge pour le moment\n"); + printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau entraîné\n"); + printf("\t\t--out | -o [BASE_FILENAME]\tLes images seront stockées dans ${out}_layer-${numéro de couche}_feature-${kernel_numero}.jpeg\n"); + printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n"); + printf("\t(mnist)\t--numero | -n [numero]\tNuméro de l'image dont la propagation veut être visualisée\n"); + printf("\t(jpeg)\t--jpeg-image | -j [FILENAME]\tImage jpeg dont la propagation veut être visualisée.\n"); } @@ -82,21 +80,20 @@ void print_poids_ker_cnn(char* modele) { } -void write_image(float** data, int width, char* base_filename, int layer_id, int kernel_id) { - int filename_length = strlen(base_filename) + (int)log10(layer_id+1)+1 + (int)log10(kernel_id+1)+1 + 21; +void write_image(float** data, int width, int height, char* base_filename, int layer_id, int kernel_id) { + int filename_length = strlen(base_filename) + (int)log10(layer_id+1)+1 + (int)log10(kernel_id+1)+1 + 21+12; char* filename = (char*)malloc(sizeof(char)*filename_length); sprintf(filename, "%s_layer-%d_feature-%d.jpeg", base_filename, layer_id, kernel_id); - imgRawImage* image = (imgRawImage*)malloc(sizeof(imgRawImage)); image->numComponents = 3; image->width = width; - image->height = width; - image->lpData = (unsigned char*)malloc(sizeof(unsigned char)*width*width*3); + image->height = height; + image->lpData = (unsigned char*)malloc(sizeof(unsigned char)*width*height*3); - for (int i=0; i < width; i++) { + for (int i=0; i < height; i++) { for (int j=0; j < width; j++) { float color = fmax(fmin(data[i][j], 1.), 0.)*255; @@ -114,52 +111,63 @@ void write_image(float** data, int width, char* base_filename, int layer_id, int } -void visual_propagation(char* modele_file, char* images_file, char* out_base, int numero) { +void visual_propagation(char* modele_file, char* mnist_images_file, char* out_base, int numero, char* jpeg_file) { Network* network = read_network(modele_file); - int* mnist_parameters = read_mnist_images_parameters(images_file); - int*** images = read_mnist_images(images_file); + if (mnist_images_file) { + int* mnist_parameters = read_mnist_images_parameters(mnist_images_file); + int*** images = read_mnist_images(mnist_images_file); - int nb_elem = mnist_parameters[0]; + int nb_elem = mnist_parameters[0]; - int width = mnist_parameters[1]; - int height = mnist_parameters[2]; - free(mnist_parameters); + int width = mnist_parameters[1]; + int height = mnist_parameters[2]; + free(mnist_parameters); - if (numero < 0 || numero >= nb_elem) { - printf_error("Numéro d'image spécifié invalide."); - printf(" Le fichier contient %d images.\n", nb_elem); - exit(1); + if (numero < 0 || numero >= nb_elem) { + printf_error("Numéro d'image spécifié invalide."); + printf(" Le fichier contient %d images.\n", nb_elem); + exit(1); + } + + // Write image to the network + write_image_in_network_32(images[numero], height, width, network->input[0][0], false); + + // Free allocated memory from image reading + for (int i=0; i < nb_elem; i++) { + for (int j=0; j < width; j++) { + free(images[i][j]); + } + free(images[i]); + } + free(images); + } else { + imgRawImage* image = loadJpegImageFile(jpeg_file); + + write_image_in_network_260(image->lpData, image->height, image->width, network->input[0]); + + // Free allocated memory from image reading + free(image->lpData); + free(image); } - - // Forward propagation - write_image_in_network_32(images[numero], height, width, network->input[0][0], false); forward_propagation(network); - for (int i=0; i < network->size-1; i++) { - if (i == 0) { - write_image(network->input[0][0], width, out_base, 0, 0); + // Écriture des résultats + for (int i=0; i < network->depth[0]; i++) { + write_image(network->input[0][i], network->width[0], network->width[0], out_base, 0, i); + } + + for (int i=1; i < network->size; i++) { + if (!(!network->kernel[i-1]->nn)) { + write_image(network->input[i][0], network->kernel[i-1]->nn->size_output, 1, out_base, i, 0); } else { - if ((!network->kernel[i]->cnn)&&(!network->kernel[i]->nn)) { - for (int j=0; j < network->depth[i]; j++) { - write_image(network->input[i][j], network->width[i], out_base, i, j); - } - } else if (!network->kernel[i]->cnn) { - // Couche de type NN, on n'affiche rien - } else { - write_image(network->input[i][0], network->width[i], out_base, i, 0); + for (int j=0; j < network->depth[i]; j++) { + write_image(network->input[i][j], network->width[i], network->width[i], out_base, i, j); } } } free_network(network); - for (int i=0; i < nb_elem; i++) { - for (int j=0; j < width; j++) { - free(images[i][j]); - } - free(images[i]); - } - free(images); } @@ -192,16 +200,28 @@ int main(int argc, char* argv[]) { } if (! strcmp(argv[1], "visual-propagation")) { char* modele = NULL; // Fichier contenant le modèle - char* images = NULL; // Dossier contenant les images + char* images = NULL; // Dossier contenant les images (mnist) char* out_base = NULL; // Préfixe du nom de fichier de sortie - int numero = -1; // Numéro de l'image dans le dataset + char* jpeg_image = NULL; // Image à regarder (jpeg) + int numero = -1; // Numéro de l'image dans le dataset (mnist) int i = 2; while (i < argc) { if ((! strcmp(argv[i], "--modele"))||(! strcmp(argv[i], "-m"))) { modele = argv[i+1]; i += 2; } else if ((! strcmp(argv[i], "--images"))||(! strcmp(argv[i], "-i"))) { - images = argv[i+1]; + if (images) { + printf_warning("Arguments conflictuels. L'image de type jpeg sera favorisée.\n"); + } else { + images = argv[i+1]; + } + i += 2; + } else if ((! strcmp(argv[i], "--jpeg-image"))||(! strcmp(argv[i], "-j"))) { + if (images) { + printf_warning("Arguments conflictuels. L'image de type MNIST sera favorisée.\n"); + } else { + jpeg_image = argv[i+1]; + } i += 2; } else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) { out_base = argv[i+1]; @@ -219,7 +239,7 @@ int main(int argc, char* argv[]) { printf_error("Pas de modèle à utiliser spécifié.\n"); return 1; } - if (!images) { + if (!images && !jpeg_image) { printf_error("Pas de fichier d'images spécifié.\n"); return 1; } @@ -227,11 +247,15 @@ int main(int argc, char* argv[]) { printf_error("Pas de fichier de sortie spécifié.\n"); return 1; } - if (numero == -1) { - printf_error("Pas de numéro d'image spécifié.\n"); - return 1; + if (images) { + if (numero == -1) { + printf_error("Pas de numéro d'image spécifié.\n"); + return 1; + } + visual_propagation(modele, images, out_base, numero, NULL); + return 0; } - visual_propagation(modele, images, out_base, numero); + visual_propagation(modele, NULL, out_base, 0, jpeg_image); return 0; } printf_error("Option choisie non reconnue: ");