diff --git a/oaib/Batch.py b/oaib/Batch.py index 769c844..606317d 100644 --- a/oaib/Batch.py +++ b/oaib/Batch.py @@ -344,11 +344,9 @@ async def run(self, callback=None): # existing requests first. if not self.__stopped.is_set(): self.log("FINISHING PROCESSING | 5 second timeout") - await wait( - [*self.__processing, *self.__callbacks, *self.__workers], - return_when=ALL_COMPLETED, - timeout=5 - ) + await gather(*self.__processing) + await gather(*self.__workers) + await gather(*self.__callbacks) await self.stop() self.log("RETURNING OUTPUT") diff --git a/tests/test_large.py b/tests/test_large.py new file mode 100644 index 0000000..2d18440 --- /dev/null +++ b/tests/test_large.py @@ -0,0 +1,24 @@ +import oaib + + +async def test_long(): + batch = oaib.Auto() + + n = 20 + m = 20 + for i in range(n): + await batch.add( + "chat.completions.create", + model="gpt-4", + max_tokens=4000, + messages=[{"role": "user", "content": "say hello and goodbye " * m}] + ) + + chats = await batch.run() + assert len( + chats) == n, f"Chat batch should return {n} results, got {len(chats)}" + + chat = chats.iloc[0].get("result") + assert chat['choices'], "Should get valid chat completions" + + chats