diff --git a/pyproject.toml b/pyproject.toml index 44404c683..adab8d4fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ llama-cpp = ["llama-cpp-python >= 0.2.0"] mistralai = ["mistralai >= 1.0.0"] ollama = ["ollama >= 0.1.7"] openai = ["openai >= 1.0.0"] -outlines = ["outlines >= 0.0.40"] +outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"] ray = ["ray[default] >= 2.31.0"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] vllm = [ @@ -99,7 +99,7 @@ faiss-gpu = ["faiss-gpu >= 1.7.2"] text-clustering = [ "umap-learn >= 0.5.6", "scikit-learn >= 1.4.1", - "matplotlib >= 3.8.3" # For the figure (even though it's optional) + "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] # minhash diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh index 767f6e6dd..0b2277f0f 100755 --- a/scripts/install_dependencies.sh +++ b/scripts/install_dependencies.sh @@ -9,10 +9,9 @@ python -m pip install uv uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]" if [ "${python_version}" != "(3, 12)" ]; then - uv pip install --system -e .[ray] + uv pip install --system -e .[ray] fi ./scripts/install_cpu_vllm.sh -uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git uv pip install --system -e ".[dev,tests]" diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py index dd4522813..d0874ed3d 100644 --- a/src/distilabel/steps/tasks/argilla_labeller.py +++ b/src/distilabel/steps/tasks/argilla_labeller.py @@ -382,10 +382,13 @@ def format_input( """Format the input into a chat message. Args: - input (Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]): The input to format. + input: The input to format. Returns: - ChatType: The formatted chat message. + The formatted chat message. + + Raises: + ValueError: If question or fields are not provided. """ input_keys = list(self.inputs.keys()) record = input[input_keys[0]] @@ -394,6 +397,11 @@ def format_input( examples = input.get(input_keys[3], self.example_records) guidelines = input.get(input_keys[4], self.guidelines) + if question is None: + raise ValueError("Question must be provided.") + if fields is None or any(field is None for field in fields): + raise ValueError("Fields must be provided.") + record = record.to_dict() if not isinstance(record, dict) else record question = question.serialize() if not isinstance(question, dict) else question fields = [ @@ -416,6 +424,7 @@ def format_input( if examples else False ) + prompt = self._template.render( fields=formatted_fields, question=formatted_question,