Skip to content

Commit

Permalink
Merge branch 'develop' into cache-per-step
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Oct 7, 2024
2 parents 95f7618 + d5c0484 commit 333e346
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ llama-cpp = ["llama-cpp-python >= 0.2.0"]
mistralai = ["mistralai >= 1.0.0"]
ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40"]
outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
Expand All @@ -99,7 +99,7 @@ faiss-gpu = ["faiss-gpu >= 1.7.2"]
text-clustering = [
"umap-learn >= 0.5.6",
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3" # For the figure (even though it's optional)
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]

# minhash
Expand Down
3 changes: 1 addition & 2 deletions scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ python -m pip install uv
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"

if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
uv pip install --system -e .[ray]
fi

./scripts/install_cpu_vllm.sh
uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git

uv pip install --system -e ".[dev,tests]"
13 changes: 11 additions & 2 deletions src/distilabel/steps/tasks/argilla_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,13 @@ def format_input(
"""Format the input into a chat message.
Args:
input (Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]): The input to format.
input: The input to format.
Returns:
ChatType: The formatted chat message.
The formatted chat message.
Raises:
ValueError: If question or fields are not provided.
"""
input_keys = list(self.inputs.keys())
record = input[input_keys[0]]
Expand All @@ -394,6 +397,11 @@ def format_input(
examples = input.get(input_keys[3], self.example_records)
guidelines = input.get(input_keys[4], self.guidelines)

if question is None:
raise ValueError("Question must be provided.")
if fields is None or any(field is None for field in fields):
raise ValueError("Fields must be provided.")

record = record.to_dict() if not isinstance(record, dict) else record
question = question.serialize() if not isinstance(question, dict) else question
fields = [
Expand All @@ -416,6 +424,7 @@ def format_input(
if examples
else False
)

prompt = self._template.render(
fields=formatted_fields,
question=formatted_question,
Expand Down

0 comments on commit 333e346

Please sign in to comment.