train.c: pick architecture based on dataset type

This commit is contained in:
augustin64 2023-05-15 10:45:14 +02:00
parent 19005366d3
commit 06abf0bc6b

View File

@ -185,8 +185,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
//* Création du réseau
Network* network;
if (!recover) {
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
if (dataset_type == 0) {
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
} else {
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
}
} else {
network = read_network(recover);
network->learning_rate = LEARNING_RATE;