Add utils/patch-network

This commit is contained in:
augustin64 2022-05-21 18:06:39 +02:00
parent 986707af2b
commit 511a522a34
2 changed files with 46 additions and 3 deletions

View File

@ -217,8 +217,10 @@ void train(int epochs, int layers, int neurons, char* recovery, char* image_file
if (delta != NULL) if (delta != NULL)
write_delta_network(delta, delta_network); write_delta_network(delta, delta_network);
} }
if (delta != NULL) write_network(out, network);
if (delta != NULL) {
deletion_of_network(delta_network); deletion_of_network(delta_network);
}
deletion_of_network(network); deletion_of_network(network);
free(train_parameters); free(train_parameters);
free(tid); free(tid);

View File

@ -11,7 +11,7 @@
Contient un ensemble de fonctions utiles pour le débogage Contient un ensemble de fonctions utiles pour le débogage
*/ */
void help(char* call) { void help(char* call) {
printf("Usage: %s ( print-poids | print-biais | creer-reseau ) [OPTIONS]\n\n", call); printf("Usage: %s ( print-poids | print-biais | creer-reseau | patch-network ) [OPTIONS]\n\n", call);
printf("OPTIONS:\n"); printf("OPTIONS:\n");
printf("\tprint-poids:\n"); printf("\tprint-poids:\n");
printf("\t\t--reseau | -r [FILENAME]\tFichier contenant le réseau de neurones.\n"); printf("\t\t--reseau | -r [FILENAME]\tFichier contenant le réseau de neurones.\n");
@ -21,7 +21,10 @@ void help(char* call) {
printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n");
printf("\tcreer-reseau:\n"); printf("\tcreer-reseau:\n");
printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n"); printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n");
printf("\t\t--number | -n [int]\tNuméro à privilégier\n"); printf("\t\t--number | -n [int]\tNuméro à privilégier.\n");
printf("\tpatch-network:\n");
printf("\t\t--network | -n [FILENAME]\tFichier contenant le réseau de neurones.\n");
printf("\t\t--delta | -d [FILENAME]\tFichier de patch à utiliser.\n");
} }
@ -115,6 +118,18 @@ void create_network(char* filename, int sortie) {
} }
void patch_stored_network(char* network_filename, char* delta_filename) {
// Apply patch to a network stored in a file
Network* network = read_network(network_filename);
Network* delta = read_delta_network(delta_filename);
patch_network(network, delta, 1);
write_network(network_filename, network);
deletion_of_network(network);
deletion_of_network(delta);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc < 2) { if (argc < 2) {
@ -194,6 +209,32 @@ int main(int argc, char* argv[]) {
} }
count_labels(labels); count_labels(labels);
exit(0); exit(0);
} else if (! strcmp(argv[1], "patch-network")) {
char* network = NULL;
char* delta = NULL;
int i = 2;
while (i < argc) {
if ((! strcmp(argv[i], "--network"))||(! strcmp(argv[i], "-n"))) {
network = argv[i+1];
i += 2;
} else if ((! strcmp(argv[i], "--delta"))||(! strcmp(argv[i], "-d"))) {
delta = argv[i+1];
i += 2;
} else {
printf("%s : Argument non reconnu\n", argv[i]);
i++;
}
}
if (!network) {
printf("--network: Argument obligatoire.\n");
exit(1);
}
if (!delta) {
printf("--delta: Argument obligatoire.\n");
exit(1);
}
patch_stored_network(network, delta);
exit(0);
} }
printf("Option choisie non reconnue: %s\n", argv[1]); printf("Option choisie non reconnue: %s\n", argv[1]);
help(argv[0]); help(argv[0]);