diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 3538bccd37..f5e8cc7b3a 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -33,6 +33,7 @@ size_categories_parser, ) from distilabel.utils.files import list_files_in_dir +from distilabel.utils.huggingface import get_hf_token DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs" PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml" @@ -81,7 +82,13 @@ def push_to_hub( Whether to generate a dataset card or not. Defaults to True. **kwargs: Additional keyword arguments to pass to the `push_to_hub` method of the `datasets.Dataset` object. + + Raises: + ValueError: If no token is provided and couldn't be retrieved automatically. """ + if token is None: + token = get_hf_token(self.__class__.__name__, "token") + for name, dataset in self.items(): dataset.push_to_hub( repo_id=repo_id, diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 015c022b1b..73b13ee4d0 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -16,7 +16,6 @@ import os import random import warnings -from pathlib import Path from typing import TYPE_CHECKING, Any, List, Optional, Union from pydantic import ( @@ -33,6 +32,10 @@ from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, Grammar, StandardInput +from distilabel.utils.huggingface import ( + _INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME, + get_hf_token, +) from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -41,9 +44,6 @@ from transformers import PreTrainedTokenizer -_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME = "HF_TOKEN" - - class InferenceEndpointsLLM(AsyncLLM): """InferenceEndpoints LLM implementation running the async API client. @@ -207,7 +207,6 @@ def load(self) -> None: # noqa: C901 from huggingface_hub import ( AsyncInferenceClient, InferenceClient, - constants, get_inference_endpoint, ) except ImportError as ie: @@ -217,14 +216,7 @@ def load(self) -> None: # noqa: C901 ) from ie if self.api_key is None: - if not Path(constants.HF_TOKEN_PATH).exists(): - raise ValueError( - f"To use `{self.__class__.__name__}` an API key must be provided via" - " `api_key` attribute or runtime parameter, set the environment variable" - f" `{self._api_key_env_var}` or use the `huggingface-hub` CLI to login" - " with `huggingface-cli login`." - ) - self.api_key = SecretStr(open(constants.HF_TOKEN_PATH).read().strip()) + self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) if self.model_id is not None: client = InferenceClient() diff --git a/src/distilabel/utils/huggingface.py b/src/distilabel/utils/huggingface.py new file mode 100644 index 0000000000..7a637a831c --- /dev/null +++ b/src/distilabel/utils/huggingface.py @@ -0,0 +1,53 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import Final + +from huggingface_hub import constants + +_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME: Final[str] = "HF_TOKEN" + + +def get_hf_token(cls_name: str, token_arg: str) -> str: + """Get the token for the hugging face API. + + Tries to extract it from the environment variable, if it is not found + it tries to read it from the file using 'huggingface_hub', + and if not possible raises a ValueError. + + Args: + cls_name: Name of the class/function that requires the token. + token_arg: Argument name to use in the error message, normally + is "token" or "api_key". + + Raises: + ValueError: If the token is not found in the file. + + Returns: + The token for the hugging face API. + """ + token = os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) + if token is None: + if not Path(constants.HF_TOKEN_PATH).exists(): + raise ValueError( + f"To use `{cls_name}` an API key must be provided via" + f" `{token_arg}`, set the environment variable" + f" `{_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME}` or use the `huggingface-hub` CLI to login" + " with `huggingface-cli login`." + ) + with open(constants.HF_TOKEN_PATH) as f: + token = f.read().strip() + return token