Skip to content

Commit

Permalink
Fix task leak (#1632)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Dec 1, 2023
1 parent 8d9871f commit 519d796
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
16 changes: 14 additions & 2 deletions nucliadb/nucliadb/common/maindb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,19 @@ async def count(self, match: str) -> int:

class Driver:
initialized = False
_abort_tasks: List[asyncio.Task] = []

async def initialize(self):
raise NotImplementedError()

async def finalize(self):
raise NotImplementedError()
while len(self._abort_tasks) > 0:
task = self._abort_tasks.pop()
if not task.done():
try:
await task
except Exception:
pass

async def begin(self) -> Transaction:
raise NotImplementedError()
Expand Down Expand Up @@ -94,4 +101,9 @@ async def transaction(
if error or wait_for_abort:
await txn.abort()
else:
asyncio.create_task(txn.abort())
self._async_abort(txn)

def _async_abort(self, txn: Transaction):
task = asyncio.create_task(txn.abort())
task.add_done_callback(lambda task: self._abort_tasks.remove(task))
self._abort_tasks.append(task)
38 changes: 34 additions & 4 deletions nucliadb/nucliadb/tests/unit/common/maindb/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import asyncio
from unittest import mock

import pytest
Expand All @@ -34,14 +35,43 @@ async def commit(self, **kw):


@pytest.fixture(scope="function")
def driver() -> Driver: # type: ignore
def txn():
return TransactionTest()


@pytest.fixture(scope="function")
def driver(txn) -> Driver: # type: ignore
driver = Driver()
with mock.patch.object(
driver, "begin", new=mock.AsyncMock(return_value=TransactionTest())
):
with mock.patch.object(driver, "begin", new=mock.AsyncMock(return_value=txn)):
yield driver


@pytest.mark.asyncio
async def test_driver_async_abort(driver, txn):
async with driver.transaction(wait_for_abort=False):
pass

assert len(driver._abort_tasks) == 1
await asyncio.sleep(0.1)

txn.abort.assert_called_once()
assert len(driver._abort_tasks) == 0


@pytest.mark.asyncio
async def test_driver_finalize_aborts_transactions(driver, txn):
async with driver.transaction(wait_for_abort=False):
pass

assert len(driver._abort_tasks) == 1

await driver.finalize()

txn.abort.assert_called_once()

assert len(driver._abort_tasks) == 0


@pytest.mark.asyncio
async def test_transaction_handles_txn_begin_errors(driver):
driver.begin.side_effect = ValueError()
Expand Down

1 comment on commit 519d796

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 519d796 Previous: 84cebd9 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12962.239042936462 iter/sec (stddev: 9.640725398110143e-8) 12982.011001408568 iter/sec (stddev: 3.088010895258411e-7) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.