From 06abf0bc6b22b8e3a535d30310193e4fc98bc9a0 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Mon, 15 May 2023 10:45:14 +0200 Subject: [PATCH] train.c: pick architecture based on dataset type --- src/cnn/train.c | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cnn/train.c b/src/cnn/train.c index c51ed7a..be06527 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -185,8 +185,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di //* Création du réseau Network* network; if (!recover) { - network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth); - //network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth); + if (dataset_type == 0) { + network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth); + //network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth); + } else { + network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories); + } } else { network = read_network(recover); network->learning_rate = LEARNING_RATE;