tipe/src/mnist/mnist.c

151 lines
3.6 KiB
C
Raw Normal View History

2022-03-21 17:06:05 +01:00
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <inttypes.h>
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
int** read_image(unsigned int width, unsigned int height, FILE* ptr) {
unsigned char buffer[width*height];
int** image = (int**)malloc(sizeof(int*)*height);
2022-05-10 21:04:48 +02:00
size_t line_size = sizeof(int) * width;
2022-03-21 17:06:05 +01:00
fread(buffer, sizeof(buffer), 1, ptr);
2022-05-14 10:34:26 +02:00
for (int i=0; i < (int)height; i++) {
2022-05-10 21:04:48 +02:00
int* line = (int*)malloc(line_size);
2022-05-14 10:34:26 +02:00
for (int j=0; j < (int)width; j++) {
2022-05-10 21:04:48 +02:00
line[j] = (int)buffer[j+i*width];
2022-03-21 17:06:05 +01:00
}
2022-05-10 21:04:48 +02:00
image[i] = line;
2022-03-21 17:06:05 +01:00
}
return image;
}
2022-06-03 15:47:02 +02:00
2022-03-21 17:06:05 +01:00
int* read_mnist_images_parameters(char* filename) {
2022-05-14 10:34:26 +02:00
int* tab = (int*)malloc(sizeof(int)*3);
2022-03-21 17:06:05 +01:00
FILE *ptr;
ptr = fopen(filename, "rb");
uint32_t magic_number;
uint32_t number_of_images;
unsigned int height;
unsigned int width;
fread(&magic_number, sizeof(uint32_t), 1, ptr);
magic_number = swap_endian(magic_number);
if (magic_number != 2051) {
printf("Incorrect magic number !\n");
exit(1);
}
fread(&number_of_images, sizeof(uint32_t), 1, ptr);
tab[0] = swap_endian(number_of_images);
fread(&height, sizeof(unsigned int), 1, ptr);
tab[1] = swap_endian(height);
fread(&width, sizeof(unsigned int), 1, ptr);
tab[2] = swap_endian(width);
return tab;
}
2022-06-03 15:47:02 +02:00
2022-04-22 15:03:21 +02:00
uint32_t read_mnist_labels_nb_images(char* filename) {
FILE *ptr;
ptr = fopen(filename, "rb");
uint32_t magic_number;
uint32_t number_of_images;
fread(&magic_number, sizeof(uint32_t), 1, ptr);
magic_number = swap_endian(magic_number);
if (magic_number != 2049) {
printf("Incorrect magic number !\n");
exit(1);
}
fread(&number_of_images, sizeof(uint32_t), 1, ptr);
number_of_images = swap_endian(number_of_images);
return number_of_images;
}
2022-03-21 17:06:05 +01:00
2022-06-03 15:47:02 +02:00
2022-03-21 17:06:05 +01:00
int*** read_mnist_images(char* filename) {
FILE *ptr;
ptr = fopen(filename, "rb");
uint32_t magic_number;
uint32_t number_of_images;
unsigned int height;
unsigned int width;
fread(&magic_number, sizeof(uint32_t), 1, ptr);
magic_number = swap_endian(magic_number);
if (magic_number != 2051) {
printf("Incorrect magic number !\n");
exit(1);
}
fread(&number_of_images, sizeof(uint32_t), 1, ptr);
number_of_images = swap_endian(number_of_images);
fread(&height, sizeof(unsigned int), 1, ptr);
height = swap_endian(height);
fread(&width, sizeof(unsigned int), 1, ptr);
width = swap_endian(width);
2022-04-02 16:15:27 +02:00
int*** tab = (int***)malloc(sizeof(int**)*number_of_images);
2022-03-21 17:06:05 +01:00
2022-05-14 10:34:26 +02:00
for (int i=0; i < (int)number_of_images; i++) {
2022-03-21 17:06:05 +01:00
tab[i] = read_image(width, height, ptr);
}
return tab;
}
2022-06-03 15:47:02 +02:00
// Renvoie des labels formatés sous le format de la base MNIST
2022-03-21 17:06:05 +01:00
unsigned int* read_mnist_labels(char* filename) {
FILE* ptr;
ptr = fopen(filename, "rb");
uint32_t magic_number;
uint32_t number_of_items;
fread(&magic_number, sizeof(uint32_t), 1, ptr);
magic_number = swap_endian(magic_number);
if (magic_number != 2049) {
printf("Incorrect magic number !\n");
exit(1);
}
fread(&number_of_items, sizeof(uint32_t), 1, ptr);
number_of_items = swap_endian(number_of_items);
unsigned char buffer[number_of_items];
fread(buffer, sizeof(unsigned char), number_of_items, ptr);
2022-05-14 10:34:26 +02:00
unsigned int* labels = (unsigned int*)malloc(sizeof(unsigned int)*number_of_items);
2022-03-21 17:06:05 +01:00
2022-05-14 10:34:26 +02:00
for (int i=0; i < (int)number_of_items; i++) {
2022-03-21 17:06:05 +01:00
labels[i] = (unsigned int)buffer[i];
}
return labels;
}