Add parallel

This commit is contained in:
augustin64 2022-05-18 21:46:05 +02:00
parent 0d2f48c192
commit db06dd4f73
3 changed files with 173 additions and 0 deletions

72
src/parallel/app.py Normal file
View File

@ -0,0 +1,72 @@
#!/usr/bin/python3
"""
Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle.
"""
import os
import random
from secrets import token_urlsafe
from flask import Flask, request, send_file
from clients import Client, clients
DATASET = "mnist-train"
SECRET = str(random.randint(1000, 10000))
CACHE = ".cache"
os.makedirs(CACHE, exist_ok=True)
app = Flask(__name__)
print(f" * Secret: {SECRET}")
@app.route("/authenticate", methods = ['POST'])
def authenticate():
"""
Authentification d'un nouvel utilisateur
"""
if not request.is_json:
return {
"status":
"request format is not json"
}
content = request.get_json()
if content["secret"] != SECRET:
return {"status": "invalid secret"}
performance = content["performance"]
token = token_urlsafe(30)
while token in [client.token for client in clients]:
token = token_urlsafe(30)
clients.append(Client(performance, token))
data = {}
data["token"] = token
data["nb_elem"] = performance
data["start"] = 0
data["dataset"] = DATASET
data["status"] = "ok"
return data
@app.route("/get_network")
def get_network():
"""
Renvoie le réseau neuronal
"""
if not request.is_json:
return {
"status":
"request format is not json"
}
token = request.get_json()["token"]
if token not in [client.token for client in clients]:
return {
"status":
"token invalide"
}
return send_file(
os.path.join(CACHE, "reseau.bin")
)

86
src/parallel/client.py Normal file
View File

@ -0,0 +1,86 @@
#!/usr/bin/python3
"""
Client se connectant au serveur Flask afin de fournir de la puissance de calcul.
"""
import json
import os
import sys
import psutil
import requests
CACHE = ".cache"
DELTA = os.path.join(CACHE, "delta_shared.bin")
RESEAU = os.path.join(CACHE, "reseau_shared.bin")
SECRET = input("SECRET : ")
HOST = input("HOST : ")
os.makedirs(CACHE, exist_ok=True)
def get_performance():
"""
Renvoie un indice de performance du client afin de savoir quelle quantité de données lui fournir
"""
cores = os.cpu_count()
max_freq = psutil.cpu_freq()[2]
return int(cores * max_freq * 0.01)
def authenticate():
"""
S'inscrit en tant que client auprès du serveur
"""
performance = get_performance()
data = {
"performance":performance,
"secret": SECRET
}
req = requests.post(
f"http://{HOST}/authenticate",
json=data
)
data = json.loads(req.text)
if data["status"] != "ok":
print("authentication error:", data["status"])
sys.exit(1)
else:
return data
def download_network(token):
"""
Récupère le réseau depuis le serveur
"""
data = {"token": token}
with requests.get(f"http://{HOST}/get_network", stream=True, json=data) as req:
req.raise_for_status()
with open(os.path.join(CACHE, RESEAU), "wb") as file:
for chunk in req.iter_content(chunk_size=8192):
file.write(chunk)
def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA):
"""
Entraînement du réseau
"""
raise NotImplementedError
def __main__():
data = authenticate()
token = data["token"]
dataset = data["dataset"]
start = data["start"]
nb_elem = data["nb_elem"]
download_network(token)
while True:
train_shared(dataset, start, nb_elem, epochs=1, out=DELTA)
if __name__ == "__main__":
__main__()

15
src/parallel/clients.py Normal file
View File

@ -0,0 +1,15 @@
#!/usr/bin/python3
"""
Description des clients se connectant au serveur.
"""
class Client():
"""
Classe client
"""
def __init__(self, performance, token):
self.performance = performance
self.token = token
clients = []