Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement "Improving Text Embeddings with LLMs" #683

Merged
merged 19 commits into from
Jun 12, 2024

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented May 30, 2024

Description

This PR implements all the tasks mentioned in the paper Improving Text Embeddings with Large Language Models, so that one can reproduce the data generation process for training embedding models with sentence-transformers.

Closes #682

Example

Find a complete example below with all the tasks implemented and how to connect them:

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks.improving_text_embeddings import (
    BitextRetrievalGenerator,
    EmbeddingTaskGenerator,
    GenerateLongTextMatchingData,
    GenerateShortTextMatchingData,
    GenerateTextClassificationData,
    GenerateTextRetrievalData,
    MonolingualTripletGenerator,
)

with Pipeline(name="improving-text-embeddings-with-llms") as pipeline:
    brainstorm_retrieval = EmbeddingTaskGenerator(
        category="text-retrieval",
        flatten_tasks=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        num_generations=1,
        group_generations=True,
        output_mappings={"model_name": "brainstorm_model"},
    )

    generate_retrieval = GenerateTextRetrievalData(
        add_raw_output=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "generation_model"},
    )

    brainstorm_retrieval >> generate_retrieval  # type: ignore

    brainstorm_classification = EmbeddingTaskGenerator(
        category="text-classification",
        flatten_tasks=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        num_generations=1,
        group_generations=True,
        output_mappings={"model_name": "brainstorm_model"},
    )

    generate_classification = GenerateTextClassificationData(
        add_raw_output=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "generation_model"},
    )

    brainstorm_classification >> generate_classification  # type: ignore

    brainstorm_matching_short = EmbeddingTaskGenerator(
        category="text-matching-short",
        flatten_tasks=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        num_generations=1,
        group_generations=True,
        output_mappings={"model_name": "brainstorm_model"},
    )

    generate_matching_short = GenerateShortTextMatchingData(
        add_raw_output=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "generation_model"},
    )

    brainstorm_matching_short >> generate_matching_short  # type: ignore

    brainstorm_matching_long = EmbeddingTaskGenerator(
        category="text-matching-long",
        flatten_tasks=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        num_generations=1,
        group_generations=True,
        output_mappings={"model_name": "brainstorm_model"},
    )

    generate_matching_long = GenerateLongTextMatchingData(
        add_raw_output=True,
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "generation_model"},
    )

    brainstorm_matching_long >> generate_matching_long  # type: ignore

    bitext_retrieval_generator = BitextRetrievalGenerator(
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "bitext_model"},
    )

    monolingual_triplet_generator = MonolingualTripletGenerator(
        llm=InferenceEndpointsLLM(
            model_id="CohereForAI/c4ai-command-r-plus",
            tokenizer_id="CohereForAI/c4ai-command-r-plus",
        ),
        output_mappings={"model_name": "monolingual_model"},
    )


if __name__ == "__main__":
    distiset = pipeline.run(
        parameters={
            step_name: {
                "llm": {
                    "generation_kwargs": {
                        "temperature": 0.7,
                        "max_new_tokens": 4096,
                        "stop_sequences": ["<EOS_TOKEN>", "<|END_OF_TURN_TOKEN|>"],
                    }
                }
            }
            for step_name in pipeline.dag
        },
    )
    if distiset is not None:
        distiset.push_to_hub(
            "distilabel-internal-testing/alvarobartt-improving-text-embeddings-with-llms-full",
        )

What's missing?

  • Add docstrings for the implemented tasks
  • Add unit tests for the implemented tasks
  • Improve structuring to avoid code duplication
  • Move the templates to separate files rather that having those as plain strings
  • Fix the naming (cc @gabrielmbmb @plaguss for help)
  • And run a couple more experiments using the structured_output arg within the InferenceEndpointsLLM as an example

@alvarobartt alvarobartt added this to the 1.2.0 milestone May 30, 2024
@alvarobartt alvarobartt self-assigned this May 30, 2024
Copy link

codspeed-hq bot commented May 31, 2024

CodSpeed Performance Report

Merging #683 will not alter performance

Comparing improving-text-embeddings-with-llms (3c97218) with develop (a0d7e93)

Summary

✅ 1 untouched benchmarks

@alvarobartt alvarobartt marked this pull request as ready for review June 4, 2024 08:47
Copy link
Contributor

@plaguss plaguss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alvarobartt, I'm fine with the naming. Maybe I would prefer moving the prompts to jinja templates as we have with other cases, but looks good to me anyway!

@alvarobartt
Copy link
Member Author

@alvarobartt, I'm fine with the naming. Maybe I would prefer moving the prompts to jinja templates as we have with other cases, but looks good to me anyway!

Yes, see the ## What's missing? section in the PR description to see what's missing other than the naming 🙂

@alvarobartt alvarobartt changed the title [WIP] Implement "Improving Text Embeddings with LLMs" Implement "Improving Text Embeddings with LLMs" Jun 7, 2024
@alvarobartt alvarobartt linked an issue Jun 11, 2024 that may be closed by this pull request
@alvarobartt alvarobartt merged commit 0e8c752 into develop Jun 12, 2024
7 checks passed
@alvarobartt alvarobartt deleted the improving-text-embeddings-with-llms branch June 12, 2024 08:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[FEATURE] Implement "Improving Text Embeddings with LLMs"
2 participants