Skip to content

Commit

Permalink
Improved support for cursor factories with Psycopg 2 - closes #89
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 25, 2024
1 parent a8a1bf1 commit d0a3c5a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.3.3 (unreleased)

- Improved support for cursor factories with Psycopg 2

## 0.3.2 (2024-07-17)

- Fixed error with asyncpg and pgvector < 0.7
Expand Down
4 changes: 3 additions & 1 deletion pgvector/psycopg2/register.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import psycopg2
from psycopg2.extensions import cursor
from .halfvec import register_halfvec_info
from .sparsevec import register_sparsevec_info
from .vector import register_vector_info


def register_vector(conn_or_curs=None):
cur = conn_or_curs.cursor() if hasattr(conn_or_curs, 'cursor') else conn_or_curs
conn = conn_or_curs if hasattr(conn_or_curs, 'cursor') else conn_or_curs.connection
cur = conn.cursor(cursor_factory=cursor)

# use to_regtype to get first matching type in search path
cur.execute("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('halfvec'), to_regtype('sparsevec'))")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_psycopg2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pgvector.psycopg2 import register_vector, SparseVector
import psycopg2
from psycopg2.extras import DictCursor, NamedTupleCursor
from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor

conn = psycopg2.connect(dbname='pgvector_python_test')
conn.autocommit = True
Expand Down Expand Up @@ -56,14 +56,14 @@ def test_sparsevec(self):
assert res[1][0] is None

def test_cursor_factory(self):
for cursor_factory in [DictCursor, NamedTupleCursor]:
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(dbname='pgvector_python_test')
cur = conn.cursor(cursor_factory=cursor_factory)
register_vector(cur)
conn.close()

def test_cursor_factory_connection(self):
for cursor_factory in [DictCursor, NamedTupleCursor]:
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(dbname='pgvector_python_test', cursor_factory=cursor_factory)
register_vector(conn)
conn.close()

0 comments on commit d0a3c5a

Please sign in to comment.