Skip to content

Commit

Permalink
Fix ProcessLLM deadlock when used in LLMPool
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Dec 20, 2023
1 parent ac492ee commit fec1dca
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
59 changes: 49 additions & 10 deletions src/distilabel/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,13 @@ class _TextGenerationResult:
"""An object used to transfer the text generation results from the `_GenerationProcess`
to the `_BridgeThread`."""

def __init__(self, generations) -> None:
def __init__(
self,
generations: Union[List[List["LLMOutput"]], None] = None,
exception: Union[Exception, None] = None,
) -> None:
self.generations = generations
self.exception = exception


class _GenerationProcess(mp.Process):
Expand Down Expand Up @@ -314,13 +319,27 @@ def run(self) -> None:

# Perform generation
logger.debug(f"Process with '{name}' received request...")
results = llm.generate(
inputs=request.inputs, num_generations=request.num_generations
)
try:
raise Exception("testttt")
generations = llm.generate(
inputs=request.inputs, num_generations=request.num_generations
)
except Exception as e:
logger.error(
f"Process with '{name}' failed to perform generation with error: {e}"
)
generations = e

generations = results.result() if isinstance(results, Future) else results
if isinstance(generations, Exception):
text_generation_result = _TextGenerationResult(exception=generations)
elif isinstance(generations, Future):
text_generation_result = _TextGenerationResult(
generations=generations.result()
)
else:
text_generation_result = _TextGenerationResult(generations=generations)

self._result_queue.put(_TextGenerationResult(generations))
self._result_queue.put(text_generation_result)

def stop(self) -> None:
"""Stops the infinite loop of the generation process."""
Expand Down Expand Up @@ -361,7 +380,7 @@ def __init__(self, process_llm: "ProcessLLM") -> None:

self._model_name = process_llm._model_name

super().__init__()
super().__init__(daemon=True)

def _wait_llm_loaded(self) -> None:
"""Waits for the generation process to load the `LLM`."""
Expand Down Expand Up @@ -407,10 +426,19 @@ def _process_request(self) -> bool:
self._call_generation_process(tg_request)

# Get the text generation result from the child process
generation_result = self._get_result_generation_process()
logger.debug(
f"Bridge thread waiting for generation result with request id {text_generation_request_id}..."
)
generation_result = self._result_queue.get()
if generation_result == -1:
return True

# Set the result of the text generation request
tg_request.future.set_result(generation_result.generations)
if generation_result.exception is not None:
# Set the exception of the text generation request
tg_request.future.set_exception(generation_result.exception)
else:
# Set the result of the text generation request
tg_request.future.set_result(generation_result.generations)

return False

Expand All @@ -426,9 +454,15 @@ def run(self) -> None:
if should_stop:
break

logger.debug("Bridge thread stopped!")

def stop(self) -> None:
"""Stops the infinite loop of the bridge thread."""
self._text_generation_request_ids_queue.put(-1)
# This is for making sure that if the bridge thread has sent a request to the
# generation process, and the generation process is stopped before sending the
# result, the bridge thread will not get blocked waiting for the result.
self._result_queue.put(-1)


class ProcessLLM:
Expand Down Expand Up @@ -488,8 +522,12 @@ def _start_bridge_thread(self) -> None:
if self._bridge_thread is None:
self._generation_process = _GenerationProcess(self)
self._generation_process.start()
pid = self._generation_process.pid
logger.debug(f"Generation process with PID {pid} started!")

self._bridge_thread = _BridgeThread(self)
self._bridge_thread.start()
logger.debug("Bridge thread for process with PID {pid} started!")

def _add_text_generation_request(
self,
Expand Down Expand Up @@ -577,6 +615,7 @@ class LLMPool:
- If `num_generations` is less than the number of `LLM`s, then `num_generations` LLMs
will be chosen randomly and each of them will perform 1 generation.
- If `num_generations` is equal to the number of `LLM`s, then each `LLM` will perform
1 generation.
Expand Down
23 changes: 1 addition & 22 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,28 +209,7 @@ def _get_batch_generations(
)
batch_generations = []
if isinstance(outputs, Future):
try:
# Result of future is `List[List[LLMOutput]]` (first list contains `batch_size`
# elements, and the second list contains `num_generations` elements)
batch_generations.extend(outputs.result())
except Exception as e:
logger.error(
f"An error occured when getting the result from the generator: {e}"
)
batch_generations.extend(
[
[
LLMOutput(
model_name=self.generator.model_name,
prompt_used=None,
raw_output=None,
parsed_output=None,
)
for _ in range(num_generations)
]
for _ in range(num_batches)
]
)
batch_generations.extend(outputs.result())
else:
batch_generations = outputs
return self._process_batch_generations(
Expand Down

0 comments on commit fec1dca

Please sign in to comment.