mirror of
https://github.com/augustin64/projet-tipe
synced 2025-02-03 10:48:01 +01:00
Add parallel
This commit is contained in:
parent
0d2f48c192
commit
db06dd4f73
72
src/parallel/app.py
Normal file
72
src/parallel/app.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#!/usr/bin/python3
|
||||||
|
"""
|
||||||
|
Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from secrets import token_urlsafe
|
||||||
|
|
||||||
|
from flask import Flask, request, send_file
|
||||||
|
|
||||||
|
from clients import Client, clients
|
||||||
|
|
||||||
|
DATASET = "mnist-train"
|
||||||
|
SECRET = str(random.randint(1000, 10000))
|
||||||
|
CACHE = ".cache"
|
||||||
|
|
||||||
|
os.makedirs(CACHE, exist_ok=True)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
print(f" * Secret: {SECRET}")
|
||||||
|
|
||||||
|
@app.route("/authenticate", methods = ['POST'])
|
||||||
|
def authenticate():
|
||||||
|
"""
|
||||||
|
Authentification d'un nouvel utilisateur
|
||||||
|
"""
|
||||||
|
if not request.is_json:
|
||||||
|
return {
|
||||||
|
"status":
|
||||||
|
"request format is not json"
|
||||||
|
}
|
||||||
|
content = request.get_json()
|
||||||
|
if content["secret"] != SECRET:
|
||||||
|
return {"status": "invalid secret"}
|
||||||
|
|
||||||
|
performance = content["performance"]
|
||||||
|
token = token_urlsafe(30)
|
||||||
|
|
||||||
|
while token in [client.token for client in clients]:
|
||||||
|
token = token_urlsafe(30)
|
||||||
|
|
||||||
|
clients.append(Client(performance, token))
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
data["token"] = token
|
||||||
|
data["nb_elem"] = performance
|
||||||
|
data["start"] = 0
|
||||||
|
data["dataset"] = DATASET
|
||||||
|
data["status"] = "ok"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/get_network")
|
||||||
|
def get_network():
|
||||||
|
"""
|
||||||
|
Renvoie le réseau neuronal
|
||||||
|
"""
|
||||||
|
if not request.is_json:
|
||||||
|
return {
|
||||||
|
"status":
|
||||||
|
"request format is not json"
|
||||||
|
}
|
||||||
|
token = request.get_json()["token"]
|
||||||
|
if token not in [client.token for client in clients]:
|
||||||
|
return {
|
||||||
|
"status":
|
||||||
|
"token invalide"
|
||||||
|
}
|
||||||
|
return send_file(
|
||||||
|
os.path.join(CACHE, "reseau.bin")
|
||||||
|
)
|
86
src/parallel/client.py
Normal file
86
src/parallel/client.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/python3
|
||||||
|
"""
|
||||||
|
Client se connectant au serveur Flask afin de fournir de la puissance de calcul.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
CACHE = ".cache"
|
||||||
|
DELTA = os.path.join(CACHE, "delta_shared.bin")
|
||||||
|
RESEAU = os.path.join(CACHE, "reseau_shared.bin")
|
||||||
|
SECRET = input("SECRET : ")
|
||||||
|
HOST = input("HOST : ")
|
||||||
|
os.makedirs(CACHE, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_performance():
|
||||||
|
"""
|
||||||
|
Renvoie un indice de performance du client afin de savoir quelle quantité de données lui fournir
|
||||||
|
"""
|
||||||
|
cores = os.cpu_count()
|
||||||
|
max_freq = psutil.cpu_freq()[2]
|
||||||
|
return int(cores * max_freq * 0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate():
|
||||||
|
"""
|
||||||
|
S'inscrit en tant que client auprès du serveur
|
||||||
|
"""
|
||||||
|
performance = get_performance()
|
||||||
|
data = {
|
||||||
|
"performance":performance,
|
||||||
|
"secret": SECRET
|
||||||
|
}
|
||||||
|
req = requests.post(
|
||||||
|
f"http://{HOST}/authenticate",
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
data = json.loads(req.text)
|
||||||
|
if data["status"] != "ok":
|
||||||
|
print("authentication error:", data["status"])
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def download_network(token):
|
||||||
|
"""
|
||||||
|
Récupère le réseau depuis le serveur
|
||||||
|
"""
|
||||||
|
data = {"token": token}
|
||||||
|
with requests.get(f"http://{HOST}/get_network", stream=True, json=data) as req:
|
||||||
|
req.raise_for_status()
|
||||||
|
with open(os.path.join(CACHE, RESEAU), "wb") as file:
|
||||||
|
for chunk in req.iter_content(chunk_size=8192):
|
||||||
|
file.write(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA):
|
||||||
|
"""
|
||||||
|
Entraînement du réseau
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def __main__():
|
||||||
|
data = authenticate()
|
||||||
|
|
||||||
|
token = data["token"]
|
||||||
|
dataset = data["dataset"]
|
||||||
|
|
||||||
|
start = data["start"]
|
||||||
|
nb_elem = data["nb_elem"]
|
||||||
|
|
||||||
|
download_network(token)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
train_shared(dataset, start, nb_elem, epochs=1, out=DELTA)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
__main__()
|
15
src/parallel/clients.py
Normal file
15
src/parallel/clients.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/python3
|
||||||
|
"""
|
||||||
|
Description des clients se connectant au serveur.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Client():
|
||||||
|
"""
|
||||||
|
Classe client
|
||||||
|
"""
|
||||||
|
def __init__(self, performance, token):
|
||||||
|
self.performance = performance
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
|
||||||
|
clients = []
|
Loading…
Reference in New Issue
Block a user