Implement LocalServer
This commit is contained in:
parent
a2d46b906b
commit
44df60fca2
@ -5,32 +5,74 @@ import readline
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
np.seterr(divide='ignore', invalid='ignore')
|
||||||
|
|
||||||
def random_word(model, k=5, dist=100):
|
class Server():
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_random_word(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_rank(self, guess):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_temp(self, guess):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _help(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _reveal_word(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class LocalServer(Server):
|
||||||
|
def __init__(self, word=None, file="models/selected_word2vec_model.bin"):
|
||||||
|
self.model = KeyedVectors.load_word2vec_format(
|
||||||
|
file,
|
||||||
|
binary=True,
|
||||||
|
unicode_errors="ignore"
|
||||||
|
)
|
||||||
|
self.word = word
|
||||||
|
self.nearest = []
|
||||||
|
|
||||||
|
def init_word(self, k=1, dist=100):
|
||||||
|
while (self.word is None or len(self.word) < 5
|
||||||
|
or '-' in self.word or '_' in self.word):
|
||||||
base_words = [
|
base_words = [
|
||||||
model.index_to_key[random.randint(0, len(model))]
|
self.model.index_to_key[random.randint(0, len(self.model))]
|
||||||
for _ in range(k)
|
for _ in range(k)
|
||||||
]
|
]
|
||||||
|
|
||||||
complete_list = base_words.copy()
|
complete_list = base_words.copy()
|
||||||
for word in base_words:
|
for word in base_words:
|
||||||
complete_list += [i[0] for i in model.most_similar(word, topn=dist)]
|
complete_list += [i[0] for i in self.model.most_similar(word, topn=dist)]
|
||||||
|
|
||||||
rk_words = model.rank_by_centrality(complete_list)
|
rk_words = self.model.rank_by_centrality(complete_list)
|
||||||
return rk_words[random.randint(0,5)%len(rk_words)][1]
|
|
||||||
|
|
||||||
|
self.word = rk_words[random.randint(0,5)%len(rk_words)][1]
|
||||||
|
self.nearest = [word]+[i[0] for i in self.model.most_similar(self.word, topn=1000)]
|
||||||
|
|
||||||
def cemantix(model, word=None):
|
def get_rank(self, guess):
|
||||||
while word is None or len(word) < 5 or '-' in word or '_' in word:
|
if guess not in self.nearest:
|
||||||
word = random_word(model, k=1, dist=0) # augment dist for a "smoother selection"
|
|
||||||
|
|
||||||
nearest = [word]+[i[0] for i in model.most_similar(word, topn=1000)]
|
|
||||||
guesses = [] # guess, temp, rank
|
|
||||||
def get_rank(guess):
|
|
||||||
if guess not in nearest:
|
|
||||||
return None
|
return None
|
||||||
return 1000 - nearest.index(guess)
|
return 1000 - self.nearest.index(guess)
|
||||||
|
|
||||||
|
def get_temp(self, guess):
|
||||||
|
return round(self.model.distance(self.word, guess)*100, 2)
|
||||||
|
|
||||||
|
def _help(self, rk):
|
||||||
|
return self.nearest[rk]
|
||||||
|
|
||||||
|
def _reveal_word(self):
|
||||||
|
return self.word
|
||||||
|
|
||||||
|
|
||||||
|
def cemantix(server: Server):
|
||||||
|
server.init_word()
|
||||||
|
|
||||||
|
guesses = [] # guess, temp, rank
|
||||||
def formatted_status(guesses, last=None):
|
def formatted_status(guesses, last=None):
|
||||||
text = ""
|
text = ""
|
||||||
for w, temp, rank in guesses:
|
for w, temp, rank in guesses:
|
||||||
@ -46,15 +88,19 @@ def cemantix(model, word=None):
|
|||||||
return text[:-1]
|
return text[:-1]
|
||||||
|
|
||||||
def tried(word, guessed):
|
def tried(word, guessed):
|
||||||
return word in [i[0] for i in guessed]
|
return (word in [i[0] for i in guessed])
|
||||||
|
|
||||||
def interpret_command(cmd, guesses):
|
def interpret_command(cmd, guesses):
|
||||||
match cmd:
|
match cmd:
|
||||||
case "clear":
|
case "clear":
|
||||||
guesses = [g for g in guesses if g[1] <= 75.]
|
guesses = [g for g in guesses if g[1] <= 75.]
|
||||||
case "help":
|
case "help":
|
||||||
|
try:
|
||||||
best_rk = max([rk for _, _, rk in guesses if rk is not None]+[749])
|
best_rk = max([rk for _, _, rk in guesses if rk is not None]+[749])
|
||||||
print("Maybe try "+Back.YELLOW+Fore.BLACK+nearest[999-best_rk]+Style.RESET_ALL)
|
print("Maybe try "+Back.YELLOW+Fore.BLACK+server._help(999-best_rk)
|
||||||
|
+Style.RESET_ALL)
|
||||||
|
except NotImplementedError:
|
||||||
|
print(Fore.RED+"No help available"+Style.RESET_ALL)
|
||||||
case _:
|
case _:
|
||||||
print(Fore.RED+"Unknown command"+Style.RESET_ALL)
|
print(Fore.RED+"Unknown command"+Style.RESET_ALL)
|
||||||
|
|
||||||
@ -69,21 +115,24 @@ def cemantix(model, word=None):
|
|||||||
guesses = interpret_command(guess[:-2], guesses)
|
guesses = interpret_command(guess[:-2], guesses)
|
||||||
continue
|
continue
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
print("The word was "+Style.BRIGHT+word+Style.RESET_ALL)
|
try:
|
||||||
|
print("The word was "+Style.BRIGHT+server._reveal_word()+Style.RESET_ALL)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
print("Goodbye!")
|
print("Goodbye!")
|
||||||
return -1
|
return -1
|
||||||
try:
|
try:
|
||||||
dist = round(round(model.distance(word, guess), 4)*100, 2)
|
dist = server.get_temp(guess)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(Fore.RED+"Key not present"+Style.RESET_ALL)
|
print(Fore.RED+"Key not present"+Style.RESET_ALL)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not tried(guess, guesses):
|
if not tried(guess, guesses):
|
||||||
guesses.append((guess, dist, get_rank(guess)))
|
guesses.append((guess, dist, server.get_rank(guess)))
|
||||||
guesses.sort(key=lambda x:-x[1])
|
guesses.sort(key=lambda x:-x[1])
|
||||||
print(chr(27) + "[2J")
|
print(chr(27) + "[2J")
|
||||||
print(formatted_status(guesses, last=guess))
|
print(formatted_status(guesses, last=guess))
|
||||||
if guess == word:
|
if server.get_rank(guess) == 1000:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
print(Fore.GREEN+"Correct!"+Style.RESET_ALL+f" {len(guesses)} tries.")
|
print(Fore.GREEN+"Correct!"+Style.RESET_ALL+f" {len(guesses)} tries.")
|
||||||
return len(guesses)
|
return len(guesses)
|
||||||
@ -100,12 +149,7 @@ def main():
|
|||||||
help="Specify model to use")
|
help="Specify model to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = KeyedVectors.load_word2vec_format(
|
cemantix(LocalServer(word=args.word, file=args.model))
|
||||||
args.model,
|
|
||||||
binary=True,
|
|
||||||
unicode_errors="ignore"
|
|
||||||
)
|
|
||||||
cemantix(model, word=args.word)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user