Skip to content

Commit

Permalink
Numpy support is unconditional
Browse files Browse the repository at this point in the history
  • Loading branch information
xhochy committed Aug 28, 2024
1 parent 1f698fd commit eae9cea
Showing 1 changed file with 31 additions and 79 deletions.
110 changes: 31 additions & 79 deletions turbodbc/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,16 @@
from itertools import islice
from typing import Any, cast

from turbodbc_intern import make_parameter_set, make_row_based_result_set

from .exceptions import Error, InterfaceError, translate_exceptions

_NO_NUMPY_SUPPORT_MSG = (
"This installation of turbodbc does not support NumPy extensions. "
"Please install the `numpy` package. If you have built turbodbc from source, "
"you may also need to reinstall turbodbc to compile the extensions."
)
_NO_ARROW_SUPPORT_MSG = (
"This installation of turbodbc does not support Apache Arrow extensions. "
"Please install the `pyarrow` package. If you have built turbodbc from source, "
"you may also need to reinstall turbodbc to compile the extensions."
import pyarrow as pa
from turbodbc_intern import (
make_arrow_result_set,
make_numpy_result_set,
make_parameter_set,
make_row_based_result_set,
set_numpy_parameters,
)


def _has_numpy_support() -> bool:
try:
import turbodbc_numpy_support # noqa: F401

return True
except ImportError:
return False


def _has_arrow_support() -> bool:
try:
import turbodbc.arrow_support # noqa: F401

return True
except ImportError:
return False
from .exceptions import InterfaceError, translate_exceptions


def _make_masked_arrays(result_batch):
Expand Down Expand Up @@ -202,36 +180,22 @@ def executemanycolumns(self, sql: str, columns):

self.impl.prepare(sql)

if _has_arrow_support():
import pyarrow as pa

def _num_chunks(c):
return c.num_chunks

if isinstance(columns, pa.Table):
from turbodbc.arrow_support import set_arrow_parameters # type: ignore

for column in columns.itercolumns():
if _num_chunks(column) != 1:
raise NotImplementedError(
"Chunked Arrays are " "not yet supported"
)
def _num_chunks(c):
return c.num_chunks

set_arrow_parameters(self.impl, columns)
return self._execute()
if isinstance(columns, pa.Table):
from turbodbc.arrow_support import set_arrow_parameters # type: ignore

# Workaround to give users a better error message without a need
# to import pyarrow
if columns.__class__.__module__.startswith("pyarrow"):
raise Error(_NO_ARROW_SUPPORT_MSG)
for column in columns.itercolumns():
if _num_chunks(column) != 1:
raise NotImplementedError("Chunked Arrays are " "not yet supported")

if not _has_numpy_support():
raise Error(_NO_NUMPY_SUPPORT_MSG)
set_arrow_parameters(self.impl, columns)
return self._execute()

_assert_numpy_column_preconditions(columns)

from numpy.ma import MaskedArray
from turbodbc_numpy_support import set_numpy_parameters

split_arrays = []
for column in columns:
Expand Down Expand Up @@ -319,10 +283,6 @@ def fetchnumpybatches(self):

def _numpy_batch_generator(self):
self._assert_valid_result_set()
if not _has_numpy_support():
raise Error(_NO_NUMPY_SUPPORT_MSG)

from turbodbc_numpy_support import make_numpy_result_set

numpy_result_set = make_numpy_result_set(self.impl.get_result_set())
first_run = True
Expand Down Expand Up @@ -351,22 +311,18 @@ def fetcharrowbatches(self, strings_as_dictionary=False, adaptive_integers=False
:return: generator of ``pyarrow.Table``
"""
self._assert_valid_result_set()
if _has_arrow_support():
from turbodbc.arrow_support import make_arrow_result_set # type: ignore

rs = make_arrow_result_set(
self.impl.get_result_set(), strings_as_dictionary, adaptive_integers
)
first_run = True
while True:
table = rs.fetch_next_batch()
is_empty_batch = len(table) == 0
if is_empty_batch and not first_run:
return # Let us return a typed result set at least once
first_run = False
yield table
else:
raise Error(_NO_ARROW_SUPPORT_MSG)
rs = make_arrow_result_set(
self.impl.get_result_set(), strings_as_dictionary, adaptive_integers
)
first_run = True
while True:
table = rs.fetch_next_batch()
is_empty_batch = len(table) == 0
if is_empty_batch and not first_run:
return # Let us return a typed result set at least once
first_run = False
yield table

def fetchallarrow(self, strings_as_dictionary=False, adaptive_integers=False):
"""
Expand All @@ -385,14 +341,10 @@ def fetchallarrow(self, strings_as_dictionary=False, adaptive_integers=False):
:return: ``pyarrow.Table``
"""
self._assert_valid_result_set()
if _has_arrow_support():
from turbodbc.arrow_support import make_arrow_result_set # type: ignore

return make_arrow_result_set(
self.impl.get_result_set(), strings_as_dictionary, adaptive_integers
).fetch_all()
else:
raise Error(_NO_ARROW_SUPPORT_MSG)
return make_arrow_result_set(
self.impl.get_result_set(), strings_as_dictionary, adaptive_integers
).fetch_all()

def nextset(self) -> bool:
"""
Expand Down

0 comments on commit eae9cea

Please sign in to comment.