From b923809e1f7634c7e338a304d1852564af8bbd0d Mon Sep 17 00:00:00 2001 From: SzilBalazs Date: Thu, 3 Aug 2023 17:11:31 +0200 Subject: [PATCH] Improve bench and tb download scripts Bench: 3091908 --- scripts/bench.py | 41 +++++++++++++++++++++++++++++------------ scripts/tb_dl.py | 17 +++++++++++------ 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/scripts/bench.py b/scripts/bench.py index c460cae..7d3ba59 100644 --- a/scripts/bench.py +++ b/scripts/bench.py @@ -15,24 +15,41 @@ # along with this program. If not, see . # - import concurrent -from concurrent.futures import ThreadPoolExecutor +import argparse import subprocess - - -THREAD_COUNT = 4 +from concurrent.futures import ThreadPoolExecutor def task(): - output = subprocess.check_output(["./WhiteCore-v0-2", 'bench']).decode("utf8") + output = subprocess.check_output(["./WhiteCore", 'bench']).decode("utf8") return output -executor = ThreadPoolExecutor(max_workers=THREAD_COUNT) -futures = [] -for i in range(THREAD_COUNT): - futures.append(executor.submit(task)) +def parse_nps(output): + tokens = output.split() + nps_index = tokens.index('nps') + nps_value = int(tokens[nps_index - 1]) + return nps_value + + +def main(thread_count): + executor = ThreadPoolExecutor(max_workers=thread_count) + futures = [executor.submit(task) for _ in range(thread_count)] + + nps_values = [] + + for future in concurrent.futures.as_completed(futures): + output = future.result() + nps = parse_nps(output) + nps_values.append(nps) + + average_nps = sum(nps_values) / len(nps_values) + print(f"Average NPS: {int(average_nps)}") + -for future in concurrent.futures.as_completed(futures): - print(future.result()) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Calculate average NPS.") + parser.add_argument("threads", type=int, help="Number of threads to use") + args = parser.parse_args() + main(args.threads) diff --git a/scripts/tb_dl.py b/scripts/tb_dl.py index f7fd144..5918e09 100644 --- a/scripts/tb_dl.py +++ b/scripts/tb_dl.py @@ -16,6 +16,7 @@ # import urllib.request +import argparse import re man5 = "https://tablebase.lichess.ovh/tables/standard/3-4-5/" @@ -42,16 +43,20 @@ def dl_tablebase(url): urllib.request.urlretrieve(f"{url}{eg}", f"{eg}") -def main(): - prompt_man5 = input("Download 5 man? (y/n)") - prompt_man6 = input("Download 6 man? (y/n)") - if prompt_man5 == "y": +def main(man5_flag, man6_flag): + if man5_flag: print("Downloading 5 man...") dl_tablebase(man5) - if prompt_man6 == "y": + if man6_flag: print("Downloading 6 man...") dl_tablebase(man6) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description='Downloads from Tablebase.') + parser.add_argument('--man5', action='store_true') + parser.add_argument('--man6', action='store_true') + + args = parser.parse_args() + + main(args.man5, args.man6)