Skip to content

Commit

Permalink
Gather HF_TOKEN internally when calling Distiset.push_to_hub if tok…
Browse files Browse the repository at this point in the history
…en is None. (#707)

* Add a way to automatically gather the HF_TOKEN when calling distiset.push_to_hub and mode constant value to distilabel.utils module

* Update src/distilabel/distiset.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Refactor function to obtain huggingface token and move it to it's module

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Jun 11, 2024
1 parent 1d53ee8 commit a0d7e93
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 13 deletions.
7 changes: 7 additions & 0 deletions src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
size_categories_parser,
)
from distilabel.utils.files import list_files_in_dir
from distilabel.utils.huggingface import get_hf_token

DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs"
PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml"
Expand Down Expand Up @@ -81,7 +82,13 @@ def push_to_hub(
Whether to generate a dataset card or not. Defaults to True.
**kwargs:
Additional keyword arguments to pass to the `push_to_hub` method of the `datasets.Dataset` object.
Raises:
ValueError: If no token is provided and couldn't be retrieved automatically.
"""
if token is None:
token = get_hf_token(self.__class__.__name__, "token")

for name, dataset in self.items():
dataset.push_to_hub(
repo_id=repo_id,
Expand Down
18 changes: 5 additions & 13 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import random
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union

from pydantic import (
Expand All @@ -33,6 +32,10 @@
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, Grammar, StandardInput
from distilabel.utils.huggingface import (
_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME,
get_hf_token,
)
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand All @@ -41,9 +44,6 @@
from transformers import PreTrainedTokenizer


_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME = "HF_TOKEN"


class InferenceEndpointsLLM(AsyncLLM):
"""InferenceEndpoints LLM implementation running the async API client.
Expand Down Expand Up @@ -207,7 +207,6 @@ def load(self) -> None: # noqa: C901
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
constants,
get_inference_endpoint,
)
except ImportError as ie:
Expand All @@ -217,14 +216,7 @@ def load(self) -> None: # noqa: C901
) from ie

if self.api_key is None:
if not Path(constants.HF_TOKEN_PATH).exists():
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via"
" `api_key` attribute or runtime parameter, set the environment variable"
f" `{self._api_key_env_var}` or use the `huggingface-hub` CLI to login"
" with `huggingface-cli login`."
)
self.api_key = SecretStr(open(constants.HF_TOKEN_PATH).read().strip())
self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key"))

if self.model_id is not None:
client = InferenceClient()
Expand Down
53 changes: 53 additions & 0 deletions src/distilabel/utils/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
from typing import Final

from huggingface_hub import constants

_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME: Final[str] = "HF_TOKEN"


def get_hf_token(cls_name: str, token_arg: str) -> str:
"""Get the token for the hugging face API.
Tries to extract it from the environment variable, if it is not found
it tries to read it from the file using 'huggingface_hub',
and if not possible raises a ValueError.
Args:
cls_name: Name of the class/function that requires the token.
token_arg: Argument name to use in the error message, normally
is "token" or "api_key".
Raises:
ValueError: If the token is not found in the file.
Returns:
The token for the hugging face API.
"""
token = os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
if token is None:
if not Path(constants.HF_TOKEN_PATH).exists():
raise ValueError(
f"To use `{cls_name}` an API key must be provided via"
f" `{token_arg}`, set the environment variable"
f" `{_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME}` or use the `huggingface-hub` CLI to login"
" with `huggingface-cli login`."
)
with open(constants.HF_TOKEN_PATH) as f:
token = f.read().strip()
return token

0 comments on commit a0d7e93

Please sign in to comment.