diff --git a/src/cemantix.py b/src/cemantix.py index 9e8b9cf..ab4ad6f 100644 --- a/src/cemantix.py +++ b/src/cemantix.py @@ -5,32 +5,74 @@ import readline import random import time +import numpy as np +np.seterr(divide='ignore', invalid='ignore') -def random_word(model, k=5, dist=100): - base_words = [ - model.index_to_key[random.randint(0, len(model))] - for _ in range(k) - ] +class Server(): + def __init__(self): + pass - complete_list = base_words.copy() - for word in base_words: - complete_list += [i[0] for i in model.most_similar(word, topn=dist)] + def set_random_word(self): + pass - rk_words = model.rank_by_centrality(complete_list) - return rk_words[random.randint(0,5)%len(rk_words)][1] + def get_rank(self, guess): + pass + def get_temp(self, guess): + pass -def cemantix(model, word=None): - while word is None or len(word) < 5 or '-' in word or '_' in word: - word = random_word(model, k=1, dist=0) # augment dist for a "smoother selection" + def _help(self): + raise NotImplementedError - nearest = [word]+[i[0] for i in model.most_similar(word, topn=1000)] - guesses = [] # guess, temp, rank - def get_rank(guess): - if guess not in nearest: + def _reveal_word(self): + raise NotImplementedError + +class LocalServer(Server): + def __init__(self, word=None, file="models/selected_word2vec_model.bin"): + self.model = KeyedVectors.load_word2vec_format( + file, + binary=True, + unicode_errors="ignore" + ) + self.word = word + self.nearest = [] + + def init_word(self, k=1, dist=100): + while (self.word is None or len(self.word) < 5 + or '-' in self.word or '_' in self.word): + base_words = [ + self.model.index_to_key[random.randint(0, len(self.model))] + for _ in range(k) + ] + + complete_list = base_words.copy() + for word in base_words: + complete_list += [i[0] for i in self.model.most_similar(word, topn=dist)] + + rk_words = self.model.rank_by_centrality(complete_list) + + self.word = rk_words[random.randint(0,5)%len(rk_words)][1] + self.nearest = [word]+[i[0] for i in self.model.most_similar(self.word, topn=1000)] + + def get_rank(self, guess): + if guess not in self.nearest: return None - return 1000 - nearest.index(guess) + return 1000 - self.nearest.index(guess) + def get_temp(self, guess): + return round(self.model.distance(self.word, guess)*100, 2) + + def _help(self, rk): + return self.nearest[rk] + + def _reveal_word(self): + return self.word + + +def cemantix(server: Server): + server.init_word() + + guesses = [] # guess, temp, rank def formatted_status(guesses, last=None): text = "" for w, temp, rank in guesses: @@ -46,15 +88,19 @@ def cemantix(model, word=None): return text[:-1] def tried(word, guessed): - return word in [i[0] for i in guessed] + return (word in [i[0] for i in guessed]) def interpret_command(cmd, guesses): match cmd: case "clear": guesses = [g for g in guesses if g[1] <= 75.] case "help": - best_rk = max([rk for _, _, rk in guesses if rk is not None]+[749]) - print("Maybe try "+Back.YELLOW+Fore.BLACK+nearest[999-best_rk]+Style.RESET_ALL) + try: + best_rk = max([rk for _, _, rk in guesses if rk is not None]+[749]) + print("Maybe try "+Back.YELLOW+Fore.BLACK+server._help(999-best_rk) + +Style.RESET_ALL) + except NotImplementedError: + print(Fore.RED+"No help available"+Style.RESET_ALL) case _: print(Fore.RED+"Unknown command"+Style.RESET_ALL) @@ -69,21 +115,24 @@ def cemantix(model, word=None): guesses = interpret_command(guess[:-2], guesses) continue except (EOFError, KeyboardInterrupt): - print("The word was "+Style.BRIGHT+word+Style.RESET_ALL) + try: + print("The word was "+Style.BRIGHT+server._reveal_word()+Style.RESET_ALL) + except NotImplementedError: + pass print("Goodbye!") return -1 try: - dist = round(round(model.distance(word, guess), 4)*100, 2) + dist = server.get_temp(guess) except KeyError: print(Fore.RED+"Key not present"+Style.RESET_ALL) continue if not tried(guess, guesses): - guesses.append((guess, dist, get_rank(guess))) + guesses.append((guess, dist, server.get_rank(guess))) guesses.sort(key=lambda x:-x[1]) print(chr(27) + "[2J") print(formatted_status(guesses, last=guess)) - if guess == word: + if server.get_rank(guess) == 1000: time.sleep(1) print(Fore.GREEN+"Correct!"+Style.RESET_ALL+f" {len(guesses)} tries.") return len(guesses) @@ -100,12 +149,7 @@ def main(): help="Specify model to use") args = parser.parse_args() - model = KeyedVectors.load_word2vec_format( - args.model, - binary=True, - unicode_errors="ignore" - ) - cemantix(model, word=args.word) + cemantix(LocalServer(word=args.word, file=args.model)) if __name__ == "__main__": main()