Skip to content

Commit

Permalink
Use blackformatter across the project
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin committed Apr 16, 2024
1 parent 34699a6 commit d7f2b28
Show file tree
Hide file tree
Showing 12 changed files with 472 additions and 336 deletions.
118 changes: 35 additions & 83 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@

strict = False


def check_args(args, command_name: str):
global strict

# chat and generate support the same options
if command_name in ["generate", "chat", "gui"]:
if command_name in ["generate", "chat", "gui"]:
# examples, can add more. Note that attributes convert dash to _
disallowed_args = ['output_pte_path', 'output_dso_path' ]
disallowed_args = ["output_pte_path", "output_dso_path"]
elif command_name == "export":
# examples, can add more. Note that attributes convert dash to _
disallowed_args = ['pte_path', 'dso_path' ]
disallowed_args = ["pte_path", "dso_path"]
elif command_name == "eval":
# TBD
disallowed_args = []
else:
raise RuntimeError(f"{command_name} is not a valid command")

for disallowed in disallowed_args:
if hasattr(args, disallowed):
text = f"command {command_name} does not support option {disallowed.replace('_', '-')}"
Expand All @@ -39,7 +40,7 @@ def check_args(args, command_name: str):
else:
print(f"Warning: {text}")


def cli_args():
import argparse

Expand All @@ -48,8 +49,8 @@ def cli_args():
parser.add_argument(
"--seed",
type=int,
default=1234, # set None for release
help="Initialize torch seed"
default=1234, # set None for release
help="Initialize torch seed",
)
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
Expand Down Expand Up @@ -78,55 +79,31 @@ def cli_args():
"--chat",
action="store_true",
help="Use torchat to for an interactive chat session.",
)
)
parser.add_argument(
"--gui",
action="store_true",
help="Use torchat to for an interactive gui-chat session.",
)
parser.add_argument(
"--num-samples",
type=int,
default=1,
help="Number of samples.")
parser.add_argument(
"--max-new-tokens",
type=int,
default=200,
help="Maximum number of new tokens."
)
parser.add_argument("--num-samples", type=int, default=1, help="Number of samples.")
parser.add_argument(
"--top-k",
type=int,
default=200,
help="Top-k for sampling.")
"--max-new-tokens", type=int, default=200, help="Maximum number of new tokens."
)
parser.add_argument("--top-k", type=int, default=200, help="Top-k for sampling.")
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Temperature for sampling."
"--temperature", type=float, default=0.8, help="Temperature for sampling."
)
parser.add_argument(
"--compile",
action="store_true",
help="Whether to compile the model."
"--compile", action="store_true", help="Whether to compile the model."
)
parser.add_argument(
"--compile-prefill",
action="store_true",
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
)
parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
parser.add_argument(
"--profile",
type=Path,
default=None,
help="Profile path."
)
parser.add_argument(
"--speculate-k",
type=int,
default=5,
help="Speculative execution depth."
"--speculate-k", type=int, default=5, help="Speculative execution depth."
)
parser.add_argument(
"--draft-checkpoint-path",
Expand Down Expand Up @@ -163,31 +140,18 @@ def cli_args():
type=Path,
default=None,
help="Model checkpoint path.",
)
parser.add_argument(
"--output-pte-path",
type=str,
default=None,
help="Filename"
)
parser.add_argument("--output-pte-path", type=str, default=None, help="Filename")
parser.add_argument("--output-dso-path", type=str, default=None, help="Filename")
parser.add_argument(
"--output-dso-path",
type=str,
default=None,
help="Filename"
)
parser.add_argument(
"--dso-path",
type=Path,
default=None,
help="Use the specified AOTI DSO model."
"--dso-path", type=Path, default=None, help="Use the specified AOTI DSO model."
)
parser.add_argument(
"--pte-path",
type=Path,
default=None,
help="Use the specified Executorch PTE model."
)
help="Use the specified Executorch PTE model.",
)
parser.add_argument(
"-d",
"--dtype",
Expand All @@ -196,48 +160,36 @@ def cli_args():
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
"--quantize",
type=str,
default="{ }",
help="Quantization options."
"--quantize", type=str, default="{ }", help="Quantization options."
)
parser.add_argument(
"--device",
type=str,
default=default_device,
help="Device to use"
)
parser.add_argument(
"--params-table",
type=str,
default=None,
help="Device to use"
"--device", type=str, default=default_device, help="Device to use"
)
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
parser.add_argument(
'--tasks',
nargs='+',
"--tasks",
nargs="+",
type=str,
default=["hellaswag"],
help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2'
help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
)
parser.add_argument(
'--limit', type=int,
default=None,
help='number of samples to evaluate'
"--limit", type=int, default=None, help="number of samples to evaluate"
)
parser.add_argument(
'--max-seq-length',
"--max-seq-length",
type=int,
default=None,
help='maximum length sequence to evaluate')

help="maximum length sequence to evaluate",
)

args = parser.parse_args()

if (Path(args.quantize).is_file()):
if Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)

return args
57 changes: 36 additions & 21 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,32 @@

try:
import lm_eval

lm_eval_available = True
except:
lm_eval_available = False

from build.builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs
from build.builder import (
_initialize_model,
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)
from generate import encode_tokens, model_forward

if lm_eval_available:
try: # lm_eval version 0.4
try: # lm_eval version 0.4
from lm_eval.models.huggingface import HFLM as eval_wrapper
from lm_eval.tasks import get_task_dict
from lm_eval.evaluator import evaluate
except: #lm_eval version 0.3
except: # lm_eval version 0.3
from lm_eval import base
from lm_eval import tasks
from lm_eval import evaluator
eval_wrapper=base.BaseLM
get_task_dict=tasks.get_task_dict
evaluate=evaluator.evaluate

eval_wrapper = base.BaseLM
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate


def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
Expand Down Expand Up @@ -84,20 +91,22 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

return seq, input_pos, max_seq_length


class GPTFastEvalWrapper(eval_wrapper):
"""
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
"""

def __init__(
self,
model: Transformer,
tokenizer,
max_seq_length: Optional[int]=None,
max_seq_length: Optional[int] = None,
):
super().__init__()
self._model = model
self._tokenizer = tokenizer
self._device = torch.device('cuda')
self._device = torch.device("cuda")
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length

@property
Expand All @@ -121,8 +130,7 @@ def device(self):
return self._device

def tok_encode(self, string: str, **kwargs):
encoded = encode_tokens(self._tokenizer,
string, bos=True, device=self._device)
encoded = encode_tokens(self._tokenizer, string, bos=True, device=self._device)
# encoded is a pytorch tensor, but some internal logic in the
# eval harness expects it to be a list instead
# TODO: verify this for multi-batch as well
Expand All @@ -138,19 +146,20 @@ def _model_call(self, inps):
inps = inps.squeeze(0)

max_new_tokens = 1
seq, input_pos, max_seq_length = \
seq, input_pos, max_seq_length = (
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
self._model,
inps,
max_new_tokens,
self.max_length,
)
)
x = seq.index_select(0, input_pos).view(1, -1)
logits = model_forward(self._model, x, input_pos)
return logits

def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')
raise Exception("unimplemented")


@torch.no_grad()
Expand Down Expand Up @@ -185,8 +194,8 @@ def eval(
except:
pass

if 'hendrycks_test' in tasks:
tasks.remove('hendrycks_test')
if "hendrycks_test" in tasks:
tasks.remove("hendrycks_test")
tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()]
task_dict = get_task_dict(tasks)

Expand All @@ -212,7 +221,7 @@ def main(args) -> None:

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)

checkpoint_path = args.checkpoint_path
checkpoint_dir = args.checkpoint_dir
params_path = args.params_path
Expand All @@ -223,12 +232,12 @@ def main(args) -> None:
pte_path = args.pte_path
quantize = args.quantize
device = args.device
model_dtype = args.dtype
model_dtype = args.dtype
tasks = args.tasks
limit = args.limit
max_seq_length = args.max_seq_length
use_tiktoken = args.tiktoken

print(f"Using device={device}")
set_precision(buildeer_args.precision)

Expand All @@ -240,9 +249,13 @@ def main(args) -> None:
)

if compile:
assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model"
assert not (
builder_args.dso_path or builder_args.pte_path
), "cannot compile exported model"
global model_forward
model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True)
model_forward = torch.compile(
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
)
torch._inductor.config.coordinate_descent_tuning = True

t1 = time.time()
Expand All @@ -268,11 +281,13 @@ def main(args) -> None:
for task, res in result["results"].items():
print(f"{task}: {res}")

if __name__ == '__main__':

if __name__ == "__main__":

def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
cli()
Loading

0 comments on commit d7f2b28

Please sign in to comment.