diff --git a/sharktank/sharktank/examples/validate_llama_ref_model.py b/sharktank/sharktank/examples/validate_llama_ref_model.py index c8b061534..8b6a8d40c 100644 --- a/sharktank/sharktank/examples/validate_llama_ref_model.py +++ b/sharktank/sharktank/examples/validate_llama_ref_model.py @@ -15,6 +15,7 @@ def main(args: list[str]): from ..utils import cli + torch.no_grad().__enter__() parser = cli.create_parser()