diff --git a/src/parallel/app.py b/src/parallel/app.py new file mode 100644 index 0000000..9755ada --- /dev/null +++ b/src/parallel/app.py @@ -0,0 +1,72 @@ +#!/usr/bin/python3 +""" +Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle. +""" +import os +import random +from secrets import token_urlsafe + +from flask import Flask, request, send_file + +from clients import Client, clients + +DATASET = "mnist-train" +SECRET = str(random.randint(1000, 10000)) +CACHE = ".cache" + +os.makedirs(CACHE, exist_ok=True) + +app = Flask(__name__) +print(f" * Secret: {SECRET}") + +@app.route("/authenticate", methods = ['POST']) +def authenticate(): + """ + Authentification d'un nouvel utilisateur + """ + if not request.is_json: + return { + "status": + "request format is not json" + } + content = request.get_json() + if content["secret"] != SECRET: + return {"status": "invalid secret"} + + performance = content["performance"] + token = token_urlsafe(30) + + while token in [client.token for client in clients]: + token = token_urlsafe(30) + + clients.append(Client(performance, token)) + + data = {} + data["token"] = token + data["nb_elem"] = performance + data["start"] = 0 + data["dataset"] = DATASET + data["status"] = "ok" + + return data + + +@app.route("/get_network") +def get_network(): + """ + Renvoie le réseau neuronal + """ + if not request.is_json: + return { + "status": + "request format is not json" + } + token = request.get_json()["token"] + if token not in [client.token for client in clients]: + return { + "status": + "token invalide" + } + return send_file( + os.path.join(CACHE, "reseau.bin") + ) diff --git a/src/parallel/client.py b/src/parallel/client.py new file mode 100644 index 0000000..e74adb3 --- /dev/null +++ b/src/parallel/client.py @@ -0,0 +1,86 @@ +#!/usr/bin/python3 +""" +Client se connectant au serveur Flask afin de fournir de la puissance de calcul. +""" +import json +import os +import sys + +import psutil +import requests + +CACHE = ".cache" +DELTA = os.path.join(CACHE, "delta_shared.bin") +RESEAU = os.path.join(CACHE, "reseau_shared.bin") +SECRET = input("SECRET : ") +HOST = input("HOST : ") +os.makedirs(CACHE, exist_ok=True) + + +def get_performance(): + """ + Renvoie un indice de performance du client afin de savoir quelle quantité de données lui fournir + """ + cores = os.cpu_count() + max_freq = psutil.cpu_freq()[2] + return int(cores * max_freq * 0.01) + + +def authenticate(): + """ + S'inscrit en tant que client auprès du serveur + """ + performance = get_performance() + data = { + "performance":performance, + "secret": SECRET + } + req = requests.post( + f"http://{HOST}/authenticate", + json=data + ) + data = json.loads(req.text) + if data["status"] != "ok": + print("authentication error:", data["status"]) + sys.exit(1) + else: + return data + + +def download_network(token): + """ + Récupère le réseau depuis le serveur + """ + data = {"token": token} + with requests.get(f"http://{HOST}/get_network", stream=True, json=data) as req: + req.raise_for_status() + with open(os.path.join(CACHE, RESEAU), "wb") as file: + for chunk in req.iter_content(chunk_size=8192): + file.write(chunk) + + + +def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA): + """ + Entraînement du réseau + """ + raise NotImplementedError + + +def __main__(): + data = authenticate() + + token = data["token"] + dataset = data["dataset"] + + start = data["start"] + nb_elem = data["nb_elem"] + + download_network(token) + + while True: + train_shared(dataset, start, nb_elem, epochs=1, out=DELTA) + + +if __name__ == "__main__": + __main__() diff --git a/src/parallel/clients.py b/src/parallel/clients.py new file mode 100644 index 0000000..0ead573 --- /dev/null +++ b/src/parallel/clients.py @@ -0,0 +1,15 @@ +#!/usr/bin/python3 +""" +Description des clients se connectant au serveur. +""" + +class Client(): + """ + Classe client + """ + def __init__(self, performance, token): + self.performance = performance + self.token = token + + +clients = []