Skip to content

Commit

Permalink
Add TogetherInferenceLLM (#215)
Browse files Browse the repository at this point in the history
* Add `TogetherInferenceLLM` and `_TOGETHER_AVAILABLE_FLAG`

* Update docstrings of `TogetherInferenceLLM`

* Add `TogetherInferenceLLM` in `distilabel.llm` init

* Access `TogetherInferenceLLM` output via dict

* Fix bug affecting `TextGenerationTask` in `_to_argilla_record`

* Add `TogetherInferenceLLM` documentation

* Add `examples/pipeline-together-inference.py`

* Update `model` argument docstring

Co-authored-by: Agus <[email protected]>

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
Co-authored-by: Agus <[email protected]>
  • Loading branch information
3 people authored Jan 5, 2024
1 parent 3160a09 commit 53ca00c
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ In addition, the following extras are available:
- `openai`: for using OpenAI API models via the `OpenAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`.
- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client.
- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/).

## Example
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ In addition, the following extras are available:
- `openai`: for using OpenAI API models via the `OpenAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`.
- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client.
- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/).

## Quick example
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from distilabel.tasks import TextGenerationTask
from distilabel.llm import TogetherInferenceLLM

llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-70b-chat",
task=TextGenerationTask(),
max_new_tokens=512,
temperature=0.3,
prompt_format="llama2",
)
output = llm.generate(
[{"input": "Explain me the theory of relativity as if you were a pirate."}]
)
# >>> print(result[0][0]["parsed_output"]["generations"])
# Ahoy matey! Yer lookin' fer a tale of the theory of relativity, eh? Well,
# settle yerself down with a pint o' grog and listen close, for this be a story
# of the sea of time and space!
# Ye see, matey, the theory of relativity be tellin' us that time and space ain't
# fixed things, like the deck o' a ship or the stars in the sky. Nay, they be like
# the ocean itself, always changin' and flowin' like the tides.
# Now, imagine ...
16 changes: 13 additions & 3 deletions docs/technical-reference/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ In this section we will see what's an `LLM` and the different `LLM`s implementat

The [`LLM`][distilabel.llm.base.LLM] class encapsulates the functionality for interacting with a large language model.

It distinguishes between *task* specifications and configurable parameters that influence the LLM's behavior.
It distinguishes between *task* specifications and configurable parameters that influence the LLM behavior.

For illustration purposes, we employ the [`TextGenerationTask`][distilabel.tasks.text_generation.base.TextGenerationTask] in this section and guide you to the dedicated [`Tasks`](../technical-reference/tasks.md) section for comprehensive details.

Expand All @@ -28,7 +28,7 @@ Let's briefly introduce the general parameters we may find[^1]:

- `top_k` and `top_p`: `top_k` limits the number of tokens the model is allowed to use to generate the following token sorted by probability, while `top_p` limits the number of tokens the model can use for the next token, but in terms of the sum of their probabilities.

- `frequency_penalty` and `presence_penalty`: the frequency penalty penalizes tokens that have already appeard in the generated text, limiting the possibility of those appearing again, and the `presence_penalty` penalizes regardless of hte frequency.
- `frequency_penalty` and `presence_penalty`: the frequency penalty penalizes tokens that have already appeared in the generated text, limiting the possibility of those appearing again, and the `presence_penalty` penalizes regardless of the frequency.

- `prompt_format` and `prompt_formatting_fn`: these two parameters allow to tweak the prompt of our models, for example we can direct the `LLM` to format the prompt according to one of the defined formats, while `prompt_formatting_fn` allows to pass a function that will be applied to the prompt before the generation, for extra control of what we ingest to the model.

Expand Down Expand Up @@ -160,6 +160,17 @@ Let's see how to interact with these LLMs:
--8<-- "docs/snippets/technical-reference/llm/inference_endpoint_generate.py"
```

### Together Inference

Together offers a product named Together Inference, which exposes some models for diverse tasks such as chat, text generation, code, or image; exposing those via an endpoint within their API either as serverless endpoints or as dedicated instances.

See their release post with more details at [Announcing Together Inference Engine – the fastest inference available](https://www.together.ai/blog/together-inference-engine-v1).


```python
--8<-- "docs/snippets/technical-reference/llm/together_inference_generate.py"
```

## `ProcessLLM` and `LLMPool`

By default, `distilabel` uses a single process, so the generation loop is usually bottlenecked by the model inference time and Python GIL. To overcome this limitation, we provide the `ProcessLLM` class that allows to load an `LLM` in a different process, avoiding the GIL and allowing to parallelize the generation loop. Creating a `ProcessLLM` is easy as:
Expand All @@ -176,4 +187,3 @@ You can directly use a `ProcessLLM` as the `generator` or `labeller` in a `Pipel
```python
--8<-- "docs/snippets/technical-reference/llm/llmpool.py"
```

74 changes: 74 additions & 0 deletions examples/pipeline-together-inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

from datasets import Dataset
from distilabel.llm import TogetherInferenceLLM
from distilabel.pipeline import Pipeline
from distilabel.tasks import TextGenerationTask

if __name__ == "__main__":
dataset = Dataset.from_dict(
{
"input": ["Explain me the theory of relativity as if you were a pirate."],
}
)

llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-70b-chat",
api_key=os.getenv("TOGETHER_API_KEY", None),
task=TextGenerationTask(),
prompt_format="llama2",
)
pipeline = Pipeline(generator=llm)

start = time.time()
dataset = pipeline.generate(
dataset=dataset,
shuffle_before_labelling=False,
num_generations=2,
skip_dry_run=True,
display_progress_bar=False,
) # type: ignore
end = time.time()
print("Elapsed", end - start)

# Push to the HuggingFace Hub
dataset.push_to_hub(
os.getenv("HF_REPO_ID"), # type: ignore
split="train",
private=True,
token=os.getenv("HF_TOKEN", None),
)

try:
from uuid import uuid4

import argilla as rg

rg.init(
api_url=os.getenv("ARGILLA_API_URL"),
api_key=os.getenv("ARGILLA_API_KEY"),
)

# Convert into an Argilla dataset and push it to Argilla
rg_dataset = dataset.to_argilla()
rg_dataset.push_to_argilla(
name=f"my-dataset-{uuid4()}",
workspace="admin",
)
except ImportError:
pass
2 changes: 2 additions & 0 deletions src/distilabel/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distilabel.llm.huggingface.transformers import TransformersLLM
from distilabel.llm.llama_cpp import LlamaCppLLM
from distilabel.llm.openai import OpenAILLM
from distilabel.llm.together import TogetherInferenceLLM
from distilabel.llm.vllm import vLLM

__all__ = [
Expand All @@ -29,6 +30,7 @@
"InferenceEndpointsLLM",
"TransformersLLM",
"LlamaCppLLM",
"TogetherInferenceLLM",
"OpenAILLM",
"vLLM",
]
215 changes: 215 additions & 0 deletions src/distilabel/llm/together.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Union

from distilabel.llm.base import LLM
from distilabel.llm.utils import LLMOutput
from distilabel.logger import get_logger
from distilabel.utils.imports import _TOGETHER_AVAILABLE

if _TOGETHER_AVAILABLE:
import together

if TYPE_CHECKING:
from distilabel.tasks.base import Task
from distilabel.tasks.prompt import SupportedFormats


logger = get_logger()


class TogetherInferenceLLM(LLM):
def __init__(
self,
task: "Task",
model: str,
api_key: Union[str, None] = None,
max_new_tokens: int = 128,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 1,
stop: Union[List[str], None] = None,
logprobs: int = 0,
num_threads: Union[int, None] = None,
prompt_format: Union["SupportedFormats", None] = None,
prompt_formatting_fn: Union[Callable[..., str], None] = None,
) -> None:
"""Initializes the OpenAILLM class.
Args:
task (Task): the task to be performed by the LLM.
model (str): the model to be used for generation.
max_new_tokens (int, optional): the maximum number of tokens to be generated.
Defaults to 128.
temperature (float, optional): the temperature to be used for generation. From the Together
Inference docs: "A decimal number that determines the degree of randomness in the response.
A value of 0 will always yield the same output. A temperature much less than 1 favors more
correctness and is appropriate for question answering or summarization. A value approaching
1 introduces more randomness in the output.". Defaults to 1.0.
repetition_penalty (float, optional): the repetition penalty to be used for generation. From the
Together Inference docs: "Controls the diversity of generated text by reducing the likelihood
of repeated sequences. Higher values decrease repetition.". Defaults to 1.0.
top_p (float, optional): the top-p value to be used for generation. From the Together
Inference docs: "used to dynamically adjust the number of choices for each predicted
token based on the cumulative probabilities. It specifies a probability threshold,
below which all less likely tokens are filtered out. This technique helps to maintain
diversity and generate more fluent and natural-sounding text.". Defaults to 1.0.
top_k (int, optional): the top-k value to be used for generation. From the Together Inference
docs: "used to limit the number of choices for the next predicted word or token. It specifies
the maximum number of tokens to consider at each step, based on their probability of occurrence.
This technique helps to speed up the generation process and can improve the quality of the
generated text by focusing on the most likely options.". Defaults to 1.
stop (List[str], optional): strings to delimitate the generation process, so that when the
model generates any of the provided characters, the generation process is considered completed.
Defaults to None.
logprobs (int, optional): the number of logprobs to be returned for each token. From the
Together Inference docs: "An integer that specifies how many top token log probabilities
are included in the response for each token generation step.". Defaults to None.
num_threads (Union[int, None], optional): the number of threads to be used
for parallel generation. If `None`, no parallel generation will be performed.
Defaults to `None`.
prompt_format (Union[SupportedFormats, None], optional): the format to be used
for the prompt. If `None`, the default format of the task will be used, available
formats are `openai`, `chatml`, `llama2`, `zephyr`, and `default`. Defaults to `None`,
but `default` (concatenation of `system_prompt` and `formatted_prompt` with a line-break)
will be used if no `prompt_formatting_fn` is provided.
prompt_formatting_fn (Union[Callable[..., str], None], optional): a function to be
applied to the prompt before generation. If `None`, no formatting will be applied.
Defaults to `None`.
Raises:
AssertionError: if the provided `model` is not available in Together Inference.
Examples:
>>> from distilabel.tasks.text_generation import TextGenerationTask as Task
>>> from distilabel.llm import TogetherInferenceLLM
>>> task = Task()
>>> llm = TogetherInferenceLLM(model="togethercomputer/llama-2-7b", task=task, prompt_format="llama2")
"""
if not _TOGETHER_AVAILABLE:
raise ImportError(
"`TogetherInferenceLLM` cannot be used as `together` is not installed, please "
" install it with `pip install together`."
)

together.api_key = api_key or os.getenv("TOGETHER_API_KEY", None)
if together.api_key is None:
raise ValueError(
"No `api_key` provided, please provide one or set the `TOGETHER_API_KEY` "
"environment variable."
)

super().__init__(
task=task,
num_threads=num_threads,
prompt_format=prompt_format,
prompt_formatting_fn=prompt_formatting_fn,
)

assert (
model in self.available_models
), f"Provided `model` is not available in Together Inference, available models are {self.available_models}"
self.model = model

self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.stop = stop
self.logprobs = logprobs

def __rich_repr__(self) -> Generator[Any, None, None]:
yield from super().__rich_repr__()
yield (
"parameters",
{
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"top_p": self.top_p,
"top_k": self.top_k,
"stop": self.stop,
"logprobs": self.logprobs,
},
)

@cached_property
def available_models(self) -> List[str]:
"""Returns the list of available models in Together Inference."""
# TODO: exclude the image models
return [model["name"] for model in together.Models.list()]

@property
def model_name(self) -> str:
"""Returns the name of the Together Inference model."""
return self.model

def _generate(
self,
inputs: List[Dict[str, Any]],
num_generations: int = 1,
) -> List[List[LLMOutput]]:
"""Generates `num_generations` for each input in `inputs`.
Args:
inputs (List[Dict[str, Any]]): the inputs to be used for generation.
num_generations (int, optional): the number of generations to be performed for each
input. Defaults to 1.
Returns:
List[List[LLMOutput]]: the generated outputs.
"""
prompts = self._generate_prompts(inputs, default_format=None)
outputs = []
for prompt in prompts:
batch = []
for _ in range(num_generations):
output = together.Complete.create(
prompt=prompt,
model=self.model,
max_tokens=self.max_new_tokens,
stop=self.stop,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
logprobs=self.logprobs,
)
if output["output"]["choices"] is not None:
for choice in output["output"]["choices"]:
try:
parsed_response = self.task.parse_output(
choice["text"].strip()
)
except Exception as e:
logger.error(
f"Error parsing Together Inference response: {e}"
)
parsed_response = None
batch.append(
LLMOutput(
model_name=self.model_name,
prompt_used=prompt,
raw_output=choice["text"],
parsed_output=parsed_response,
)
)
if len(batch) > 0:
outputs.append(batch)
return outputs
Loading

0 comments on commit 53ca00c

Please sign in to comment.