Skip to content

Commit

Permalink
Always use context manager for transactions (#2290)
Browse files Browse the repository at this point in the history
* Always use context manager for transactions

* Remove wait_for_abort

* Remove transaction.begin(), only context manager

* Better pool timeout handling

* Add PG pool metrics (#2295)
  • Loading branch information
javitonino authored Jul 9, 2024
1 parent 538b47a commit 85e9504
Show file tree
Hide file tree
Showing 33 changed files with 751 additions and 1,119 deletions.
8 changes: 4 additions & 4 deletions nucliadb/src/nucliadb/common/datamanagers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ async def get_kv_pb(


@contextlib.asynccontextmanager
async def with_rw_transaction(wait_for_abort: bool = True):
async def with_rw_transaction():
driver = get_driver()
async with driver.transaction(read_only=False, wait_for_abort=wait_for_abort) as txn:
async with driver.transaction(read_only=False) as txn:
yield txn


Expand All @@ -51,7 +51,7 @@ async def with_rw_transaction(wait_for_abort: bool = True):


@contextlib.asynccontextmanager
async def with_ro_transaction(wait_for_abort: bool = True):
async def with_ro_transaction():
driver = get_driver()
async with driver.transaction(read_only=True, wait_for_abort=wait_for_abort) as ro_txn:
async with driver.transaction(read_only=True) as ro_txn:
yield ro_txn
34 changes: 2 additions & 32 deletions nucliadb/src/nucliadb/common/maindb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,36 +74,6 @@ async def finalize(self):
except Exception:
pass

async def begin(self, read_only: bool = False) -> Transaction:
raise NotImplementedError()

@asynccontextmanager
async def transaction(
self, wait_for_abort: bool = True, read_only: bool = False
) -> AsyncGenerator[Transaction, None]:
"""
Use to make sure transaction is always aborted.
:param wait_for_abort: If True, wait for abort to finish before returning.
If False, abort is done in background (unless there
is an error)
"""
txn: Optional[Transaction] = None
error: bool = False
try:
txn = await self.begin(read_only=read_only)
yield txn
except Exception:
error = True
raise
finally:
if txn is not None and txn.open:
if error or wait_for_abort:
await txn.abort()
else:
self._async_abort(txn)

def _async_abort(self, txn: Transaction):
task = asyncio.create_task(txn.abort())
task.add_done_callback(lambda task: self._abort_tasks.remove(task))
self._abort_tasks.append(task)
async def transaction(self, read_only: bool = False) -> AsyncGenerator[Transaction, None]:
yield Transaction()
8 changes: 5 additions & 3 deletions nucliadb/src/nucliadb/common/maindb/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#
import glob
import os
from typing import Optional
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional

from nucliadb.common.maindb.driver import (
DEFAULT_BATCH_SCAN_LIMIT,
Expand Down Expand Up @@ -212,7 +213,8 @@ async def initialize(self):
async def finalize(self):
pass

async def begin(self, read_only: bool = False) -> LocalTransaction:
@asynccontextmanager
async def transaction(self, read_only: bool = False) -> AsyncGenerator[Transaction, None]:
if self.url is None:
raise AttributeError("Invalid url")
return LocalTransaction(self.url, self)
yield LocalTransaction(self.url, self)
88 changes: 74 additions & 14 deletions nucliadb/src/nucliadb/common/maindb/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncGenerator, Optional, Union
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Optional

import backoff
import psycopg
Expand All @@ -42,12 +44,33 @@
);
"""

logger = logging.getLogger(__name__)

# Request Metrics
pg_observer = metrics.Observer(
"pg_client",
labels={"type": ""},
)

# Pool metrics
POOL_METRICS_COUNTERS = {
# Requests for a connection to the pool
"requests_num": metrics.Counter("pg_client_pool_requests_total"),
"requests_queued": metrics.Counter("pg_client_pool_requests_queued_total"),
"requests_errors": metrics.Counter("pg_client_pool_requests_errors_total"),
"requests_wait_ms": metrics.Counter("pg_client_pool_requests_queued_seconds_total"),
"usage_ms": metrics.Counter("pg_client_pool_requests_usage_seconds_total"),
# Pool opening a connection to PG
"connections_num": metrics.Counter("pg_client_pool_connections_total"),
"connections_ms": metrics.Counter("pg_client_pool_connections_seconds_total"),
}
POOL_METRICS_GAUGES = {
"pool_size": metrics.Gauge("pg_client_pool_connections_open"),
# The two below most likely change too rapidly to be useful in a metric
"pool_available": metrics.Gauge("pg_client_pool_connections_available"),
"requests_waiting": metrics.Gauge("pg_client_pool_requests_waiting"),
}


class DataLayer:
def __init__(self, connection: psycopg.AsyncConnection):
Expand Down Expand Up @@ -137,7 +160,6 @@ async def abort(self):
await self.connection.rollback()
finally:
self.open = False
await self.driver.pool.putconn(self.connection)

async def commit(self):
with pg_observer({"type": "commit"}):
Expand All @@ -148,7 +170,6 @@ async def commit(self):
raise
finally:
self.open = False
await self.driver.pool.putconn(self.connection)

async def batch_get(self, keys: list[str], for_update: bool = True):
return await self.data_layer.batch_get(keys, select_for_update=for_update)
Expand Down Expand Up @@ -269,26 +290,65 @@ async def initialize(self):
await self.pool.open()

# check if table exists
async with self.pool.connection() as conn:
async with self._get_connection() as conn:
await conn.execute(CREATE_TABLE)

self.initialized = True
self.metrics_task = asyncio.create_task(self._report_metrics_task())

async def finalize(self):
async with self._lock:
await self.pool.close()
self.initialized = False

@backoff.on_exception(backoff.expo, RETRIABLE_EXCEPTIONS, jitter=backoff.random_jitter, max_tries=3)
async def begin(self, read_only: bool = False) -> Union[PGTransaction, ReadOnlyPGTransaction]:
self.metrics_task.cancel()

async def _report_metrics_task(self):
while True:
self._report_metrics()
await asyncio.sleep(60)

def _report_metrics(self):
if not self.initialized:
return

metrics = self.pool.pop_stats()
for key, metric in POOL_METRICS_COUNTERS.items():
value = metrics.get(key, 0)
if key.endswith("_ms"):
value /= 1000
metric.counter.inc(value)

for key, metric in POOL_METRICS_GAUGES.items():
value = metrics.get(key, 0)
metric.set(value)

@asynccontextmanager
async def transaction(self, read_only: bool = False) -> AsyncGenerator[Transaction, None]:
if read_only:
return ReadOnlyPGTransaction(self)
yield ReadOnlyPGTransaction(self)
else:
timeout = self.acquire_timeout_ms / 1000
conn = await self.pool.getconn(timeout=timeout)
with pg_observer({"type": "begin"}):
return PGTransaction(self, conn)
async with self._get_connection() as conn:
yield PGTransaction(self, conn)

def _get_connection(self) -> InstrumentedAcquireContext:
@asynccontextmanager
async def _get_connection(self) -> AsyncGenerator[psycopg.AsyncConnection, None]:
timeout = self.acquire_timeout_ms / 1000
return InstrumentedAcquireContext(self.pool.connection(timeout=timeout))
# Manual retry loop since backoff.on_exception does not play well with async context managers
retries = 0
while True:
with pg_observer({"type": "acquire"}):
try:
async with self.pool.connection(timeout=timeout) as conn:
yield conn
return
except psycopg_pool.PoolTimeout as e:
logger.warning(
f"Timeout getting connection from the pool, backing off. Retries = {retries}"
)
if retries < 3:
await asyncio.sleep(1)
retries += 1
else:
raise e
except Exception as e:
raise e
Loading

2 comments on commit 85e9504

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 85e9504 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 2926.53900792355 iter/sec (stddev: 0.0000027998403812798556) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 85e9504 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 3054.209676031114 iter/sec (stddev: 0.000005055790648691623) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 0.93

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.