Skip to content

Commit

Permalink
Fix test_bench
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 20, 2024
1 parent b2b4158 commit a6390be
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def pytest_generate_tests(metafunc):

if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork":
paths = _list_model_paths()
model_names = [os.path.basename(path) for path in paths]
metafunc.parametrize(
"model_name",
model_names,
ids=model_names,
"model_path",
paths,
ids=[os.path.basename(path) for path in paths],
scope="class",
)

Expand All @@ -62,13 +61,14 @@ def pytest_generate_tests(metafunc):
)
class TestBenchNetwork:

def test_train(self, model_name, device, compiler, benchmark):
def test_train(self, model_path, device, benchmark):
try:
model_name = os.path.basename(model_path)
if skip_by_metadata(
test="train",
device=device,
extra_args=[],
metadata=get_metadata_from_yaml(model_name),
metadata=get_metadata_from_yaml(model_path),
):
raise NotImplementedError("Test skipped by its metadata.")
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
Expand All @@ -91,13 +91,14 @@ def test_train(self, model_name, device, compiler, benchmark):
except NotImplementedError:
print(f"Test train on {device} is not implemented, skipping...")

def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):
def test_eval(self, model_path, device, benchmark, pytestconfig):
try:
model_name = os.path.basename(model_path)
if skip_by_metadata(
test="eval",
device=device,
extra_args=[],
metadata=get_metadata_from_yaml(model_name),
metadata=get_metadata_from_yaml(model_path),
):
raise NotImplementedError("Test skipped by its metadata.")
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
Expand All @@ -110,16 +111,15 @@ def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):

task.make_model_instance(test="eval", device=device)

with task.no_grad(disable_nograd=pytestconfig.getoption("disable_nograd")):
benchmark(task.invoke)
benchmark.extra_info["machine_state"] = get_machine_state()
benchmark.extra_info["batch_size"] = task.get_model_attribute(
"batch_size"
)
benchmark.extra_info["precision"] = task.get_model_attribute(
"dargs", "precision"
)
benchmark.extra_info["test"] = "eval"
benchmark(task.invoke)
benchmark.extra_info["machine_state"] = get_machine_state()
benchmark.extra_info["batch_size"] = task.get_model_attribute(
"batch_size"
)
benchmark.extra_info["precision"] = task.get_model_attribute(
"dargs", "precision"
)
benchmark.extra_info["test"] = "eval"

except NotImplementedError:
print(f"Test eval on {device} is not implemented, skipping...")
Expand Down

0 comments on commit a6390be

Please sign in to comment.