diff --git a/src/distilabel/llm/base.py b/src/distilabel/llm/base.py index 1306984e2b..0bd9189bd5 100644 --- a/src/distilabel/llm/base.py +++ b/src/distilabel/llm/base.py @@ -247,8 +247,13 @@ class _TextGenerationResult: """An object used to transfer the text generation results from the `_GenerationProcess` to the `_BridgeThread`.""" - def __init__(self, generations) -> None: + def __init__( + self, + generations: Union[List[List["LLMOutput"]], None] = None, + exception: Union[Exception, None] = None, + ) -> None: self.generations = generations + self.exception = exception class _GenerationProcess(mp.Process): @@ -314,13 +319,26 @@ def run(self) -> None: # Perform generation logger.debug(f"Process with '{name}' received request...") - results = llm.generate( - inputs=request.inputs, num_generations=request.num_generations - ) + try: + generations = llm.generate( + inputs=request.inputs, num_generations=request.num_generations + ) + except Exception as e: + logger.error( + f"Process with '{name}' failed to perform generation with error: {e}" + ) + generations = e - generations = results.result() if isinstance(results, Future) else results + if isinstance(generations, Exception): + text_generation_result = _TextGenerationResult(exception=generations) + elif isinstance(generations, Future): + text_generation_result = _TextGenerationResult( + generations=generations.result() + ) + else: + text_generation_result = _TextGenerationResult(generations=generations) - self._result_queue.put(_TextGenerationResult(generations)) + self._result_queue.put(text_generation_result) def stop(self) -> None: """Stops the infinite loop of the generation process.""" @@ -361,7 +379,7 @@ def __init__(self, process_llm: "ProcessLLM") -> None: self._model_name = process_llm._model_name - super().__init__() + super().__init__(daemon=True) def _wait_llm_loaded(self) -> None: """Waits for the generation process to load the `LLM`.""" @@ -407,10 +425,19 @@ def _process_request(self) -> bool: self._call_generation_process(tg_request) # Get the text generation result from the child process - generation_result = self._get_result_generation_process() + logger.debug( + f"Bridge thread waiting for generation result with request id {text_generation_request_id}..." + ) + generation_result = self._result_queue.get() + if generation_result == -1: + return True - # Set the result of the text generation request - tg_request.future.set_result(generation_result.generations) + if generation_result.exception is not None: + # Set the exception of the text generation request + tg_request.future.set_exception(generation_result.exception) + else: + # Set the result of the text generation request + tg_request.future.set_result(generation_result.generations) return False @@ -426,9 +453,15 @@ def run(self) -> None: if should_stop: break + logger.debug("Bridge thread stopped!") + def stop(self) -> None: """Stops the infinite loop of the bridge thread.""" self._text_generation_request_ids_queue.put(-1) + # This is for making sure that if the bridge thread has sent a request to the + # generation process, and the generation process is stopped before sending the + # result, the bridge thread will not get blocked waiting for the result. + self._result_queue.put(-1) class ProcessLLM: @@ -488,8 +521,12 @@ def _start_bridge_thread(self) -> None: if self._bridge_thread is None: self._generation_process = _GenerationProcess(self) self._generation_process.start() + pid = self._generation_process.pid + logger.debug(f"Generation process with PID {pid} started!") + self._bridge_thread = _BridgeThread(self) self._bridge_thread.start() + logger.debug("Bridge thread for process with PID {pid} started!") def _add_text_generation_request( self, @@ -577,6 +614,7 @@ class LLMPool: - If `num_generations` is less than the number of `LLM`s, then `num_generations` LLMs will be chosen randomly and each of them will perform 1 generation. + - If `num_generations` is equal to the number of `LLM`s, then each `LLM` will perform 1 generation. diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 4309067807..308cbe3331 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -209,28 +209,7 @@ def _get_batch_generations( ) batch_generations = [] if isinstance(outputs, Future): - try: - # Result of future is `List[List[LLMOutput]]` (first list contains `batch_size` - # elements, and the second list contains `num_generations` elements) - batch_generations.extend(outputs.result()) - except Exception as e: - logger.error( - f"An error occured when getting the result from the generator: {e}" - ) - batch_generations.extend( - [ - [ - LLMOutput( - model_name=self.generator.model_name, - prompt_used=None, - raw_output=None, - parsed_output=None, - ) - for _ in range(num_generations) - ] - for _ in range(num_batches) - ] - ) + batch_generations.extend(outputs.result()) else: batch_generations = outputs return self._process_batch_generations(