diff --git a/src/preview_mnist.c b/src/preview_mnist.c index 9ddbed6..d677eae 100644 --- a/src/preview_mnist.c +++ b/src/preview_mnist.c @@ -10,6 +10,10 @@ uint32_t swap_endian(uint32_t val) { } +// Prévisualise un chiffre écrit à la main +// de taille width x height +// commencant à l'adresse mémoire start +// dans le fichier pointé par ptr void print_image(unsigned int width, unsigned int height, FILE* ptr, int start) { unsigned char buffer[width*height+start]; @@ -27,9 +31,10 @@ void print_image(unsigned int width, unsigned int height, FILE* ptr, int start) } } - -void read_mnist_images(char* filename) { - unsigned char buffer[4]; +// Lit un set de données images de la base de données MNIST +// dans le fichier situé à filename, les +// images comportant comme labels 'labels' +void read_mnist_images(char* filename, unsigned char* labels) { FILE *ptr; ptr = fopen(filename, "rb"); @@ -42,7 +47,7 @@ void read_mnist_images(char* filename) { fread(&magic_number, sizeof(uint32_t), 1, ptr); magic_number = swap_endian(magic_number); - if (magic_number != 2051){ + if (magic_number != 2051) { printf("Incorrect magic number !\n"); exit(1); } @@ -61,17 +66,49 @@ void read_mnist_images(char* filename) { //printf("%u x %u\n\n", width, height); for (int i=0; i < number_of_images; i++) { - printf("--- Number %d ---\n", i); + printf("--- Number %d : %u ---\n", i, labels[i]); print_image(width, height, ptr, (i*width*height)); } } +unsigned char* read_mnist_labels(char* filename) { + FILE* ptr; + + ptr = fopen(filename, "rb"); + + uint32_t magic_number; + uint32_t number_of_items; + + fread(&magic_number, sizeof(uint32_t), 1, ptr); + magic_number = swap_endian(magic_number); + + if (magic_number != 2049) { + printf("Incorrect magic number !\n"); + exit(1); + } + + fread(&number_of_items, sizeof(uint32_t), 1, ptr); + number_of_items = swap_endian(number_of_items); + + printf("number of items: %" PRIu32 "\n", number_of_items); + + unsigned char* labels = malloc(sizeof(unsigned char)*number_of_items); + unsigned char tmp; + for (int i=0; i< number_of_items; i++) { + fread(&tmp, sizeof(unsigned char), 1, ptr); + labels[i] = tmp; + } + return labels; +} + int main(int argc, char *argv[]) { - if (argc == 1) { - printf("Utilisation: %s [FILE]\n", argv[0]); + if (argc < 3) { + printf("Utilisation: %s [IMAGES FILE] [LABELS FILE]\n", argv[0]); return 1; } - read_mnist_images(argv[1]); + unsigned char* labels = read_mnist_labels(argv[2]); + read_mnist_images(argv[1], labels); + free(labels); return 0; }