-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update argilla
integration to use argilla_sdk
v2
#705
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice start. Just a few high level comments.
Co-authored-by: Ben Burtenshaw <[email protected]>
For the moment it's being installed as `pip install git+https://github.com/argilla-io/argilla-python.git@main`
Edit: the issue was with the Argilla Server version as I was using 1.26.0 while 1.27.0 or higher was required 👍🏻
Install as from uuid import uuid4
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps import (
LoadDataFromDicts,
TextGenerationToArgilla,
)
from distilabel.steps.tasks import TextGeneration
if __name__ == "__main__":
with Pipeline(name="my-pipeline") as pipeline:
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=[
{
"instruction": "Write a short story about a dragon that saves a princess from a tower.",
},
],
)
text_generation = TextGeneration(
name="text_generation",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
api_key="...", # type: ignore
),
num_generations=4,
group_generations=True,
)
text_generation_to_argilla = TextGenerationToArgilla(
name="text_generation_to_argilla",
api_url="...",
api_key="...", # type: ignore
dataset_name=f"text-generation-{uuid4()}",
dataset_workspace="admin",
)
( # type: ignore
load_dataset
>> text_generation
>> text_generation_to_argilla
)
pipeline.run(
parameters={
text_generation.name: { # type: ignore
"llm": {
"generation_kwargs": {
"max_new_tokens": 512,
"temperature": 0.7,
},
},
},
}
) The logs then look like: |
Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-705/ |
CodSpeed Performance ReportMerging #705 will not alter performanceComparing Summary
|
Description
This PR renames and updates
Argilla
toArgillaBase
, since now the client inargilla_sdk
(later to be renamed toargilla
only as per a recent discussion with @frascuchon) is namedArgilla
too. Besides that, the code has been updated to use the latest Python client instead not only forArgillaBase
but also for the subclassesTextGenerationToArgilla
andPreferenceToArgilla
.Warning
This change here implies that the
argilla
server version should be 1.27.0 or higher, otherwise theargilla_sdk
won't work.Closes argilla-io/argilla#4880