Skip to content

Commit

Permalink
Improve bench and tb download scripts
Browse files Browse the repository at this point in the history
Bench: 3091908
  • Loading branch information
SzilBalazs committed Aug 3, 2023
1 parent 51efb18 commit b923809
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
41 changes: 29 additions & 12 deletions scripts/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,41 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#


import concurrent
from concurrent.futures import ThreadPoolExecutor
import argparse
import subprocess


THREAD_COUNT = 4
from concurrent.futures import ThreadPoolExecutor


def task():
output = subprocess.check_output(["./WhiteCore-v0-2", 'bench']).decode("utf8")
output = subprocess.check_output(["./WhiteCore", 'bench']).decode("utf8")
return output


executor = ThreadPoolExecutor(max_workers=THREAD_COUNT)
futures = []
for i in range(THREAD_COUNT):
futures.append(executor.submit(task))
def parse_nps(output):
tokens = output.split()
nps_index = tokens.index('nps')
nps_value = int(tokens[nps_index - 1])
return nps_value


def main(thread_count):
executor = ThreadPoolExecutor(max_workers=thread_count)
futures = [executor.submit(task) for _ in range(thread_count)]

nps_values = []

for future in concurrent.futures.as_completed(futures):
output = future.result()
nps = parse_nps(output)
nps_values.append(nps)

average_nps = sum(nps_values) / len(nps_values)
print(f"Average NPS: {int(average_nps)}")


for future in concurrent.futures.as_completed(futures):
print(future.result())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Calculate average NPS.")
parser.add_argument("threads", type=int, help="Number of threads to use")
args = parser.parse_args()
main(args.threads)
17 changes: 11 additions & 6 deletions scripts/tb_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import urllib.request
import argparse
import re

man5 = "https://tablebase.lichess.ovh/tables/standard/3-4-5/"
Expand All @@ -42,16 +43,20 @@ def dl_tablebase(url):
urllib.request.urlretrieve(f"{url}{eg}", f"{eg}")


def main():
prompt_man5 = input("Download 5 man? (y/n)")
prompt_man6 = input("Download 6 man? (y/n)")
if prompt_man5 == "y":
def main(man5_flag, man6_flag):
if man5_flag:
print("Downloading 5 man...")
dl_tablebase(man5)
if prompt_man6 == "y":
if man6_flag:
print("Downloading 6 man...")
dl_tablebase(man6)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description='Downloads from Tablebase.')
parser.add_argument('--man5', action='store_true')
parser.add_argument('--man6', action='store_true')

args = parser.parse_args()

main(args.man5, args.man6)

0 comments on commit b923809

Please sign in to comment.