diff --git a/src/cnn/train.c b/src/cnn/train.c index c51ed7a..be06527 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -185,8 +185,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di //* Création du réseau Network* network; if (!recover) { - network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth); - //network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth); + if (dataset_type == 0) { + network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth); + //network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth); + } else { + network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories); + } } else { network = read_network(recover); network->learning_rate = LEARNING_RATE;