Skip to content

Commit

Permalink
Filter early positions
Browse files Browse the repository at this point in the history
Bench: 9543419
  • Loading branch information
SzilBalazs committed Sep 1, 2023
1 parent 8e060dd commit c5d8215
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions train/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,53 @@ def combine_data(directory, output):
print(f"Directory {directory} does not exist.")
return

files = glob.glob(f'{directory}/*.plain')
files = glob.glob(f"{directory}/*.plain")

if not files:
print("No .plain files found in the given directory.")
return

stats = [0, 0, 0]
filtered_early = 0

with open(output, 'w') as outfile:
with open(output, "w") as outfile:
for file in files:
with open(file, 'r') as infile:
with open(file, "r") as infile:
print(f"Reading {file}...")
for line in infile:
outfile.write(line)
tokens = line.split(";")
wdl = int(tokens[4]) + 1
fen, ply, move, score, wdl, tmp = line.split(";")

if int(ply) <= 10:
filtered_early += 1
continue

wdl = int(wdl) + 1
stats[wdl] += 1
outfile.write(line)

total = sum(stats)
print(f"Black wins {stats[0] / total * 100}% - "
f"Draws {stats[1] / total * 100}% - "
f"Wins wins {stats[2] / total * 100}%")
print(f"Filtered early positions: {filtered_early}")
print(f"Data from {len(files)} files has been successfully combined into {output}")


def split_data(input_path, rate):
os.system(f'echo "split input {input_path} rate {rate}\nquit\n" | ./WhiteCore')


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Combine plain files and shuffle them')
parser.add_argument('directory', type=str, help='the directory to get plain files')
parser.add_argument('output', type=str, help='the output file')
parser = argparse.ArgumentParser(description="Combine plain files and shuffle them")
parser.add_argument("directory", type=str, help="the directory to get plain files")
parser.add_argument("split_rate", type=int, help="train:validation data ratio")
args = parser.parse_args()

combine_data(args.directory, args.output)
for file in ["data.plain", "train.plain", "validation.plain"]:
try:
os.remove(file)
except FileNotFoundError:
pass

combine_data(args.directory, "data.plain")
split_data("data.plain", args.split_rate)
Binary file modified weights/master.bin
Binary file not shown.

0 comments on commit c5d8215

Please sign in to comment.