diff --git a/src/parallel/app.py b/src/parallel/app.py index 9755ada..0d347d1 100644 --- a/src/parallel/app.py +++ b/src/parallel/app.py @@ -3,70 +3,149 @@ Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle. """ import os +import time import random +import subprocess from secrets import token_urlsafe +from threading import Thread -from flask import Flask, request, send_file +from flask import Flask, request, send_from_directory, session -from clients import Client, clients +from structures import clients, Client, NoMoreJobAvailableError, TryLaterError, Training +# Définitions de variables DATASET = "mnist-train" +TEST_SET = "mnist-t10k" SECRET = str(random.randint(1000, 10000)) -CACHE = ".cache" +CACHE = ".cache" # À remplacer avec un chemin absolu +BATCHS = 10 +RESEAU = os.path.join(CACHE, "reseau.bin") + +training = Training(BATCHS, DATASET, TEST_SET, CACHE) os.makedirs(CACHE, exist_ok=True) +# On crée un réseau aléatoire si il n'existe pas encore +if not os.path.isfile(RESEAU): + if not os.path.isfile("out/main"): + subprocess.call(["make.sh", "main"]) + subprocess.call + ([ + "out/main", "train", + "--epochs", "0", + "--images", "data/mnist/train-images-idx3-ubyte", + "--labels", "data/mnist/train-labels-idx1-ubyte", + "--out", RESEAU, + ]) + print(f" * Created {RESEAU}") +else: + print(f" * {RESEAU} already exists") + app = Flask(__name__) +# On définit une clé secrète pour pouvoir utiliser des cookies de session +app.config["SECRET_KEY"] = token_urlsafe(40) print(f" * Secret: {SECRET}") -@app.route("/authenticate", methods = ['POST']) + +@app.route("/authenticate", methods=["POST"]) def authenticate(): """ Authentification d'un nouvel utilisateur """ if not request.is_json: - return { - "status": - "request format is not json" - } + return {"status": "request format is not json"} content = request.get_json() if content["secret"] != SECRET: return {"status": "invalid secret"} - performance = content["performance"] token = token_urlsafe(30) - - while token in [client.token for client in clients]: + while token in clients.keys(): token = token_urlsafe(30) - clients.append(Client(performance, token)) + clients[token] = Client(content["performance"], token) + # On prépare la réponse du serveur data = {} - data["token"] = token - data["nb_elem"] = performance - data["start"] = 0 - data["dataset"] = DATASET data["status"] = "ok" + data["dataset"] = training.dataset + session["token"] = token + + try: + clients[token].get_job(training) + data["nb_elem"] = clients[token].performance + data["start"] = clients[token].start + data["instruction"] = "train" + + except NoMoreJobAvailableError: + data["status"] = "Training already ended" + data["nb_elem"] = 0 + data["start"] = 0 + data["instruction"] = "stop" + + except TryLaterError: + data["status"] = "Wait for next batch" + data["nb_elem"] = 0 + data["start"] = 0 + data["instruction"] = "sleep" + data["sleep_time"] = 1 return data -@app.route("/get_network") +@app.route("/post_network", methods=["POST"]) +def post_network(): + """ + Applique le patch renvoyé dans le nouveau réseau + """ + token = session.get("token") + if not token in clients.keys(): + return {"status": "token invalide"} + + while training.is_patch_locked(): + time.sleep(0.1) + + request.files["file"].save(training.delta) + training.patch() + + # Préparation de la réponse + data = {} + data["status"] = "ok" + data["dataset"] = training.dataset + + try: + clients[token].get_job(training) + data["dataset"] = training.dataset + data["nb_elem"] = clients[token].performance + data["start"] = clients[token].start + data["instruction"] = "train" + + except NoMoreJobAvailableError: + data["status"] = "Training already ended" + data["nb_elem"] = 0 + data["start"] = 0 + data["instruction"] = "stop" + + except TryLaterError: + Thread(target=training.test_network()).start() + data["status"] = "Wait for next batch" + data["nb_elem"] = 0 + data["start"] = 0 + data["instruction"] = "sleep" + data["sleep_time"] = 1 + + return data + + +@app.route("/get_network", methods=["GET", "POST"]) def get_network(): """ Renvoie le réseau neuronal """ - if not request.is_json: - return { - "status": - "request format is not json" - } - token = request.get_json()["token"] - if token not in [client.token for client in clients]: - return { - "status": - "token invalide" - } - return send_file( - os.path.join(CACHE, "reseau.bin") - ) + token = session.get("token") + if not token in clients.keys(): + return {"status": "token invalide"} + + if token not in clients.keys(): + return {"status": "token invalide"} + + return send_from_directory(directory=CACHE, path="reseau.bin") diff --git a/src/parallel/client.py b/src/parallel/client.py index e74adb3..e7ed22b 100644 --- a/src/parallel/client.py +++ b/src/parallel/client.py @@ -5,15 +5,20 @@ Client se connectant au serveur Flask afin de fournir de la puissance de calcul. import json import os import sys +import time +import subprocess import psutil import requests -CACHE = ".cache" +# Définition de constantes +CACHE = ".cache" # Replace with an absolute path DELTA = os.path.join(CACHE, "delta_shared.bin") RESEAU = os.path.join(CACHE, "reseau_shared.bin") SECRET = input("SECRET : ") HOST = input("HOST : ") + +session = requests.Session() os.makedirs(CACHE, exist_ok=True) @@ -23,7 +28,7 @@ def get_performance(): """ cores = os.cpu_count() max_freq = psutil.cpu_freq()[2] - return int(cores * max_freq * 0.01) + return int(cores * max_freq * 0.5) def authenticate(): @@ -31,55 +36,107 @@ def authenticate(): S'inscrit en tant que client auprès du serveur """ performance = get_performance() - data = { - "performance":performance, - "secret": SECRET - } - req = requests.post( - f"http://{HOST}/authenticate", - json=data - ) + data = {"performance": performance, "secret": SECRET} + # Les données d'identification seront ensuite stockées dans un cookie de l'objet session + req = session.post(f"http://{HOST}/authenticate", json=data) + data = json.loads(req.text) if data["status"] != "ok": - print("authentication error:", data["status"]) + print("error in authenticate():", data["status"]) sys.exit(1) else: return data -def download_network(token): +def download_network(): """ Récupère le réseau depuis le serveur """ - data = {"token": token} - with requests.get(f"http://{HOST}/get_network", stream=True, json=data) as req: + with session.get(f"http://{HOST}/get_network", stream=True) as req: req.raise_for_status() - with open(os.path.join(CACHE, RESEAU), "wb") as file: + with open(RESEAU, "wb") as file: for chunk in req.iter_content(chunk_size=8192): file.write(chunk) +def send_delta_network(continue_=False): + """ + Envoie le réseau différentiel et obéit aux instructions suivantes + """ + with open(DELTA, "rb") as file: + files = {"file": file} + req = session.post(f"http://{HOST}/post_network", files=files) + req_data = json.loads(req.text) + + # Actions à effectuer en fonction de la réponse + if "instruction" not in req_data.keys(): + print(req_data["status"]) + raise NotImplementedError + + if req_data["instruction"] == "sleep": + print(f"Sleeping {req_data['sleep_time']}s.") + time.sleep(req_data["sleep_time"]) + send_delta_network(continue_=continue_) + + elif req_data["instruction"] == "stop": + print(req_data["status"]) + print("Shutting down.") + + elif req_data["instruction"] == "train": + download_network() + train_shared(req_data["dataset"], req_data["start"], req_data["nb_elem"]) + + else: + json.dumps(req_data) + raise NotImplementedError + def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA): """ Entraînement du réseau """ - raise NotImplementedError + # Utiliser un dictionnaire serait plus efficace et plus propre + if dataset == "mnist-train": + images = "data/mnist/train-images-idx3-ubyte" + labels = "data/mnist/train-labels-idx1-ubyte" + elif dataset == "mnist-t10k": + images = "data/mnist/t10k-images-idx3-ubyte" + labels = "data/mnist/t10k-labels-idx1-ubyte" + else: + print(f"Dataset {dataset} not implemented yet") + raise NotImplementedError + + # On compile out/main si il n'existe pas encore + if not os.path.isfile("out/main"): + subprocess.call(["make.sh", "main"]) + + # Entraînement du réseau + subprocess.call( + [ + "out/main", "train", + "--epochs", str(epochs), + "--images", images, + "--labels", labels, + "--recover", RESEAU, + "--delta", out, + "--nb-images", str(nb_elem), + "--start", str(start), + ], + stdout=subprocess.DEVNULL, + ) + return send_delta_network(continue_=True) def __main__(): data = authenticate() - token = data["token"] dataset = data["dataset"] - start = data["start"] nb_elem = data["nb_elem"] - download_network(token) - - while True: - train_shared(dataset, start, nb_elem, epochs=1, out=DELTA) + download_network() + # train_shared s'appelle récursivement sur lui même jusqu'à la fin du programme + train_shared(dataset, start, nb_elem, epochs=1, out=DELTA) if __name__ == "__main__": diff --git a/src/parallel/clients.py b/src/parallel/clients.py deleted file mode 100644 index 0ead573..0000000 --- a/src/parallel/clients.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/python3 -""" -Description des clients se connectant au serveur. -""" - -class Client(): - """ - Classe client - """ - def __init__(self, performance, token): - self.performance = performance - self.token = token - - -clients = [] diff --git a/src/parallel/structures.py b/src/parallel/structures.py new file mode 100644 index 0000000..a19083f --- /dev/null +++ b/src/parallel/structures.py @@ -0,0 +1,133 @@ +#!/usr/bin/python3 +""" +Description des structures. +""" +import os +import time +import subprocess + + +class NoMoreJobAvailableError(Exception): + """Entraînement du réseau fini""" + pass + + +class TryLaterError(Exception): + """Batch fini, réessayer plus tard""" + pass + + +class Client: + """ + Description d'un client se connectant au serveur + """ + def __init__(self, performance, token): + self.performance = performance + self.token = token + self.start = 0 + self.nb_images = 0 + + + def get_job(self, training): + """ + Donne un travail au client + """ + if training.nb_images == training.cur_image: + if training.batchs == training.cur_batch: + raise NoMoreJobAvailableError + raise TryLaterError + + self.start = training.cur_image + self.nb_images = min(training.nb_images - training.cur_image, self.performance) + training.cur_image += self.nb_images + + +clients = {} + + +class Training: + """ + Classe training + """ + def __init__(self, batchs, dataset, test_set, cache): + # Définition de variables + self.batchs = batchs + self.cur_batch = 1 + self.cur_image = 0 + self.dataset = dataset + self.test_set = test_set + self.cache = cache + self.reseau = os.path.join(self.cache, "reseau.bin") + self.delta = os.path.join(self.cache, "delta.bin") + + # Définition des chemins et données relatives à chaque set de données + # TODO: implémenter plus proprement avec un dictionnaire ou même un fichier datasets.json + if self.dataset == "mnist-train": + self.nb_images = 60000 + elif self.dataset == "mnist-t10k": + self.nb_images = 10000 + else: + raise NotImplementedError + + if self.test_set == "mnist-train": + self.test_images = "data/mnist/train-images-idx3-ubyte" + self.test_labels = "data/mnist/train-labels-idx1-ubyte" + elif self.test_set == "mnist-t10k": + self.test_images = "data/mnist/t10k-images-idx3-ubyte" + self.test_labels = "data/mnist/t10k-labels-idx1-ubyte" + else: + print(f"{self.test_set} test dataset unknown.") + raise NotImplementedError + + # On supprime le fichier de lock qui permet de + # ne pas écrire en même temps plusieurs fois sur le fichier reseau.bin + if os.path.isfile(self.reseau + ".lock"): + os.remove(self.reseau + ".lock") + + + def test_network(self): + """ + Teste les performances du réseau avant le batch suivant + """ + if not os.path.isfile("out/main"): + subprocess.call(["make.sh", "main"]) + + subprocess.call( + [ + "out/main", "test", + "--images", self.test_images, + "--labels", self.test_labels, + "--modele", self.reseau + ]) + self.cur_batch += 1 + self.cur_image = 0 + + + def patch(self): + """ + Applique un patch au réseau + """ + # On attend que le lock se libère puis on patch le réseau + while self.is_patch_locked(): + time.sleep(0.1) + + with open(self.reseau + ".lock", "w", encoding="utf8") as file: + file.write("") + + if not os.path.isfile("out/main"): + subprocess.call(["make.sh", "utils"]) + subprocess.call + ([ + "out/utils", "patch-network", + "--network", self.reseau, + "--delta", self.delta, + ]) + + os.remove(self.reseau + ".lock") + + + def is_patch_locked(self): + """ + Petit raccourci pour vérifier si le lock est présent + """ + return os.path.isfile(self.reseau + ".lock")