Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decimal and time data types #133

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ The transaction is created when the first SQL statement is executed.
exits the *with* context and the queries succeed, otherwise
`prestodb.dbapi.Connection.rollback()' will be called.

# Improved Python types

If you enable the flag `experimental_python_types`, the client will convert the results of the query to the
corresponding Python types. For example, if the query returns a `DECIMAL` column, the result will be a `Decimal` object.

Limitations of the Python types are described in the
[Python types documentation](https://docs.python.org/3/library/datatypes.html). These limitations will generate an
exception `prestodb.exceptions.DataError` if the query returns a value that cannot be converted to the corresponding Python
type.

```python
import prestodb
import pytz
from datetime import datetime

conn = prestodb.dbapi.connect(
experimental_python_types=True
...
)

cur = conn.cursor()

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"

# Running Tests

There is a helper scripts, `run`, that provides commands to run tests.
Expand Down
69 changes: 64 additions & 5 deletions prestodb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import logging
import os
from typing import Any, Dict, List, Optional, Text, Tuple, Union # NOQA for mypy types
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple, Union
import pytz

import prestodb.redirect
import requests
Expand Down Expand Up @@ -457,10 +461,11 @@ class PrestoResult(object):
https://docs.python.org/3/library/stdtypes.html#generator-types
"""

def __init__(self, query, rows=None):
def __init__(self, query, rows=None, experimental_python_types = False):
self._query = query
self._rows = rows or []
self._rownumber = 0
self._experimental_python_types = experimental_python_types

@property
def rownumber(self):
Expand All @@ -471,15 +476,67 @@ def __iter__(self):
# Initial fetch from the first POST request
for row in self._rows:
self._rownumber += 1
yield row
if not self._experimental_python_types:
yield row
else:
yield self._map_to_python_types(row, self._query.columns)
self._rows = None

# Subsequent fetches from GET requests until next_uri is empty.
while not self._query.is_finished():
rows = self._query.fetch()
for row in rows:
self._rownumber += 1
yield row
if not self._experimental_python_types:
yield row
else:
yield self._map_to_python_types(row, self._query.columns)

@classmethod
def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
(value, data_type) = item

if value is None:
return None

raw_type = data_type["typeSignature"]["rawType"]

try:
if isinstance(value, list):
raw_type = {
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
}
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
elif "decimal" in raw_type:
return Decimal(value)
elif raw_type == "date":
return datetime.strptime(value, "%Y-%m-%d").date()
elif raw_type == "timestamp with time zone":
dt, tz = value.rsplit(' ', 1)
if tz.startswith('+') or tz.startswith('-'):
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z")
return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz))
elif "timestamp" in raw_type:
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
elif "time with time zone" in raw_type:
matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group(2) == '-':
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
else:
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
return datetime.strptime(matches.group(1), "%H:%M:%S.%f").time().replace(tzinfo=timezone(tz))
elif "time" in raw_type:
return datetime.strptime(value, "%H:%M:%S.%f").time()
else:
return value
except ValueError as e:
error_str = f"Could not convert '{value}' into the associated python type for '{raw_type}'"
raise prestodb/client.py (error_str) from e

def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
return list(map(self._map_to_python_type, zip(row, columns)))


class PrestoQuery(object):
Expand All @@ -489,6 +546,7 @@ def __init__(
self,
request, # type: PrestoRequest
sql, # type: Text
experimental_python_types = False,
):
# type: (...) -> None
self.auth_req = request.auth_req # type: Optional[Request]
Expand All @@ -502,7 +560,8 @@ def __init__(
self._cancelled = False
self._request = request
self._sql = sql
self._result = PrestoResult(self)
self._result = PrestoResult(self, experimental_python_types=experimental_python_types)
self._experimental_python_types = experimental_python_types

@property
def columns(self):
Expand Down Expand Up @@ -543,7 +602,7 @@ def execute(self):
self._warnings = getattr(status, "warnings", [])
if status.next_uri is None:
self._finished = True
self._result = PrestoResult(self, status.rows)
self._result = PrestoResult(self, status.rows, self._experimental_python_types)
while (
not self._finished and not self._cancelled
):
Expand Down
23 changes: 17 additions & 6 deletions prestodb/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import binascii
import datetime
from decimal import Decimal
import logging
import uuid
from typing import Any, List, Optional # NOQA for mypy types
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
isolation_level=IsolationLevel.AUTOCOMMIT,
experimental_python_types=False,
**kwargs,
):
self.host = host
Expand Down Expand Up @@ -107,6 +109,8 @@ def __init__(
self._request = None
self._transaction = None

self.experimental_python_types = experimental_python_types

@property
def isolation_level(self):
return self._isolation_level
Expand Down Expand Up @@ -171,7 +175,7 @@ def cursor(self):
request = self.transaction._request
else:
request = self._create_request()
return Cursor(self, request)
return Cursor(self, request, self.experimental_python_types)


class Cursor(object):
Expand All @@ -182,7 +186,7 @@ class Cursor(object):

"""

def __init__(self, connection, request):
def __init__(self, connection, request, experimental_python_types = False):
if not isinstance(connection, Connection):
raise ValueError(
"connection must be a Connection object: {}".format(type(connection))
Expand All @@ -193,6 +197,7 @@ def __init__(self, connection, request):
self.arraysize = 1
self._iterator = None
self._query = None
self._experimental_python_types = experimental_python_types

def __iter__(self):
return self._iterator
Expand Down Expand Up @@ -263,7 +268,7 @@ def execute(self, operation, params=None):
# TODO: Consider caching prepared statements if requested by caller
self._deallocate_prepared_statement(statement_name)
else:
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
self._query = prestodb.client.PrestoQuery(self._request, sql=operation, experimental_python_types=self._experimental_python_types)
self._iterator = iter(self._query.execute())
return self

Expand All @@ -272,7 +277,7 @@ def _generate_unique_statement_name(self):

def _prepare_statement(self, statement: str, name: str) -> None:
sql = f"PREPARE {name} FROM {statement}"
query = prestodb.client.PrestoQuery(self._request, sql=sql)
query = prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types)
query.execute()

def _execute_prepared_statement(self, statement_name, params):
Expand All @@ -282,11 +287,11 @@ def _execute_prepared_statement(self, statement_name, params):
+ " USING "
+ ",".join(map(self._format_prepared_param, params))
)
return prestodb.client.PrestoQuery(self._request, sql=sql)
return prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types)

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = "DEALLOCATE PREPARE " + statement_name
query = prestodb.client.PrestoQuery(self._request, sql=sql)
query = prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types)
query.execute()

def _format_prepared_param(self, param):
Expand Down Expand Up @@ -323,6 +328,9 @@ def _format_prepared_param(self, param):

if isinstance(param, datetime.datetime) and param.tzinfo is not None:
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
# named timezones
if hasattr(param.tzinfo, 'zone'):
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone)
# offset-based timezones
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))

Expand Down Expand Up @@ -356,6 +364,9 @@ def _format_prepared_param(self, param):
if isinstance(param, uuid.UUID):
return "UUID '%s'" % param

if isinstance(param, Decimal):
return "DECIMAL '%s'" % param

if isinstance(param, (bytes, bytearray)):
return "X'%s'" % binascii.hexlify(param).decode("utf-8")

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
)

require = ["pytz"]

kerberos_require = ["requests_kerberos"]

google_auth_require = ["google_auth"]

all_require = [kerberos_require, google_auth_require]
all_require = [require, kerberos_require, google_auth_require]

tests_require = all_require + ["httpretty", "pytest", "pytest-runner"]

Expand Down