Skip to content

Commit

Permalink
Use importlib.util.find_spec to check if lm_eval is installed ins…
Browse files Browse the repository at this point in the history
…tead of trying to import it (#1023)

Use importlib.util.find_spec to check if lm_eval is installed instead of trying to import it

There is a circular dependency when trying to import lm_eval inside torchao. The chain is like this:

torchao -> lm_eval -> transformers.pipelines -> torchao

And results in the following error:

RuntimeError: Failed to import transformers.pipelines because of the following error (look up to see its traceback):
cannot import name 'quantize_' from partially initialized module 'torchao.quantization'

which

1. causes _lm_eval_available to be erroneously set to False, even if lm_eval is available
2. interrupts lm_eval's initialization, leaving it partially initialized

you can observe this with:

>>> import torchao
>>> import lm_eval.__main__
>>> import lm_eval.api.registry
>> lm_eval.api.registry
AttributeError: module 'lm_eval' has no attribute 'api'

Having a bare except clause here was suppressing this circular import error, which from glancing around seems kind of like a general pattern in this code base. It might be worth reconsidering this pattern.
  • Loading branch information
ringohoffman authored Oct 7, 2024
1 parent c187f87 commit e2301e9
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import importlib.util
from typing import Dict, List, Optional

import torch
from torch.utils._python_dispatch import TorchDispatchMode
Expand Down Expand Up @@ -40,12 +41,7 @@
"recommended_inductor_config_setter"
]

try:
import lm_eval # pyre-ignore[21] # noqa: F401

_lm_eval_available = True
except:
_lm_eval_available = False
_lm_eval_available = importlib.util.find_spec("lm_eval") is not None

# basic SQNR
def compute_error(x, y):
Expand Down

0 comments on commit e2301e9

Please sign in to comment.