diff --git a/src/parallel/client.py b/src/parallel/client.py index 6e429b3..431486a 100755 --- a/src/parallel/client.py +++ b/src/parallel/client.py @@ -4,17 +4,20 @@ Client se connectant au serveur Flask afin de fournir de la puissance de calcul. """ import json import os -import sys -import time +import shutil import subprocess +import sys +import tempfile +import time import psutil import requests # Définition de constantes -CACHE = "/tmp/parallel/client_cache" # Replace with an absolute path +CACHE = tempfile.mkdtemp() DELTA = os.path.join(CACHE, "delta_shared.bin") RESEAU = os.path.join(CACHE, "reseau_shared.bin") +PROTOCOL = "https" if len(sys.argv) > 1: HOST = sys.argv[1] @@ -46,7 +49,7 @@ def authenticate(): performance = get_performance() data = {"performance": performance, "secret": SECRET} # Les données d'identification seront ensuite stockées dans un cookie de l'objet session - req = session.post(f"http://{HOST}/authenticate", json=data) + req = session.post(f"{PROTOCOL}://{HOST}/authenticate", json=data) data = json.loads(req.text) if data["status"] != "ok": @@ -60,7 +63,7 @@ def download_network(): """ Récupère le réseau depuis le serveur """ - with session.get(f"http://{HOST}/get_network", stream=True) as req: + with session.get(f"{PROTOCOL}://{HOST}/get_network", stream=True) as req: req.raise_for_status() with open(RESEAU, "wb") as file: for chunk in req.iter_content(chunk_size=8192): @@ -73,7 +76,7 @@ def send_delta_network(continue_=False): """ with open(DELTA, "rb") as file: files = {"file": file} - req = session.post(f"http://{HOST}/post_network", files=files) + req = session.post(f"{PROTOCOL}://{HOST}/post_network", files=files) req_data = json.loads(req.text) # Actions à effectuer en fonction de la réponse @@ -146,8 +149,11 @@ def __main__(): # train_shared s'appelle récursivement sur lui même jusqu'à la fin du programme try: train_shared(dataset, start, nb_elem, epochs=1, out=DELTA) - except requests.exceptions.ConnectionError: + except requests.exceptions.ConnectionError and json.decoder.JSONDecodeError: + # requests.exceptions.ConnectionError -> Host disconnected + # json.decoder.JSONDecodeError -> Host disconnected but nginx handles it print("Host disconnected") + shutil.rmtree(CACHE) if __name__ == "__main__":