From 511a522a34a0f35978fa1e173e50f28578a1148c Mon Sep 17 00:00:00 2001 From: augustin64 Date: Sat, 21 May 2022 18:06:39 +0200 Subject: [PATCH] Add utils/patch-network --- src/mnist/main.c | 4 +++- src/mnist/utils.c | 45 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/mnist/main.c b/src/mnist/main.c index 06ec50a..3ca60fc 100644 --- a/src/mnist/main.c +++ b/src/mnist/main.c @@ -217,8 +217,10 @@ void train(int epochs, int layers, int neurons, char* recovery, char* image_file if (delta != NULL) write_delta_network(delta, delta_network); } - if (delta != NULL) + write_network(out, network); + if (delta != NULL) { deletion_of_network(delta_network); + } deletion_of_network(network); free(train_parameters); free(tid); diff --git a/src/mnist/utils.c b/src/mnist/utils.c index 3965081..ca68de4 100644 --- a/src/mnist/utils.c +++ b/src/mnist/utils.c @@ -11,7 +11,7 @@ Contient un ensemble de fonctions utiles pour le débogage */ void help(char* call) { - printf("Usage: %s ( print-poids | print-biais | creer-reseau ) [OPTIONS]\n\n", call); + printf("Usage: %s ( print-poids | print-biais | creer-reseau | patch-network ) [OPTIONS]\n\n", call); printf("OPTIONS:\n"); printf("\tprint-poids:\n"); printf("\t\t--reseau | -r [FILENAME]\tFichier contenant le réseau de neurones.\n"); @@ -21,7 +21,10 @@ void help(char* call) { printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); printf("\tcreer-reseau:\n"); printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n"); - printf("\t\t--number | -n [int]\tNuméro à privilégier\n"); + printf("\t\t--number | -n [int]\tNuméro à privilégier.\n"); + printf("\tpatch-network:\n"); + printf("\t\t--network | -n [FILENAME]\tFichier contenant le réseau de neurones.\n"); + printf("\t\t--delta | -d [FILENAME]\tFichier de patch à utiliser.\n"); } @@ -115,6 +118,18 @@ void create_network(char* filename, int sortie) { } +void patch_stored_network(char* network_filename, char* delta_filename) { + // Apply patch to a network stored in a file + Network* network = read_network(network_filename); + Network* delta = read_delta_network(delta_filename); + + patch_network(network, delta, 1); + + write_network(network_filename, network); + deletion_of_network(network); + deletion_of_network(delta); +} + int main(int argc, char* argv[]) { if (argc < 2) { @@ -194,6 +209,32 @@ int main(int argc, char* argv[]) { } count_labels(labels); exit(0); + } else if (! strcmp(argv[1], "patch-network")) { + char* network = NULL; + char* delta = NULL; + int i = 2; + while (i < argc) { + if ((! strcmp(argv[i], "--network"))||(! strcmp(argv[i], "-n"))) { + network = argv[i+1]; + i += 2; + } else if ((! strcmp(argv[i], "--delta"))||(! strcmp(argv[i], "-d"))) { + delta = argv[i+1]; + i += 2; + } else { + printf("%s : Argument non reconnu\n", argv[i]); + i++; + } + } + if (!network) { + printf("--network: Argument obligatoire.\n"); + exit(1); + } + if (!delta) { + printf("--delta: Argument obligatoire.\n"); + exit(1); + } + patch_stored_network(network, delta); + exit(0); } printf("Option choisie non reconnue: %s\n", argv[1]); help(argv[0]);