diff --git a/src/parallel/app.py b/src/parallel/app.py index 2ad45ae..d4d0bb0 100644 --- a/src/parallel/app.py +++ b/src/parallel/app.py @@ -6,18 +6,19 @@ import os import time import random import subprocess -from secrets import token_urlsafe from threading import Thread +from secrets import token_urlsafe from flask import Flask, request, send_from_directory, session -from structures import clients, Client, NoMoreJobAvailableError, TryLaterError, Training +from structures import (Client, NoMoreJobAvailableError, Training, + TryLaterError, clients) # Définitions de variables DATASET = "mnist-train" TEST_SET = "mnist-t10k" SECRET = str(random.randint(1000, 10000)) -CACHE = ".cache" # À remplacer avec un chemin absolu +CACHE = "/tmp/parallel/app_cache" # À remplacer avec un chemin absolu BATCHS = 10 RESEAU = os.path.join(CACHE, "reseau.bin") @@ -28,13 +29,13 @@ os.makedirs(CACHE, exist_ok=True) if not os.path.isfile(RESEAU): if not os.path.isfile("out/main"): subprocess.call(["./make.sh", "build", "main"]) - subprocess.call - ([ + subprocess.call( + [ "out/main", "train", "--epochs", "0", "--images", "data/mnist/train-images-idx3-ubyte", "--labels", "data/mnist/train-labels-idx1-ubyte", - "--out", RESEAU, + "--out", RESEAU ]) print(f" * Created {RESEAU}") else: diff --git a/src/parallel/client.py b/src/parallel/client.py index 734093f..f506f0a 100644 --- a/src/parallel/client.py +++ b/src/parallel/client.py @@ -12,7 +12,7 @@ import psutil import requests # Définition de constantes -CACHE = ".cache" # Replace with an absolute path +CACHE = "/tmp/parallel/client_cache" # Replace with an absolute path DELTA = os.path.join(CACHE, "delta_shared.bin") RESEAU = os.path.join(CACHE, "reseau_shared.bin") SECRET = input("SECRET : ") diff --git a/src/parallel/structures.py b/src/parallel/structures.py index 08aefc4..2f3b18d 100644 --- a/src/parallel/structures.py +++ b/src/parallel/structures.py @@ -3,6 +3,7 @@ Description des structures. """ import os +import sys import time import subprocess @@ -101,6 +102,9 @@ class Training: ]) self.cur_batch += 1 self.cur_image = 0 + if self.cur_batch >= self.batchs: + print("Done.") + sys.exit() def patch(self): @@ -116,8 +120,8 @@ class Training: if not os.path.isfile("out/main"): subprocess.call(["./make.sh", "build", "utils"]) - subprocess.call - ([ + subprocess.call( + [ "out/utils", "patch-network", "--network", self.reseau, "--delta", self.delta,