diff --git a/src/cnn/cnn.c b/src/cnn/cnn.c index f8420c4..639f232 100644 --- a/src/cnn/cnn.c +++ b/src/cnn/cnn.c @@ -130,11 +130,29 @@ void write_image_in_network_32(int** image, int height, int width, float** input } } -void write_256_image_in_network(unsigned char* image, int img_width, int img_depth, int input_width, float*** input) { - assert(img_width <= input_width); - assert((input_width - img_width)%2 == 0); +void write_256_image_in_network(unsigned char* image, int img_width, int img_height, int img_depth, int input_width, float*** input) { + int padding = 0; + int decalage_x = 0; // Si l'input est plus petit que img_height, décalage de l'input par rapport à l'image selon 1e coord + int decalage_y = 0; // Pareil avec width et 2e coord - int padding = (input_width - img_width)/2; + if (img_width < input_width) { // Avec padding, l'image est carrée + assert(img_height == img_width); + assert((input_width - img_width)%2 == 0); + + padding = (input_width - img_width)/2; + } else { // Sans padding, l'image est au minimum de la taille de l'input + assert(img_height >= input_width); + + int decalage_possible_x = input_width - img_height; + if (decalage_possible_x > 0) { + decalage_x = rand() %decalage_possible_x; + } + + int decalage_possible_y = input_width - img_width; + if (decalage_possible_y > 0) { + decalage_y = rand() %decalage_possible_y; + } + } for (int i=0; i < padding; i++) { for (int j=0; j < input_width; j++) { @@ -147,10 +165,14 @@ void write_256_image_in_network(unsigned char* image, int img_width, int img_dep } } - for (int i=0; i < img_width; i++) { - for (int j=0; j < img_width; j++) { + int min_width = min(img_width, input_width); + int min_height = min(img_height, input_width); + for (int i=0; i < min_height; i++) { + for (int j=0; j < min_width; j++) { for (int composante=0; composante < img_depth; composante++) { - input[composante][i+padding][j+padding] = (float)image[(i*img_width+j)*img_depth + composante] / 255.0f; + int x = i + decalage_x; + int y = j + decalage_y; + input[composante][i+padding][j+padding] = (float)image[(x*img_width+y)*img_depth + composante] / 255.0f; } } } diff --git a/src/cnn/export.c b/src/cnn/export.c index 4bfa151..d304ab1 100644 --- a/src/cnn/export.c +++ b/src/cnn/export.c @@ -152,7 +152,7 @@ void visual_propagation(char* modele_file, char* mnist_images_file, char* out_ba } else { imgRawImage* image = loadJpegImageFile(jpeg_file); - write_256_image_in_network(image->lpData, image->width, image->numComponents, network->width[0], network->input[0]); + write_256_image_in_network(image->lpData, image->width, image->height, image->numComponents, network->width[0], network->input[0]); // Free allocated memory from image reading free(image->lpData); diff --git a/src/cnn/include/cnn.h b/src/cnn/include/cnn.h index 5ed7e7c..f50d226 100644 --- a/src/cnn/include/cnn.h +++ b/src/cnn/include/cnn.h @@ -6,8 +6,8 @@ #define DEF_MAIN_H #define EVERYTHING 0 -#define NN_ONLY 1 -#define NN_AND_LINEARISATION 2 +#define NN_AND_LINEARISATION 1 +#define NN_ONLY 2 /* * Renvoie l'indice de l'élément de valeur maximale dans un tableau de flottants @@ -26,12 +26,15 @@ int will_be_drop(int dropout_prob); void write_image_in_network_32(int** image, int height, int width, float** input, bool random_offset); /* -* Écrit une image linéarisée de img_width*img_width*img_depth pixels dans un tableau de taille size_input*size_input*3 +* Écrit une image linéarisée de img_width*img_height*img_depth pixels dans un tableau de taille size_input*size_input*3 * Les conditions suivantes doivent être respectées: -* - l'image est au plus de la même taille que input -* - la différence de taille entre input et l'image doit être un multiple de 2 (pour centrer l'image) + +* Soit l'image est plus petite que l'input, et est carrée, alors +* la différence de taille entre input et l'image doit être un multiple de 2 (pour centrer l'image) + +* Soit l'image est de taille au moins la taille de l'input, et elle sera décalée de manière aléatoire */ -void write_256_image_in_network(unsigned char* image, int img_width, int img_depth, int input_width, float*** input); +void write_256_image_in_network(unsigned char* image, int img_width, int img_height, int img_depth, int input_width, float*** input); /* * Propage en avant le cnn. Le dropout est actif que si le réseau est en phase d'apprentissage. diff --git a/src/cnn/test_network.c b/src/cnn/test_network.c index 9bbd0bd..3212aeb 100644 --- a/src/cnn/test_network.c +++ b/src/cnn/test_network.c @@ -80,7 +80,7 @@ float* test_network_jpg(Network* network, char* data_dir, bool preview_fails, bo printf("Avancement: %.1f%%\r", 1000*i/(float)dataset->numImages); fflush(stdout); } - write_256_image_in_network(dataset->images[i], dataset->height, dataset->numComponents, network->width[0], network->input[0]); + write_256_image_in_network(dataset->images[i], dataset->width, dataset->height, dataset->numComponents, network->width[0], network->input[0]); forward_propagation(network); maxi = indice_max(network->input[network->size-1][0][0], 50); @@ -184,11 +184,11 @@ void recognize_mnist(Network* network, char* input_file, char* out) { } void recognize_jpg(Network* network, char* input_file, char* out) { - int width; // Dimensions de l'image, qui doit être carrée int maxi; imgRawImage* image = loadJpegImageFile(input_file); - width = image->width; + int height = image->height; + int width = image->width; assert(image->width == image->height); @@ -198,7 +198,7 @@ void recognize_jpg(Network* network, char* input_file, char* out) { } // Load image in the first layer of the Network - write_256_image_in_network(image->lpData, width, image->numComponents, network->width[0], network->input[0]); + write_256_image_in_network(image->lpData, width, height, image->numComponents, network->width[0], network->input[0]); forward_propagation(network); diff --git a/src/cnn/train.c b/src/cnn/train.c index 0238a38..5840585 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -128,7 +128,7 @@ void* train_thread(void* parameters) { load_image_param->index = index[i+1]; pthread_create(&tid, NULL, load_image, (void*) load_image_param); } - write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]); + write_256_image_in_network(param->dataset->images[index[i]], width, height, param->dataset->numComponents, network->width[0], network->input[0]); #ifdef DETAILED_TRAIN_TIMINGS start_time = omp_get_wtime();