-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_information.py
60 lines (49 loc) · 2.14 KB
/
compute_information.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from synqgen.ni import HFNIEstimator
import click
import json
import torch
MAP = {"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16}
@click.command()
@click.argument("collection_file")
@click.argument("model_checkpoint")
@click.option("--context-percentage", type=float, default=0)
@click.option("--context-tokens", type=int, default=0)
@click.option("--dtype", type=str, default="float32")
def main(collection_file,
model_checkpoint,
context_percentage,
context_tokens,
dtype):
if dtype=="int8":
_dtype_option = {"load_in_8bit": True}
elif dtype=="int4":
_dtype_option = {"load_in_4bit": True}
else:
_dtype_option = {"torch_dtype": MAP[dtype]}
def read_jsonl(file_path):
def gen():
with open(file_path) as f:
for data in map(json.loads, f):
yield data
return gen
estimator = HFNIEstimator(model_checkpoint,
cache_dir="hf_cache",
model_kwargs={
"device_map": 0,
**_dtype_option
})
_model_name = model_checkpoint.replace("/","_")
_dataset_name = collection_file[:-12].replace("/","_")
_out_name = f"{_model_name}_{_dataset_name}_P{context_percentage}_{dtype}.jsonl"
with open(f"results/{_out_name}", "w") as f:
for out in estimator.information_from_generator(read_jsonl(collection_file),
context_percentage=context_percentage,
context_tokens=context_tokens,
max_samples=10_000,
max_documents=10_000):
#print(out)
f.write(f"{json.dumps(out)}\n")
if __name__=="__main__":
main()