Skip to content

Commit

Permalink
Add python==3.12 (#615)
Browse files Browse the repository at this point in the history
* Add Python 3.12

* Add `install_dependencies.sh` script

* Update to `ruff==0.4.5`

* Apply format

* Update commands

* Update to `argilla >= 1.29.0`

* Update to setup tmate in 3.12

* Update `vllm` dependency

* Use `uv` to install dependencies

* Update dependencies

* Fix regex message for 3.12
  • Loading branch information
gabrielmbmb authored May 31, 2024
1 parent ac41e7f commit 42efe6d
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 36 deletions.
13 changes: 3 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false

steps:
Expand All @@ -48,17 +48,10 @@ jobs:

- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
python_version=$(python -c "import sys; print(sys.version_info[:2])")
pip install -e .[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,vllm]
if [ "${python_version}" != "(3, 8)" ]; then
pip install -e .[mistralai,instructor]
fi;
pip install git+https://github.com/argilla-io/LLM-Blender.git
run: ./scripts/install_dependencies.sh

- name: Setup tmate session
if: ${{ github.event_name == 'workflow_dispatch' && matrix.python-version == '3.11' && github.event.inputs.tmate_session == 'true' }}
if: ${{ github.event_name == 'workflow_dispatch' && matrix.python-version == '3.12' && github.event.inputs.tmate_session == 'true' }}
uses: mxschmitt/action-tmate@v3

- name: Lint
Expand Down
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ repos:
- --fuzzy-match-generates-todo

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.4
rev: v0.4.5
hooks:
- id: ruff
args:
- --fix
args: [--fix]
- id: ruff-format

ci:
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ sources = src/distilabel tests

.PHONY: format
format:
ruff --fix $(sources)
ruff check --fix $(sources)
ruff format $(sources)

.PHONY: lint
lint:
ruff $(sources)
ruff check $(sources)
ruff format --check $(sources)

.PHONY: unit-tests
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand Down Expand Up @@ -45,7 +46,7 @@ distilabel = "distilabel.cli.app:app"
"distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin"

[project.optional-dependencies]
dev = ["ruff == 0.2.2", "pre-commit >= 3.5.0"]
dev = ["ruff == 0.4.5", "pre-commit >= 3.5.0"]
docs = [
"mkdocs-material >= 9.5.0",
"mkdocstrings[python] >= 0.24.0",
Expand All @@ -61,7 +62,7 @@ tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"]

# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
argilla = ["argilla >= 1.23.0"]
argilla = ["argilla >= 1.29.0"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.19.0"]
Expand All @@ -74,7 +75,7 @@ ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = ["vllm >= 0.2.1", "filelock >= 3.13.4"]
vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Expand Down
12 changes: 12 additions & 0 deletions scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

python_version=$(python -c "import sys; print(sys.version_info[:2])")

python -m pip install uv

uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai]"
if [ "${python_version}" != "(3, 8)" ]; then
uv pip install --system -e .[mistralai,instructor]
fi

uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git
6 changes: 2 additions & 4 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def load(self) -> None:

@property
@abstractmethod
def inputs(self) -> List[str]:
...
def inputs(self) -> List[str]: ...

@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput":
...
def process(self, *inputs: StepInput) -> "StepOutput": ...
9 changes: 3 additions & 6 deletions src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,15 @@ def _set_routing_batch_function(
routing_batch_function._step = self

@overload
def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction":
...
def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": ...

@overload
def __rshift__(
self, other: List["DownstreamConnectableSteps"]
) -> List["DownstreamConnectableSteps"]:
...
) -> List["DownstreamConnectableSteps"]: ...

@overload
def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable":
...
def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": ...

def __rshift__(
self,
Expand Down
9 changes: 3 additions & 6 deletions src/distilabel/steps/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,23 @@ def step(
inputs: Union[List[str], None] = None,
outputs: Union[List[str], None] = None,
step_type: Literal["normal"] = "normal",
) -> Callable[..., Type["Step"]]:
...
) -> Callable[..., Type["Step"]]: ...


@overload
def step(
inputs: Union[List[str], None] = None,
outputs: Union[List[str], None] = None,
step_type: Literal["global"] = "global",
) -> Callable[..., Type["GlobalStep"]]:
...
) -> Callable[..., Type["GlobalStep"]]: ...


@overload
def step(
inputs: None = None,
outputs: Union[List[str], None] = None,
step_type: Literal["generator"] = "generator",
) -> Callable[..., Type["GeneratorStep"]]:
...
) -> Callable[..., Type["GeneratorStep"]]: ...


def step(
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/steps/argilla/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import sys
from typing import TYPE_CHECKING, List

import pytest
Expand Down Expand Up @@ -83,7 +84,9 @@ def test_with_errors(self, caplog) -> None:

with pytest.raises(
TypeError,
match="Can't instantiate abstract class Argilla with abstract methods inputs, process",
match="Can't instantiate abstract class Argilla with abstract methods inputs, process"
if sys.version_info < (3, 12)
else "Can't instantiate abstract class Argilla without an implementation for abstract methods 'inputs', 'process'",
):
Argilla(name="step", pipeline=Pipeline(name="unit-test-pipeline")) # type: ignore

Expand Down
5 changes: 4 additions & 1 deletion tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from dataclasses import field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -77,7 +78,9 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:

with pytest.raises(
TypeError,
match="Can't instantiate abstract class Task with abstract methods format_input, format_output",
match="Can't instantiate abstract class Task with abstract methods format_input, format_output"
if sys.version_info < (3, 12)
else "Can't instantiate abstract class Task without an implementation for abstract methods 'format_input', 'format_output'",
):
Task(name="task", llm=DummyLLM()) # type: ignore

Expand Down

0 comments on commit 42efe6d

Please sign in to comment.