mirror of
https://github.com/augustin64/projet-tipe
synced 2025-02-02 19:39:39 +01:00
Add parallel
This commit is contained in:
parent
0d2f48c192
commit
db06dd4f73
72
src/parallel/app.py
Normal file
72
src/parallel/app.py
Normal file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
from secrets import token_urlsafe
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
|
||||
from clients import Client, clients
|
||||
|
||||
DATASET = "mnist-train"
|
||||
SECRET = str(random.randint(1000, 10000))
|
||||
CACHE = ".cache"
|
||||
|
||||
os.makedirs(CACHE, exist_ok=True)
|
||||
|
||||
app = Flask(__name__)
|
||||
print(f" * Secret: {SECRET}")
|
||||
|
||||
@app.route("/authenticate", methods = ['POST'])
|
||||
def authenticate():
|
||||
"""
|
||||
Authentification d'un nouvel utilisateur
|
||||
"""
|
||||
if not request.is_json:
|
||||
return {
|
||||
"status":
|
||||
"request format is not json"
|
||||
}
|
||||
content = request.get_json()
|
||||
if content["secret"] != SECRET:
|
||||
return {"status": "invalid secret"}
|
||||
|
||||
performance = content["performance"]
|
||||
token = token_urlsafe(30)
|
||||
|
||||
while token in [client.token for client in clients]:
|
||||
token = token_urlsafe(30)
|
||||
|
||||
clients.append(Client(performance, token))
|
||||
|
||||
data = {}
|
||||
data["token"] = token
|
||||
data["nb_elem"] = performance
|
||||
data["start"] = 0
|
||||
data["dataset"] = DATASET
|
||||
data["status"] = "ok"
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@app.route("/get_network")
|
||||
def get_network():
|
||||
"""
|
||||
Renvoie le réseau neuronal
|
||||
"""
|
||||
if not request.is_json:
|
||||
return {
|
||||
"status":
|
||||
"request format is not json"
|
||||
}
|
||||
token = request.get_json()["token"]
|
||||
if token not in [client.token for client in clients]:
|
||||
return {
|
||||
"status":
|
||||
"token invalide"
|
||||
}
|
||||
return send_file(
|
||||
os.path.join(CACHE, "reseau.bin")
|
||||
)
|
86
src/parallel/client.py
Normal file
86
src/parallel/client.py
Normal file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Client se connectant au serveur Flask afin de fournir de la puissance de calcul.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
|
||||
CACHE = ".cache"
|
||||
DELTA = os.path.join(CACHE, "delta_shared.bin")
|
||||
RESEAU = os.path.join(CACHE, "reseau_shared.bin")
|
||||
SECRET = input("SECRET : ")
|
||||
HOST = input("HOST : ")
|
||||
os.makedirs(CACHE, exist_ok=True)
|
||||
|
||||
|
||||
def get_performance():
|
||||
"""
|
||||
Renvoie un indice de performance du client afin de savoir quelle quantité de données lui fournir
|
||||
"""
|
||||
cores = os.cpu_count()
|
||||
max_freq = psutil.cpu_freq()[2]
|
||||
return int(cores * max_freq * 0.01)
|
||||
|
||||
|
||||
def authenticate():
|
||||
"""
|
||||
S'inscrit en tant que client auprès du serveur
|
||||
"""
|
||||
performance = get_performance()
|
||||
data = {
|
||||
"performance":performance,
|
||||
"secret": SECRET
|
||||
}
|
||||
req = requests.post(
|
||||
f"http://{HOST}/authenticate",
|
||||
json=data
|
||||
)
|
||||
data = json.loads(req.text)
|
||||
if data["status"] != "ok":
|
||||
print("authentication error:", data["status"])
|
||||
sys.exit(1)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def download_network(token):
|
||||
"""
|
||||
Récupère le réseau depuis le serveur
|
||||
"""
|
||||
data = {"token": token}
|
||||
with requests.get(f"http://{HOST}/get_network", stream=True, json=data) as req:
|
||||
req.raise_for_status()
|
||||
with open(os.path.join(CACHE, RESEAU), "wb") as file:
|
||||
for chunk in req.iter_content(chunk_size=8192):
|
||||
file.write(chunk)
|
||||
|
||||
|
||||
|
||||
def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA):
|
||||
"""
|
||||
Entraînement du réseau
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __main__():
|
||||
data = authenticate()
|
||||
|
||||
token = data["token"]
|
||||
dataset = data["dataset"]
|
||||
|
||||
start = data["start"]
|
||||
nb_elem = data["nb_elem"]
|
||||
|
||||
download_network(token)
|
||||
|
||||
while True:
|
||||
train_shared(dataset, start, nb_elem, epochs=1, out=DELTA)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
__main__()
|
15
src/parallel/clients.py
Normal file
15
src/parallel/clients.py
Normal file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Description des clients se connectant au serveur.
|
||||
"""
|
||||
|
||||
class Client():
|
||||
"""
|
||||
Classe client
|
||||
"""
|
||||
def __init__(self, performance, token):
|
||||
self.performance = performance
|
||||
self.token = token
|
||||
|
||||
|
||||
clients = []
|
Loading…
Reference in New Issue
Block a user