Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StructuredGeneration task and support for grammar in InferenceEndpointsLLM #680

Merged
merged 13 commits into from
Jun 3, 2024

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented May 29, 2024

Description

This PR adds the StructuredGeneration task, similarly to the TextGeneration one, but also expecting the input grammar and producing both the chat-like input and the grammar within the format_input. In order to achieve that, the typing has been updated / modified in Task.process and also LLM.generate so that the received input/s contain either only the chat for most of the cases, or the chat and the grammar for the StructuredGeneration case.

Note

This is still a work in progress and subject to changes, but at the moment this seems the most straight forward / intuitive way to do so.

Additionally, this PR adds the grammar arg within InferenceEndpointsLLM so that it can be provided via the init or via runtime parameter.

Note

The main difference between using the grammar arg compared to using the StructuredGeneration task relies on the fact that the grammar arg is intended to be used with any task, whenever we want the output to match a certain format e.g. in UltraFeedback we may want the output to match a certain regex to avoid output parsing issues; while on the other hand, the task StructuredGeneration is intended when we have different grammars per row and we want to generate an output for the given instruction based on a different grammar per row, e.g. a function calling scenario where we want each generation for a given instruction to match a certain function schema.

Examples

grammar at LLM-level (same grammar for every generation)

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
from pydantic import BaseModel


class Character(BaseModel):
    name: str
    description: str
    role: str
    weapon: str


with Pipeline(name="inference-endpoints-structured-generation") as pipeline:
    load_data = LoadDataFromDicts(
        name="load_data",
        data=[{"instruction": "Generate a character from a RPG game."}],
    )

    text_generation_cohere = TextGeneration(
        name="text_generation_cohere",
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
            api_key="***",  # type: ignore
            grammar={
                "type": "json",
                "value": Character.model_json_schema(),
            },
        ),
        use_system_prompt=False,
        input_batch_size=10,
        output_mappings={"model_name": "generation_model"},
    )

    load_data >> text_generation_cohere  # type: ignore


if __name__ == "__main__":
    distiset = pipeline.run(
        parameters={  # type: ignore
            text_generation_cohere.name: {
                "llm": {
                    "generation_kwargs": {
                        "temperature": 0.7,
                        "max_new_tokens": 4096,
                        "stop_sequences": ["<EOS_TOKEN>", "<|END_OF_TURN_TOKEN|>"],
                    }
                }
            },
        },
    )
    if distiset is not None:
        distiset.push_to_hub(
            "distilabel-internal-testing/inference-endpoints-structured-generation",
            token="***",
        )

grammar via StructuredGeneration (one grammar` per row)

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from pydantic import BaseModel


class Character(BaseModel):
    name: str
    description: str
    role: str
    weapon: str


class Animal(BaseModel):
    name: str
    species: str
    habitat: str
    diet: str


with Pipeline(name="inference-endpoints-structured-generation") as pipeline:
    load_data = LoadDataFromDicts(
        name="load_data",
        data=[
            {
                "instruction": "Generate a character from a RPG game.",
                "grammar": {
                    "type": "json",
                    "value": Character.model_json_schema(),
                },
            },
            {
                "instruction": "Generate an animal from a zoo.",
                "grammar": {
                    "type": "json",
                    "value": Animal.model_json_schema(),
                },
            },
            {
                "instruction": "What's the weather like today in Seattle in Celsius degrees?",
                "grammar": {
                    "type": "regex",
                    "value": "(\\d{1,2})°C",
                },
            },
        ],
    )

    task = StructuredGeneration(
        name="task",
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
            api_key="***",  # type: ignore
        ),
        use_system_prompt=False,
        output_mappings={"model_name": "generation_model"},
    )

    load_data >> task  # type: ignore


if __name__ == "__main__":
    distiset = pipeline.run(
        parameters={  # type: ignore
            task.name: {
                "llm": {
                    "generation_kwargs": {
                        "temperature": 0.7,
                        "max_new_tokens": 4096,
                        "stop_sequences": ["<EOS_TOKEN>", "<|END_OF_TURN_TOKEN|>"],
                    }
                }
            },
        },
    )
    if distiset is not None:
        distiset.push_to_hub(
            "distilabel-internal-testing/inference-endpoints-structured-generation-multiple",
            token="***",
        )

- Now the `generate` method in the `LLM` can receive either a chat or a tuple with the chat and the grammar for that chat
- `grammar` is an arg at `LLM` level
- The `grammar` can be specified per row via the `StructuredGeneration`, while when specifying a global `grammar` then the `grammar` arg within the `LLM` can be used via the `TextGeneration` task instead
@alvarobartt alvarobartt added this to the 1.2.0 milestone May 29, 2024
@alvarobartt alvarobartt self-assigned this May 29, 2024
@alvarobartt alvarobartt linked an issue May 29, 2024 that may be closed by this pull request
Copy link

codspeed-hq bot commented May 31, 2024

CodSpeed Performance Report

Merging #680 will not alter performance

Comparing inference-endpoints-structured-gen (e7399d1) with develop (1624b1e)

Summary

✅ 1 untouched benchmarks

@alvarobartt alvarobartt marked this pull request as ready for review June 3, 2024 08:32
@plaguss plaguss merged commit 918c19f into develop Jun 3, 2024
7 checks passed
@plaguss plaguss deleted the inference-endpoints-structured-gen branch June 3, 2024 11:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[FEATURE] Add structured generation for InferenceEndpointsLLM
2 participants