mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-23 23:26:25 +01:00
Remove parallel
This commit is contained in:
parent
a9e704a7bc
commit
0f5867ebb6
@ -1,155 +0,0 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Serveur Flask pour entraîner le réseau sur plusieurs machines en parallèle.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
from secrets import token_urlsafe
|
||||
|
||||
from flask import Flask, request, send_from_directory, session
|
||||
|
||||
from structures import (Client, NoMoreJobAvailableError, Training,
|
||||
TryLaterError, clients)
|
||||
|
||||
# Définitions de variables
|
||||
DATASET = "mnist-train"
|
||||
TEST_SET = "mnist-t10k"
|
||||
SECRET = str(random.randint(1000, 10000))
|
||||
CACHE = "/tmp/parallel/app_cache" # À remplacer avec un chemin absolu
|
||||
BATCHS = 10
|
||||
RESEAU = os.path.join(CACHE, "reseau.bin")
|
||||
|
||||
training = Training(BATCHS, DATASET, TEST_SET, CACHE)
|
||||
|
||||
os.makedirs(CACHE, exist_ok=True)
|
||||
# On crée un réseau aléatoire si il n'existe pas encore
|
||||
if not os.path.isfile(RESEAU):
|
||||
if not os.path.isfile("out/mnist_main"):
|
||||
subprocess.call(["./make.sh", "build", "mnist-main"])
|
||||
subprocess.call(
|
||||
[
|
||||
"out/mnist_main", "train",
|
||||
"--epochs", "0",
|
||||
"--images", "data/mnist/train-images-idx3-ubyte",
|
||||
"--labels", "data/mnist/train-labels-idx1-ubyte",
|
||||
"--out", RESEAU
|
||||
])
|
||||
print(f" * Created {RESEAU}")
|
||||
else:
|
||||
print(f" * {RESEAU} already exists")
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
# On définit une clé secrète pour pouvoir utiliser des cookies de session
|
||||
app.config["SECRET_KEY"] = token_urlsafe(40)
|
||||
print(f" * Secret: {SECRET}")
|
||||
|
||||
with open("app-secret", "w", encoding="utf8") as file:
|
||||
file.write(SECRET)
|
||||
|
||||
|
||||
@app.route("/authenticate", methods=["POST"])
|
||||
def authenticate():
|
||||
"""
|
||||
Authentification d'un nouvel utilisateur
|
||||
"""
|
||||
if not request.is_json:
|
||||
return {"status": "request format is not json"}
|
||||
content = request.get_json()
|
||||
if content["secret"] != SECRET:
|
||||
return {"status": "invalid secret"}
|
||||
|
||||
token = token_urlsafe(30)
|
||||
while token in clients.keys():
|
||||
token = token_urlsafe(30)
|
||||
|
||||
clients[token] = Client(content["performance"], token)
|
||||
|
||||
# On prépare la réponse du serveur
|
||||
data = {}
|
||||
data["status"] = "ok"
|
||||
data["dataset"] = training.dataset
|
||||
session["token"] = token
|
||||
|
||||
try:
|
||||
clients[token].get_job(training)
|
||||
data["nb_elem"] = clients[token].performance
|
||||
data["start"] = clients[token].start
|
||||
data["instruction"] = "train"
|
||||
|
||||
except NoMoreJobAvailableError:
|
||||
data["status"] = "Training already ended"
|
||||
data["nb_elem"] = 0
|
||||
data["start"] = 0
|
||||
data["instruction"] = "stop"
|
||||
|
||||
except TryLaterError:
|
||||
data["status"] = "Wait for next batch"
|
||||
data["nb_elem"] = 0
|
||||
data["start"] = 0
|
||||
data["instruction"] = "sleep"
|
||||
data["sleep_time"] = 0.2
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@app.route("/post_network", methods=["POST"])
|
||||
def post_network():
|
||||
"""
|
||||
Applique le patch renvoyé dans le nouveau réseau
|
||||
"""
|
||||
token = session.get("token")
|
||||
if not token in clients.keys():
|
||||
return {"status": "token invalide"}
|
||||
|
||||
while training.is_patch_locked():
|
||||
time.sleep(0.1)
|
||||
|
||||
request.files["file"].save(training.delta)
|
||||
training.patch()
|
||||
training.computed_images += clients[token].performance
|
||||
# Préparation de la réponse
|
||||
data = {}
|
||||
data["status"] = "ok"
|
||||
data["dataset"] = training.dataset
|
||||
|
||||
try:
|
||||
clients[token].get_job(training)
|
||||
data["dataset"] = training.dataset
|
||||
data["nb_elem"] = clients[token].performance
|
||||
data["start"] = clients[token].start
|
||||
data["instruction"] = "train"
|
||||
|
||||
except NoMoreJobAvailableError:
|
||||
data["status"] = "Training already ended"
|
||||
data["nb_elem"] = 0
|
||||
data["start"] = 0
|
||||
data["instruction"] = "stop"
|
||||
|
||||
except TryLaterError:
|
||||
Thread(target=training.test_network()).start()
|
||||
data["status"] = "Wait for next batch"
|
||||
data["nb_elem"] = 0
|
||||
data["start"] = 0
|
||||
data["instruction"] = "sleep"
|
||||
data["sleep_time"] = 0.02
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@app.route("/get_network", methods=["GET", "POST"])
|
||||
def get_network():
|
||||
"""
|
||||
Renvoie le réseau neuronal
|
||||
"""
|
||||
token = session.get("token")
|
||||
if not token in clients.keys():
|
||||
return {"status": "token invalide"}
|
||||
|
||||
if token not in clients.keys():
|
||||
return {"status": "token invalide"}
|
||||
|
||||
return send_from_directory(directory=CACHE, path="reseau.bin")
|
@ -1,160 +0,0 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Client se connectant au serveur Flask afin de fournir de la puissance de calcul.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
|
||||
# Définition de constantes
|
||||
CACHE = tempfile.mkdtemp()
|
||||
DELTA = os.path.join(CACHE, "delta_shared.bin")
|
||||
RESEAU = os.path.join(CACHE, "reseau_shared.bin")
|
||||
PROTOCOL = "https"
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
HOST = sys.argv[1]
|
||||
else:
|
||||
HOST = input("HOST : ")
|
||||
|
||||
if len(sys.argv) > 2:
|
||||
SECRET = sys.argv[2]
|
||||
else:
|
||||
SECRET = input("SECRET : ")
|
||||
|
||||
session = requests.Session()
|
||||
os.makedirs(CACHE, exist_ok=True)
|
||||
|
||||
|
||||
def get_performance():
|
||||
"""
|
||||
Renvoie un indice de performance du client afin de savoir quelle quantité de données lui fournir
|
||||
"""
|
||||
cores = os.cpu_count()
|
||||
max_freq = psutil.cpu_freq()[2]
|
||||
return int(cores * max_freq * 0.5)
|
||||
|
||||
|
||||
def authenticate():
|
||||
"""
|
||||
S'inscrit en tant que client auprès du serveur
|
||||
"""
|
||||
performance = get_performance()
|
||||
data = {"performance": performance, "secret": SECRET}
|
||||
# Les données d'identification seront ensuite stockées dans un cookie de l'objet session
|
||||
req = session.post(f"{PROTOCOL}://{HOST}/authenticate", json=data)
|
||||
|
||||
data = json.loads(req.text)
|
||||
if data["status"] != "ok":
|
||||
print("error in authenticate():", data["status"])
|
||||
sys.exit(1)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def download_network():
|
||||
"""
|
||||
Récupère le réseau depuis le serveur
|
||||
"""
|
||||
with session.get(f"{PROTOCOL}://{HOST}/get_network", stream=True) as req:
|
||||
req.raise_for_status()
|
||||
with open(RESEAU, "wb") as file:
|
||||
for chunk in req.iter_content(chunk_size=8192):
|
||||
file.write(chunk)
|
||||
|
||||
|
||||
def send_delta_network(continue_=False):
|
||||
"""
|
||||
Envoie le réseau différentiel et obéit aux instructions suivantes
|
||||
"""
|
||||
with open(DELTA, "rb") as file:
|
||||
files = {"file": file}
|
||||
req = session.post(f"{PROTOCOL}://{HOST}/post_network", files=files)
|
||||
req_data = json.loads(req.text)
|
||||
|
||||
# Actions à effectuer en fonction de la réponse
|
||||
if "instruction" not in req_data.keys():
|
||||
print(req_data["status"])
|
||||
raise NotImplementedError
|
||||
|
||||
if req_data["instruction"] == "sleep":
|
||||
print(f"Sleeping {req_data['sleep_time']}s.")
|
||||
time.sleep(req_data["sleep_time"])
|
||||
send_delta_network(continue_=continue_)
|
||||
|
||||
elif req_data["instruction"] == "stop":
|
||||
print(req_data["status"])
|
||||
print("Shutting down.")
|
||||
|
||||
elif req_data["instruction"] == "train":
|
||||
download_network()
|
||||
train_shared(req_data["dataset"], req_data["start"], req_data["nb_elem"])
|
||||
|
||||
else:
|
||||
json.dumps(req_data)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def train_shared(dataset, start, nb_elem, epochs=1, out=DELTA):
|
||||
"""
|
||||
Entraînement du réseau
|
||||
"""
|
||||
# Utiliser un dictionnaire serait plus efficace et plus propre
|
||||
if dataset == "mnist-train":
|
||||
images = "data/mnist/train-images-idx3-ubyte"
|
||||
labels = "data/mnist/train-labels-idx1-ubyte"
|
||||
elif dataset == "mnist-t10k":
|
||||
images = "data/mnist/t10k-images-idx3-ubyte"
|
||||
labels = "data/mnist/t10k-labels-idx1-ubyte"
|
||||
else:
|
||||
print(f"Dataset {dataset} not implemented yet")
|
||||
raise NotImplementedError
|
||||
|
||||
# On compile out/main si il n'existe pas encore
|
||||
if not os.path.isfile("out/mnist_main"):
|
||||
subprocess.call(["./make.sh", "build", "mnist-main"])
|
||||
|
||||
# Entraînement du réseau
|
||||
subprocess.call(
|
||||
[
|
||||
"out/mnist-main", "train",
|
||||
"--epochs", str(epochs),
|
||||
"--images", images,
|
||||
"--labels", labels,
|
||||
"--recover", RESEAU,
|
||||
"--delta", out,
|
||||
"--nb-images", str(nb_elem),
|
||||
"--start", str(start),
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
)
|
||||
return send_delta_network(continue_=True)
|
||||
|
||||
|
||||
def __main__():
|
||||
data = authenticate()
|
||||
|
||||
dataset = data["dataset"]
|
||||
start = data["start"]
|
||||
nb_elem = data["nb_elem"]
|
||||
|
||||
download_network()
|
||||
# train_shared s'appelle récursivement sur lui même jusqu'à la fin du programme
|
||||
try:
|
||||
train_shared(dataset, start, nb_elem, epochs=1, out=DELTA)
|
||||
except requests.exceptions.ConnectionError and json.decoder.JSONDecodeError:
|
||||
# requests.exceptions.ConnectionError -> Host disconnected
|
||||
# json.decoder.JSONDecodeError -> Host disconnected but nginx handles it
|
||||
print("Host disconnected")
|
||||
shutil.rmtree(CACHE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
__main__()
|
@ -1,146 +0,0 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
Description des structures.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
|
||||
class NoMoreJobAvailableError(Exception):
|
||||
"""Entraînement du réseau fini"""
|
||||
pass
|
||||
|
||||
|
||||
class TryLaterError(Exception):
|
||||
"""Batch fini, réessayer plus tard"""
|
||||
pass
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Description d'un client se connectant au serveur
|
||||
"""
|
||||
def __init__(self, performance, token):
|
||||
self.performance = performance
|
||||
self.token = token
|
||||
self.start = 0
|
||||
self.nb_images = 0
|
||||
|
||||
|
||||
def get_job(self, training):
|
||||
"""
|
||||
Donne un travail au client
|
||||
"""
|
||||
if training.nb_images <= training.computed_images:
|
||||
if training.batchs == training.cur_batch:
|
||||
raise NoMoreJobAvailableError
|
||||
raise TryLaterError
|
||||
|
||||
self.start = training.cur_image
|
||||
self.nb_images = min(training.nb_images - training.cur_image, self.performance)
|
||||
training.cur_image += self.nb_images
|
||||
|
||||
|
||||
clients = {}
|
||||
|
||||
|
||||
class Training:
|
||||
"""
|
||||
Classe training
|
||||
"""
|
||||
def __init__(self, batchs, dataset, test_set, cache):
|
||||
# Définition de variables
|
||||
self.batchs = batchs
|
||||
self.cur_batch = 1
|
||||
self.cur_image = 0
|
||||
self.computed_images = 0
|
||||
self.lock_test = False
|
||||
self.dataset = dataset
|
||||
self.test_set = test_set
|
||||
self.cache = cache
|
||||
self.reseau = os.path.join(self.cache, "reseau.bin")
|
||||
self.delta = os.path.join(self.cache, "delta.bin")
|
||||
|
||||
# Définition des chemins et données relatives à chaque set de données
|
||||
# TODO: implémenter plus proprement avec un dictionnaire ou même un fichier datasets.json
|
||||
if self.dataset == "mnist-train":
|
||||
self.nb_images = 60000
|
||||
elif self.dataset == "mnist-t10k":
|
||||
self.nb_images = 10000
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.test_set == "mnist-train":
|
||||
self.test_images = "data/mnist/train-images-idx3-ubyte"
|
||||
self.test_labels = "data/mnist/train-labels-idx1-ubyte"
|
||||
elif self.test_set == "mnist-t10k":
|
||||
self.test_images = "data/mnist/t10k-images-idx3-ubyte"
|
||||
self.test_labels = "data/mnist/t10k-labels-idx1-ubyte"
|
||||
else:
|
||||
print(f"{self.test_set} test dataset unknown.")
|
||||
raise NotImplementedError
|
||||
|
||||
# On supprime le fichier de lock qui permet de
|
||||
# ne pas écrire en même temps plusieurs fois sur le fichier reseau.bin
|
||||
if os.path.isfile(self.reseau + ".lock"):
|
||||
os.remove(self.reseau + ".lock")
|
||||
|
||||
|
||||
def test_network(self):
|
||||
"""
|
||||
Teste les performances du réseau avant le batch suivant
|
||||
"""
|
||||
if self.lock_test:
|
||||
return
|
||||
|
||||
self.lock_test = True
|
||||
if not os.path.isfile("out/mnist_main"):
|
||||
subprocess.call(["./make.sh", "build", "mnist-main"])
|
||||
|
||||
subprocess.call(
|
||||
[
|
||||
"out/mnist_main", "test",
|
||||
"--images", self.test_images,
|
||||
"--labels", self.test_labels,
|
||||
"--modele", self.reseau
|
||||
])
|
||||
self.cur_batch += 1
|
||||
self.cur_image = 0
|
||||
self.computed_images = 0
|
||||
if self.cur_batch >= self.batchs:
|
||||
print("Done.")
|
||||
os._exit(0)
|
||||
|
||||
self.lock_test = False
|
||||
return
|
||||
|
||||
|
||||
def patch(self):
|
||||
"""
|
||||
Applique un patch au réseau
|
||||
"""
|
||||
# On attend que le lock se libère puis on patch le réseau
|
||||
while self.is_patch_locked():
|
||||
time.sleep(0.1)
|
||||
|
||||
with open(self.reseau + ".lock", "w", encoding="utf8") as file:
|
||||
file.write("")
|
||||
|
||||
if not os.path.isfile("out/mnist_utils"):
|
||||
subprocess.call(["./make.sh", "build", "mnist-utils"])
|
||||
subprocess.call(
|
||||
[
|
||||
"out/mnist_utils", "patch-network",
|
||||
"--network", self.reseau,
|
||||
"--delta", self.delta,
|
||||
])
|
||||
|
||||
os.remove(self.reseau + ".lock")
|
||||
|
||||
|
||||
def is_patch_locked(self):
|
||||
"""
|
||||
Petit raccourci pour vérifier si le lock est présent
|
||||
"""
|
||||
return os.path.isfile(self.reseau + ".lock")
|
Loading…
Reference in New Issue
Block a user