From e2301e9dba91fa962d673fdc3b3f0002856a3ba7 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 6 Oct 2024 19:20:24 -0700 Subject: [PATCH] Use `importlib.util.find_spec` to check if `lm_eval` is installed instead of trying to import it (#1023) Use importlib.util.find_spec to check if lm_eval is installed instead of trying to import it There is a circular dependency when trying to import lm_eval inside torchao. The chain is like this: torchao -> lm_eval -> transformers.pipelines -> torchao And results in the following error: RuntimeError: Failed to import transformers.pipelines because of the following error (look up to see its traceback): cannot import name 'quantize_' from partially initialized module 'torchao.quantization' which 1. causes _lm_eval_available to be erroneously set to False, even if lm_eval is available 2. interrupts lm_eval's initialization, leaving it partially initialized you can observe this with: >>> import torchao >>> import lm_eval.__main__ >>> import lm_eval.api.registry >> lm_eval.api.registry AttributeError: module 'lm_eval' has no attribute 'api' Having a bare except clause here was suppressing this circular import error, which from glancing around seems kind of like a general pattern in this code base. It might be worth reconsidering this pattern. --- torchao/quantization/utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0df6174d0..0beadfe5d 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,7 +3,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple +import importlib.util +from typing import Dict, List, Optional import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -40,12 +41,7 @@ "recommended_inductor_config_setter" ] -try: - import lm_eval # pyre-ignore[21] # noqa: F401 - - _lm_eval_available = True -except: - _lm_eval_available = False +_lm_eval_available = importlib.util.find_spec("lm_eval") is not None # basic SQNR def compute_error(x, y):