From 33f85baa1e7eeff4d0ee3ec1d308da95bd36d9bb Mon Sep 17 00:00:00 2001 From: julienChemillier Date: Fri, 3 Feb 2023 15:12:59 +0100 Subject: [PATCH] Fix issues due to pooling --- src/cnn/include/print.h | 2 +- src/cnn/neuron_io.c | 13 ++++++++----- src/cnn/print.c | 10 +++++++--- src/cnn/utils.c | 1 + test/cnn_structure.c | 6 +++++- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/cnn/include/print.h b/src/cnn/include/print.h index 1b9fe50..6a65aa2 100644 --- a/src/cnn/include/print.h +++ b/src/cnn/include/print.h @@ -11,7 +11,7 @@ void print_kernel_cnn(Kernel_cnn* k, int depth_input, int dim_input, int depth_o /* * Affiche une couche de pooling */ -void print_pooling(int size); +void print_pooling(int size, int pooling); /* * Affiche le kernel d'une couche de fully connected diff --git a/src/cnn/neuron_io.c b/src/cnn/neuron_io.c index 3d42ad4..8083942 100644 --- a/src/cnn/neuron_io.c +++ b/src/cnn/neuron_io.c @@ -5,9 +5,11 @@ #include "../include/colors.h" #include "../include/utils.h" -#include "include/neuron_io.h" +#include "include/function.h" #include "include/struct.h" +#include "include/neuron_io.h" + #define MAGIC_NUMBER 1012 #define bufferAdd(val) {buffer[indice_buffer] = val; indice_buffer++;} @@ -122,8 +124,8 @@ void write_couche(Network* network, int indice_couche, int type_couche, FILE* pt fwrite(buffer, sizeof(buffer), 1, ptr); } else if (type_couche == 2) { // Cas du Pooling Layer uint32_t pre_buffer[2]; - pre_buffer[0] = kernel->activation; // Variable du pooling - pre_buffer[1] = kernel->linearisation; + pre_buffer[0] = kernel->linearisation; + pre_buffer[1] = kernel->pooling; fwrite(pre_buffer, sizeof(pre_buffer), 1, ptr); } } @@ -305,12 +307,13 @@ Kernel* read_kernel(int type_couche, int output_dim, FILE* ptr) { } } else if (type_couche == 2) { // Cas du Pooling Layer uint32_t pooling, linearisation; - fread(&pooling, sizeof(pooling), 1, ptr); fread(&linearisation, sizeof(linearisation), 1, ptr); + fread(&pooling, sizeof(pooling), 1, ptr); kernel->cnn = NULL; kernel->nn = NULL; - kernel->activation = pooling; + kernel->activation = IDENTITY; + kernel->pooling = pooling; kernel->linearisation = linearisation; } return kernel; diff --git a/src/cnn/print.c b/src/cnn/print.c index 42e0058..222fc19 100644 --- a/src/cnn/print.c +++ b/src/cnn/print.c @@ -47,10 +47,14 @@ void print_kernel_cnn(Kernel_cnn* ker, int depth_input, int dim_input, int depth } -void print_pooling(int size) { +void print_pooling(int size, int pooling) { print_bar; purple; - printf("-------Pooling %dx%d-------\n", size ,size); + if (pooling == 1) { + printf("-------Average Pooling %dx%d-------\n", size ,size); + } else { + printf("-------Max Pooling %dx%d-------\n", size ,size); + } reset_color; print_bar; print_dspace; @@ -117,7 +121,7 @@ void print_cnn(Network* network) { print_kernel_nn(k_i->nn, input_width, output_width); } else { // Pooling - print_pooling(input_width - output_width +1); + print_pooling(input_width - output_width +1, k_i->pooling); } } } \ No newline at end of file diff --git a/src/cnn/utils.c b/src/cnn/utils.c index bf79217..3815084 100644 --- a/src/cnn/utils.c +++ b/src/cnn/utils.c @@ -52,6 +52,7 @@ bool equals_networks(Network* network1, Network* network2) { if (!network1->kernel[i]->cnn && !network1->kernel[i]->nn) { // Type Pooling checkEquals(kernel[i]->activation, "kernel[i]->activation pour un pooling", i); + checkEquals(kernel[i]->pooling, "kernel[i]->pooling pour un pooling", i); } else if (!network1->kernel[i]->cnn) { // Type NN checkEquals(kernel[i]->nn->input_units, "kernel[i]->nn->input_units", i); diff --git a/test/cnn_structure.c b/test/cnn_structure.c index dc0bc07..2375dd0 100644 --- a/test/cnn_structure.c +++ b/test/cnn_structure.c @@ -20,7 +20,11 @@ int main() { for (int i=0; i < network->size-1; i++) { kernel = network->kernel[i]; if ((!kernel->cnn)&&(!kernel->nn)) { - printf("\n==== Couche %d de type "YELLOW"Pooling"RESET" ====\n", i); + if (kernel->pooling == 1) { + printf("\n==== Couche %d de type "YELLOW"Average Pooling"RESET" ====\n", i); + } else { + printf("\n==== Couche %d de type "YELLOW"Max Pooling"RESET" ====\n", i); + } } else if (!kernel->cnn) { printf("\n==== Couche %d de type "GREEN"NN"RESET" ====\n", i); printf("input: %d\n", kernel->nn->input_units);